#!/usr/bin/env python3
import os
import sys
import glob


class TSVJoiner:
    def __init__(self, tsv_file_path=""):
        self.header = set()
        self.data = {}
        self.is_empty = True
        if os.path.exists(tsv_file_path):
            self.read_tsv(tsv_file_path)

    def _join_header(self, header):
        for _h in header:
            self.header.add(_h)

    def read_tsv(self, tsv_file_path):
        self.is_empty = False
        R = open(tsv_file_path, "r")
        header = next(R).rstrip("\r\n").split("\t")
        header_len = len(header)
        data = None
        self._join_header(header[1:])
        for _line in R:
            line = _line.rstrip("\r\n").split("\t")
            data = self.data.get(line[0], None)
            if data is None:
                self.data[line[0]] = {}
            for i in range(1,header_len):
                old_val = self.data[line[0]].get(header[i], None) 
                if old_val is None:
                    self.data[line[0]][header[i]] = line[i]
                else:
                    self.data[line[0]][header[i]] = old_val + "," + line[i]
        R.close()

    def write_tsv(self, W):
        assert not self.is_empty, "No data to write!"
        header_list = list(self.header)
        _out = ""
        W.write("ID\t")
        for head in header_list:
            _out += head + "\t"
        W.write(_out[:-1] + "\n")
        for _id in self.data.keys():
            _out = _id + "\t"
            for head in header_list:
                _out += self.data[_id].get(head, "None") + "\t"
            W.write(_out[:-1] + "\n")
        W.close()


if __name__ == "__main__":
    OUT_STR = "Usage: tsv-join -o outfile <file-path(s)>"

    assert len(sys.argv) >= 4, OUT_STR
    assert sys.argv[1] == "-o", OUT_STR

    file_list = []

    for arg in sys.argv[3:]:
        for _file in glob.glob(arg):
            file_list.append(os.path.abspath(_file))

    tsv = TSVJoiner()
    for _file in file_list:
        tsv.read_tsv(_file)

    tsv.write_tsv(open(os.path.abspath(sys.argv[2]), "w"))
