#!/usr/bin/env python
"""
Main Catactor classes and functions for meta scATAC-seq analyses.
"""

__author__ = "Kawaguchi RK"
__copyright__ = "Copyright 2019, ent Project"
__credits__ = ["Kawaguchi RK"]
__license__ = "MIT"
__version__ = "0.1"
__maintainer__ = "Kawaguchi RK"
__email__ = "rkawaguc@cshl.edu"
__status__ = "Development"



from Catactor_base import *
from learn_alternate import *
from collections import Counter
from sklearn.model_selection import StratifiedKFold
import tempfile
import datetime

TOP_GENES_MAX = 5000

def normalize_features_for_cls(adata, rank_flag, marker_list, top_markers):
    global TOP_GENES_MAX
    if adata.X.shape[0] > 50000 and adata.X.shape[1] > 30000: # too large datasets
        flag = (-adata.var.loc[:,'cov']).argsort()
        all_genes = sorted(adata.var.loc[flag < TOP_GENES_MAX,:].index.tolist())
        for key in marker_list:
            all_genes.extend(marker_list[key][0:min(len(marker_list[key]), top_markers)])
        all_genes = [x for x in sorted(set(all_genes)) if x in adata.var.index]
        adata = adata[:,np.array(all_genes)]
    X = adata.X
    if scipy.sparse.issparse(X):
        X = X.todense()
    if rank_flag:
        rank_exp = np.apply_along_axis(lambda x: rankdata(x, 'average')/X.shape[0], 0, X)
    else:
        from sklearn.preprocessing import MinMaxScaler
        X = np.array(X)
        scaler = MinMaxScaler()
        X = np.apply_along_axis(lambda x: MinMaxScaler().fit_transform(x.reshape(-1, 1)), 0, X)
        X = np.squeeze(X)
    adata.X = X
    return adata

def fill_prediction_vec(vec, total_length, index):
    vec_all = pd.Series([np.nan] * total_length)
    vec_all.iloc[np.array(index)] = list(vec)
    return vec_all

def fill_prediction(Y_pred, Y_order, total_length, index):
    return fill_prediction_vec(Y_pred, total_length, index), fill_prediction_vec(Y_order, total_length, index)

def preprocess_feature_matrix(X, y_true):
    if y_true is not None:
        total_length = len(y_true)
        X, y_true, index = remove_nan_data(X, y_true)
        Y_true = dict([(c, np.array([True if cl == c else False for cl in y_true]).reshape((len(y_true)))) for c in set(y_true)])
        return X, y_true, Y_true, total_length, index
    else:
        total_length = X.shape[0]
        return X, None, None, total_length, np.linspace(0, total_length-1, total_length).astype(int)


def norm_row_columns(X):
    from sklearn.preprocessing import MinMaxScaler
    X = np.array(X)
    scaler = MinMaxScaler()
    X = np.apply_along_axis(lambda x: MinMaxScaler().fit_transform(x.reshape(-1, 1)), 0, X)
    X = np.squeeze(X)
    X = np.apply_along_axis(lambda x: MinMaxScaler().fit_transform(x.reshape(-1, 1)), 1, X)
    X = np.squeeze(X)
    return X

def get_max_indices(Y_pred):
    max_y = Y_pred.max(axis=0)
    result = []
    for i in range(Y_pred.shape[0]):
        y_pred = [1 if Y_pred[i,j] == max_y[j] and Y_pred[i,j] > 0 and max(Y_pred[[h for h in range(Y_pred.shape[0]) if h != i] ,j]) < max_y[j]  \
                    else 0 for j in range(Y_pred.shape[1])]
        result.append(y_pred)
    return np.array(result), max_y

def comp_auroc(y_pred, y_true, y_order, roc_out):
    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_order, pos_label=1)
    fptpr = [fpr, tpr]
    auc = metrics.auc(fpr, tpr)
    acc = metrics.accuracy_score(y_true, y_pred)
    precision = metrics.precision_score(y_true, y_pred)
    recall = metrics.recall_score(y_true, y_pred)
    if len(roc_out) > 0:
        print('Computing AUROC', Counter(y_true), Counter(y_pred), auc, acc, precision, recall, 'stored in', roc_out)
        with open(roc_out, 'wb') as f:
            pickle.dump(fptpr, f)
    return auc, acc, precision, recall

def comp_regressed_signals(X, X_columns=[]):
    regressed_xy =  {}
    if len(X_columns) == 0:
        X_columns = X.columns
    for c in X_columns:
        for oc in X_columns:
            regressed_xy[(oc, c)] = compute_regressed_cord(X, oc, c, outlier=False)[0]
    Y_pred, Y_order = [], []
    competitor = {'EX':'IN', 'IN':'EX', 'NN':'EX'}
    for c in X_columns:
        y_min = np.array([regressed_xy[(oc, c)][:,1] for oc in X_columns if oc != c]).min(axis=0)
        # y_min = regressed_xy[(competitor[c], c)][:,1]
        x_min = y_min
        # x_min = np.array([regressed_xy[(oc, c)][:,0] for oc in X.columns if oc != c]).mean(axis=0)
        Y_pred.append([1 if x > 0 else 0 for x in y_min])
        #Y_order.append(rankdata(x_min))
        Y_order.append(x_min)
    return np.array(Y_pred), np.array(Y_order)

def make_prediction(Y_pred, labels):
    result = pd.Series(['NA']*Y_pred.shape[1])
    assigned = np.array([i for i, x in enumerate(Y_pred.max(axis=0)) if x > 0 or x])
    if len(assigned) == 0: # No positive prediction
        return result.tolist()
    result[assigned] = [labels[int(ind)] for ind in np.argmax(Y_pred[:,assigned], axis=0)]
    return result.tolist()


def plot_mean_roc(fptpr, output):
    from scipy import interp
    mean_fpr = np.linspace(0, 1, 100)
    tprs, aucs, accs = [], [], []
    plot = (output is not None)
    for i in range(len(fptpr)):
        fpr, tpr, auc, acc = fptpr[i]
        if plot:
            plt.plot(fpr, tpr, lw=1, alpha=0.3,
                        label='ROC fold %d (AUC = %0.2f)' % (i, auc))
        tprs.append(interp(mean_fpr, fpr, tpr))
        aucs.append(auc)
        accs.append(acc)
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = metrics.auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    if not plot:
        return mean_auc, np.mean(accs)
    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
            label='Chance', alpha=.8)
    plt.plot(mean_fpr, mean_tpr, color='b',
        label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
                lw=2, alpha=.8)
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                    label=r'$\pm$ 1 std. dev.')
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, ncol=2, borderaxespad=0.)
    plt.savefig(output, bbox_inches='tight')
    plt.close('all')
    plt.clf()
    return mean_auc, np.mean(accs)

def evaluate_by_raw_signal_cv(X, y_true, header, sampling, classifier='la', feature_names=[], verbose=True, saved=False):
    if classifier == 'la':
        kwargs = {'rho':1e-5, 'tol':1e-4, 'check':0.90, 'max_alternates':-1, 'maxitr':1000}
    else:
        kwargs = {}
    fptpr, all_list = {}, {}
    clf = get_classifier(classifier, kwargs)
    ofname = header+'_'+classifier+'_cv_fptpr.npy'
    if (not saved) or (not os.path.exists(ofname)):
        for i, (train, test) in enumerate(sampling.split(X, y_true)):
            if verbose:
                print('-- Computing cross validation', i, classifier)
            if classifier == 'la':
                clf = clf.fit(X[train], y_true[train], featurename=feature_names)
                y_pred = clf.predict(X[test])
                probas_ = clf.predict_proba_mean_weight(X[test], max_alternates=5)
            else:
                clf = clf.fit(X[train], y_true[train])
                y_pred = clf.predict(X[test])
                probas_ = clf.predict_proba(X[test])
                probas_ = probas_[:,1]
            fpr, tpr, thresholds = metrics.roc_curve(y_true[test], probas_)
            acc = metrics.accuracy_score(y_true[test], y_pred)
            auc = metrics.auc(fpr, tpr)
            fptpr[i] = [fpr, tpr, auc, acc]
        with open(ofname, 'wb') as f:
            pickle.dump(fptpr, f)
    with open(ofname, 'rb') as f:
        fptpr = pickle.load(f)
    mean_auc, mean_acc = plot_mean_roc(fptpr, (header+'_'+classifier+'_mean_auc.png'))
    return mean_auc, mean_acc

def remove_nan_data(X, y_true):
    index, y_true = zip(*[(i, str(x)) for i, x in enumerate(y_true) if str(x) != 'NA' and str(x) != 'nan']) # remove unannotated cells or clusters
    index = np.array(index)
    X = X[np.array(index),:]
    X, y_true = np.array(X), y_true
    return X, y_true, index

def feature_is_large(m, X):
    return m in ['la', 'svm', 'rf'] and X.shape[1] > 500

def train_by_raw_signal(X, y_true, header, sampling, feature_names, verbose=True, no_cv=False, cores=1, saved=False):
    global METHODS
    all_result = []
    total_length = len(y_true)
    X, y_true, index = remove_nan_data(np.array(X), y_true)
    if verbose:
        print('Training option: cv=', no_cv, "size=", X.shape)
    for m in METHODS:
        if feature_is_large(m, X):
            continue
        if verbose:
            print('Train with ', m)
        if m == 'la':
            print(set(y_true))
            make_lasso_matrix(X, y_true, header+'_'+'la', feature_names=feature_names, cores=cores, saved=saved)
        else:
            make_other_clfs(X, y_true, header, classifier=m, cores=cores, saved=saved)
    if no_cv: return None
    cell_types = sorted(set(y_true))
    for m in METHODS:
        if feature_is_large(m, X): continue
        if verbose:
            print('Evaluate AUC by cross validation', header, m)
        for i, c in enumerate(cell_types): # predict independently
            y = np.array([True if cl == c else False for cl in y_true]).reshape(len(y_true))
            print(cell_types, sum(y))
            mean_auc, mean_acc = evaluate_by_raw_signal_cv(X, y, header+'_'+c, sampling, m, feature_names, verbose, saved)
            all_result.append([c, m, mean_auc, mean_acc, 'sc'])
    df = pd.DataFrame(all_result, columns=['celltype', 'clf', 'auc', 'acc', 'class'])
    return df

def predict_by_raw_signal_trained(X, train_header, test_header, feature_names, verbose=True, cell_level=False, mt=False):
    global METHODS
    all_result = []
    X, _, _, total_length, index = preprocess_feature_matrix(X, None)
    Y_pred_all, Y_order_all = {}, {}
    cell_types = []
    for m in METHODS:
        if feature_is_large(m, X):
            continue
        if verbose:
            print('Predicting ', m, 'for', test_header, 'based on ', train_header)
        if m == 'la':
            clfs = read_lasso_clf(train_header+'_'+m)
        else:
            clfs = read_other_clfs(train_header, m)
        if clfs is None:
            continue
        try:
            _, y_proba = {}, {}
            if len(cell_types) == 0:
                cell_types = clfs.keys()
            for c in cell_types:
                if m == 'la':
                    y_proba[c] = clfs[c].predict_proba_mean_weight(X, max_alternates=5)
                else:
                    y_proba[c] = clfs[c].predict_proba(X)[:,1]
            y_proba_mat = np.array([y_proba[c] for c in cell_types])
            if mt: # evaluate a normalized signal against multiple types
                y_proba_mat = norm_row_columns(y_proba_mat)
                Y_pred, y_order = get_max_indices(y_proba_mat)
                y_pred = make_prediction(Y_pred, [c for c in y_proba])
            else:
                Y_pred, y_order = get_max_indices(y_proba_mat)
                y_pred = make_prediction(Y_pred, [c for c in y_proba])
            if cell_level:
                y_pred = fill_prediction_vec(y_pred, total_length, index)
                Y_pred_all[m] = y_pred
                for c in cell_types:
                    Y_order_all[m+'_'+c] = fill_prediction_vec(y_proba[c], total_length, index)
            else:
                Y_pred_all[m] = y_pred
                for c in cell_types:
                    Y_order_all[m+'_'+c] = y_proba[c]
        except:
            print("Prediction error:", sys.exc_info()[0])
    return Y_pred_all, Y_order_all, None
    
def evaluate_by_raw_signal_trained_single_task(m, c, X, Y_true, clf, test_header, roc_file, label):
    pred = clf.predict(X)
    if m == 'la':
        proba = clf.predict_proba_mean_weight(X, max_alternates=5)
    else:
        proba = clf.predict_proba(X)[:,1]
    auc, acc, prec, rec = comp_auroc(pred, Y_true, proba, roc_file)
    pred_pos = sum(pred)
    true_pos = sum(Y_true)
    return pred, proba, [c, m, auc, acc, prec, rec, len(Y_true), pred_pos, true_pos, roc_file, label]

def evaluate_by_raw_signal_trained_multi_task(m, c, pred, proba, Y_true, test_header, roc_file, label):
    auc, acc, prec, rec = comp_auroc(pred, Y_true, proba, roc_file)
    pred_pos = sum(pred)
    true_pos = sum(Y_true)
    return [c, m, auc, acc, prec, rec, len(Y_true), pred_pos, true_pos, roc_file, label]

def evaluate_by_raw_signal_trained(X, y_true, train_header, test_header, feature_names, verbose=True, cell_level=False, mt=False):
    global METHODS
    all_result = []
    X, y_true, Y_true, total_length, index = preprocess_feature_matrix(X, y_true)
    Y_pred_all, Y_order_all = {}, {}
    cell_types = sorted(set(y_true))
    for m in METHODS:
        if feature_is_large(m, X):
            continue
        if verbose:
            print('Evaluating ', m, 'for', test_header, 'based on ', train_header)
        if m == 'la':
            clfs = read_lasso_clf(train_header+'_'+m)
        else:
            clfs = read_other_clfs(train_header, m)
        if clfs is None:
            continue
        y_proba_dic, y_pred_dic = {}, {}
        try: # sklearn version should be same to use trained classifiers
            # single task
            for i, c in enumerate(cell_types):
                roc_file = test_header+'_'+c+('' if len([x for x in cell_types if x == c]) == 1 else str(i))+'_'+m+'_fptpr.npy'
                pred, proba, result = evaluate_by_raw_signal_trained_single_task(m, c, X, Y_true[c], clfs[c], test_header, roc_file, 'st')
                y_proba_dic[c] = proba
                y_pred_dic[c] = pred
                print(min(proba), max(proba), np.mean(proba))
                all_result.append(result)
            y_proba_mat = np.array([y_proba_dic[c] for c in cell_types])
            y_pred_mat = np.array([y_pred_dic[c] for c in cell_types])
            if mt:
                y_proba_mat = norm_row_columns(y_proba_mat)
                print('statistics for AUROC', [(min(p), max(p), np.mean(p)) for p in y_proba_mat])
                Y_pred, y_order = get_max_indices(y_proba_mat)
                y_pred = make_prediction(Y_pred, [c for c in y_proba])
                for i, c in enumerate(cell_types):
                    roc_file = test_header+'_'+c+('' if len([x for x in cell_types if x == c]) == 1 else str(i))+'_'+m+'_mt_fptpr.npy'
                    result = evaluate_by_raw_signal_trained_multi_task(m, c, [1 if y == c else 0 for y in y_pred], y_proba_mat[i], Y_true[c], test_header, roc_file, 'mt')
                    all_result.append(result)
            else:
                Y_pred, y_order = get_max_indices(y_proba_mat)
                y_pred = make_prediction(Y_pred, [c for c in y_proba_dic])
            if cell_level:
                y_pred = fill_prediction_vec(y_pred, total_length, index)
                Y_pred_all[m] = y_pred
                for c in cell_types:
                    Y_order_all[m+'_'+c] = fill_prediction_vec(y_proba_dic[c], total_length, index)
            else:
                Y_pred_all[m] = y_pred
                for c in cell_types:
                    Y_order_all[m+'_'+c] = y_order
        except:
            print("Evaluation error:", sys.exc_info()[0])
    return Y_pred_all, Y_order_all, pd.DataFrame(all_result, columns=['celltype', 'method', 'auc', 'acc', 'precision', 'recall', 'whole', 'ppos', 'tpos', 'roc_file', 'class'])

def predict_by_raw_signal(X, regress):
    if regress:
        Y_pred, Y_order = comp_regressed_signals(X)
        y_order = Y_order.max(axis=0)
        y_pred = make_prediction(Y_pred, X.columns.tolist())
    else:
        Y_order = norm_row_columns(X.values).transpose()
        Y_pred, y_order = get_max_indices(Y_order)
        y_pred = make_prediction(Y_pred, X.columns.tolist())
    return y_pred, y_order


def evaluate_by_raw_signal(X, y_true, header, regress, simulate):
    result = []
    if regress:
        Y_pred, Y_order = comp_regressed_signals(X, X.columns)
        y_order = Y_order.max(axis=0)
        y_pred = make_prediction(Y_pred, X.columns)        
    else:
        Y_order = norm_row_columns(X.values).transpose()
        Y_pred, y_order = get_max_indices(Y_order)
        y_pred = make_prediction(Y_pred, X.columns)
    for i, c in enumerate(X.columns):
        if c not in set(y_true):
            continue
        y_true_bin = [1 if c == y else 0 for y in y_true]
        y_pred_bin = [1 if c == y else 0 for y in y_pred]
        if simulate:
            roc_file = ''
        else:
            roc_file = header+'_'+c+('_'+str(i) if len([x for x in X.columns if x == c]) > 1 else '')+'_fptpr.npy'
        auc, acc, prec, rec = comp_auroc(y_pred_bin, y_true_bin, Y_order[i,:], roc_file)
        pred_pos = sum(y_pred_bin)
        true_pos = sum(y_true_bin)
        result.append([c, auc, acc, prec, rec, len(y_true_bin), pred_pos, true_pos, roc_file])
    return y_pred, y_order, pd.DataFrame(result, columns=['celltype', 'auc', 'acc', 'precision', 'recall', 'whole', 'ppos', 'tpos', 'roc_file'])



class CatactorLassoAlt(Catactor):

    def __init__(self, args):
        Catactor.__init__(self, args)
        if self.args['scobj'] != '':
            self.all_cell_path = self.args['scobj']
            self.all_cell_modif_path = self.args['scobj'].replace('_obj', '_obj_with_feat')
            self.all_cell_cluster_path = self.args['scobj'].replace('_obj', '_obj_clust_ave').rstrip('.pyn')
            self.all_header = self.args['scobj'].replace('_scanpy_obj.pyn', '')
        elif self.args['scmobj'] != '':
            self.all_cell_modif_path = self.args['scmobj']
            self.all_header = self.args['scmobj'].replace('_scanpy_obj_with_feat.pyn', '')
            self.all_cell_path = self.args['scmobj'].replace('_with_feat', '')
            self.all_cell_cluster_path = self.args['scmobj'].replace('_with_feat', '_clust_ave').rstrip('.pyn')
        else:
            self.all_header = '_'.join([self.args['output'].split(',')[0], self.args['gene_group'], self.args['cell_group']])+'_all'+('_bin' if self.args['binary'] else '')
            self.all_cell_path = os.path.join(self.args['odir'], self.scobj, self.all_header+'_scanpy_obj.pyn')
            self.all_cell_modif_path = os.path.join(self.args['odir'], self.scobj, self.all_header+'_scanpy_obj_with_feat.pyn')
            self.all_cell_cluster_path = os.path.join(self.args['odir'], self.scobj, self.all_header+'_scanpy_obj_clust_ave')
        if self.args['train_out'] != '' and self.args['test_out'] == '':
            self.set_sampling_parameters()
        if self.args['verbose']:
            print('-- Set scanpy obj paths')
            print(self.all_header, '->', self.all_cell_path)
            print('->', self.all_cell_modif_path)

    def set_sampling_parameters(self):
        global SEED
        self.seed = SEED
        self.cv = self.args['cv']
        self.metrics = ['roc_auc', 'precision', 'recall', 'accuracy']
        self.sampling = StratifiedKFold(n_splits=self.cv, shuffle=True, random_state=self.seed)

    def evaluate_by_cell_level_signals(self, X, y_true, header, regress=False, simulate=False):
        total_length = len(y_true)
        X_columns = X.columns
        X, y_true, index = remove_nan_data(np.array(X), y_true)
        X = pd.DataFrame(X, columns=X_columns)
        y_pred, y_order, df = evaluate_by_raw_signal(X, y_true, header, regress, simulate)
        df = df.assign(target='cell')
        y_pred, y_order = fill_prediction(y_pred, y_order, total_length, index)
        return y_pred.tolist(), y_order.tolist(), df

    def predict_by_cell_level_signals(self, X, header, regress=False):
        y_pred, y_order = predict_by_raw_signal(X, regress)
        return y_pred.tolist(), y_order.tolist()


    def evaluate_by_cluster_level_signals(self, X, y_cluster, cluster_annotation, header, regress=False, simulate=False):
        mX, mX_index = average_for_each_cluster(X, y_cluster)
        mX = pd.DataFrame(mX, index=mX_index, columns=X.columns)
        y_true = convert_to_raw_celltype([cluster_annotation.loc[i,'celltype'] for i in mX_index])
        y_pred, y_order, df = self.evaluate_by_cell_level_signals(mX, y_true, header, regress, simulate)
        df = df.assign(target='cluster')
        return y_pred, y_order, df

    def predict_by_cluster_level_signals(self, X, y_cluster, answer, header, regress=False):
        X.index = y_cluster
        X.index = X.index.rename('index')
        mX = X.reset_index().groupby(by='index').mean()
        y_true = [cluster_annotation.loc[i,'celltype'] for i in mX.index]
        index, y_true = zip(*[(i, x) for i, x in enumerate(y_true) if x == x])
        mX = mX.loc[np.array(index),:]
        Y_pred, Y_order = self.predict_by_cell_level_signals(mX, y_true, header, regress)
        return Y_pred, Y_order

    def convert_celltype_labels_to_target(self, cluster, celltype_labels):
        if cluster in ['cluster', 'celltype']:
            return celltype_labels
        elif cluster == 'neuron':
            dict = {'EX':'P', 'IN':'P', 'NN':'N', 'NA':'NA'}
            return [dict[x] if x in dict else x for x in celltype_labels]
        elif cluster == 'inex':
            dict = {'EX':'EX', 'IN':'IN', 'NN':'NA'}
            return [dict[x] if x in dict else x for x in celltype_labels]

    def evaluate_exp_prediction_for_X(self, adata, all_result, marker_type, mode, key, cluster, new_X, cluster_list, cluster_annotation, evaluate_flag, simulate=False):
        tail = {'celltype':'each', 'neuron':'neach', 'cluster':'cluster', 'inex':'ieach'}
        result_tail = {'celltype':'', 'neuron':'_neuron', 'inex':'_inex'}
        celltype_labels = new_X.columns
        if evaluate_flag:
            if cluster == 'cluster':
                Y_pred, Y_order, df = self.evaluate_by_cluster_level_signals(new_X, adata.obs.loc[:, cluster_list['cluster']], cluster_annotation, self.all_header+'_'+mode+'_'+key+'_'+tail[cluster], len(set(celltype_labels)) == 3, simulate)
            elif cluster == 'celltype':
                Y_pred, Y_order, df = self.evaluate_by_cell_level_signals(new_X, convert_to_raw_celltype(adata.obs[cluster]), self.all_header+'_'+mode+'_'+key+'_'+tail[cluster], len(set(celltype_labels)) == 3, simulate)
            else:
                Y_pred, Y_order, df = self.evaluate_by_cell_level_signals(new_X, adata.obs[cluster], self.all_header+'_'+mode+'_'+key+'_'+tail[cluster], len(set(celltype_labels)) == 3, simulate)
            df = df.assign(marker=marker_type)
            df = df.assign(mode=mode)
            df = df.assign(problem=cluster)
            all_result = pd.concat([all_result, df])
        else:
            if cluster == 'cluster':
                Y_pred, Y_order = self.predict_by_cluster_level_signals(new_X, self.all_header+'_'+mode+'_'+key+'_'+tail[cluster], len(set(celltype_labels)) == 3)
            else:
                Y_pred, Y_order = self.predict_by_cell_level_signals(new_X, self.all_header+'_'+mode+'_'+key+'_'+tail[cluster], len(set(celltype_labels)) == 3)
        if cluster != 'cluster':
            adata.obs.loc[:,'pred_'+mode+'_'+key+result_tail[cluster]] = Y_pred
            adata.obs.loc[:,'ord_'+mode+'_'+key+result_tail[cluster]] = Y_order
        return adata, all_result

    def evaluate_agg_exp_prediction(self, adata, marker_list, cluster_list, cluster_annotation, evaluate_flag, evaluate_labels):
        all_result = None
        marker_groups = set(['_'.join(c.split('_')[0:-1]) for c in marker_list])
        adata = self.compute_each_cell_signals(adata, marker_list)
        for mode in ['average', 'rankmean']:
            names = [mode+'_'+key for key in marker_groups]
            for key in sorted(marker_groups):
                marker_columns = [m for m in marker_list if m.startswith(key)]
                marker_type = key.split('_')[0]
                if self.args['verbose']:
                    print('-- Predicting', key, marker_columns)
                if marker_type == "SM":
                    marker_type = '_'.join(key.split('_')[0:-1])
                    ori_celltype_labels = [c.split('_')[-1] for c in marker_columns]
                else:
                    ori_celltype_labels = [c.split('_')[1] for c in marker_columns]
                X = adata.obs.loc[:, [mode+'_'+x for x in marker_columns]]
                for cluster in evaluate_labels:
                    X.columns = self.convert_celltype_labels_to_target(cluster, ori_celltype_labels)
                    new_X = X.iloc[:,np.array([i for i, x in enumerate(X.columns) if 'NA' not in x])]
                    assert X.shape[1] >= new_X.shape[1]
                    adata, all_result = self.evaluate_exp_prediction_for_X(adata, all_result, marker_type, mode, key, cluster, new_X, cluster_list, cluster_annotation, evaluate_flag)
            if evaluate_flag:
                all_result.to_csv(self.all_header+'_'+'auroc.csv')
        adata.obs.to_csv(self.all_header+'_'+'prediction.csv')

    def sample_averaged_cells(self, adata, cluster, threshold, sample_size=100):
        celltypes = [x for x in adata.obs[cluster].unique() if 'NA' not in x]
        each_index = {}
        for cell in celltypes:
            each_index[cell] = np.linspace(0, adata.obs.shape[0]-1, num=adata.obs.shape[0], dtype=int)[(adata.obs[cluster] == cell)]
        new_X, y_true = None, None
        if threshold <= 50 and sample_size < 100:
            sample_size = 100
        if threshold == 1:
            sampled_index = [np.random.choice(each_index[cell], sample_size) for cell in celltypes]
            new_X = adata.X[np.array([i for x in sampled_index for i in x]),:]
            y_true = [[cell]*sample_size for cell in celltypes]
            y_true = pd.Series([i for x in y_true for i in x])
        else:
            y_true_temp = [[cell]*threshold for cell in celltypes]
            y_true_temp = pd.Series([i for x in y_true_temp for i in x])
            for s in range(sample_size):
                sampled_index = [np.random.choice(each_index[cell], threshold) for cell in celltypes]
                X = np.vstack(tuple([adata.X[x,:].todense() for x in sampled_index]))
                mX, mX_index  = average_for_each_cluster_less_memory(X, y_true_temp, verbose=False)
                y_true = pd.concat((y_true, pd.Series(mX_index)))
                if new_X is None: new_X = mX
                else: new_X = np.vstack((new_X, mX))
        aggdata = sc.AnnData(new_X, obs=pd.DataFrame(pd.Series(y_true, name=cluster)), var=pd.DataFrame(index=adata.var.index))
        return aggdata

    def simulate_agg_exp_prediction(self, adata, marker_list, cluster_list, cluster_annotation, evaluate_flag, evaluate_labels):
        marker_groups = set(['_'.join(c.split('_')[0:-1]) for c in marker_list])
        simualte = "down_sample"
        threshold = [1, 5, 10, 25, 50, 75, 100, 150, 200]
        sample_times = 10
        all_result = None
        for th in threshold:
            temp_result = None
            for cluster in evaluate_labels:
                if cluster == 'cluster': continue
                for s in range(sample_times):
                    temp_result = None
                    aggdata = self.sample_averaged_cells(adata, cluster, th)
                    aggdata = self.compute_each_cell_signals(aggdata, marker_list)
                    for mode in ['average', 'rankmean']:
                        names = [mode+'_'+key for key in marker_groups]
                        for key in sorted(marker_groups):
                            marker_columns = [m for m in marker_list if m.startswith(key)]
                            marker_type = key.split('_')[0]
                            if self.args['verbose']:
                                print('-- Predicting', key, marker_columns)
                            if marker_type == "SM":
                                marker_type = '_'.join(key.split('_')[0:-1])
                                ori_celltype_labels = [c.split('_')[-1] for c in marker_columns]
                            else:
                                ori_celltype_labels = [c.split('_')[1] for c in marker_columns]
                            X = aggdata.obs.loc[:, [mode+'_'+x for x in marker_columns]]
                            X.columns = self.convert_celltype_labels_to_target(cluster, ori_celltype_labels)
                            assert adata.X.shape[1] >= X.shape[1]
                            _, temp_result = self.evaluate_exp_prediction_for_X(aggdata, temp_result, marker_type, mode, key, cluster, X, cluster_list, cluster_annotation, evaluate_flag, simulate=True)
            if evaluate_flag:
                ofname = self.all_header+'_simulate_'+simulate+'_'+str(th)+'_auroc.csv'
                if self.verbose:
                    print('Writing to ...', ofname)
                temp_result.to_csv(ofname)
            all_result = pd.concat((all_result, temp_result))

    def get_feature_file(self, key, train_label):
        return os.path.join(self.args['clf_dir'], self.args['train_out']+'_'+key+'_'+train_label+'_features.csv')
    
    def extract_features(self, key, adata, marker_list, train_label):
        global TOP_GENES_MAX
        marker_columns = []
        if self.args['test_out'] == '':
            if key == 'all':
                flag = (-adata.var.loc[:,'cov']).argsort()
                all_genes = sorted(adata.var.loc[flag < TOP_GENES_MAX,:].index.tolist())
            else:
                marker_columns = [m for m in marker_list if m.startswith(key)]
                all_genes = sorted(list(set([x for m in marker_columns for x in marker_list[m]])))
        else:
            feature_file = self.get_feature_file(key, train_label)
            if not os.path.exists(feature_file):
                raise Exception('No feature file', feature_file)
            all_genes = pd.read_csv(feature_file, header=None, index_col=0).iloc[:,0].tolist()
        bdata = adata[:,[x for x in all_genes if x in adata.var.index]]
        if self.args['verbose']:
            print('Total genes:',  len(all_genes))
            print('Detected genes:', len([x for x in all_genes if x in adata.var.index.tolist()]))
            print(bdata.shape)
        return bdata, marker_columns, all_genes

    def obtain_x_and_y(self, bdata, key, train_label, cluster, cluster_list, cluster_annotation, all_genes):
        X = bdata.X.copy()
        header = os.path.join(self.args['clf_dir'], self.args['train_out']+'_'+key+'_'+train_label+'_'+('cluster' if cluster == 'cluster' else 'each'))
        test_header = os.path.join(self.args['clf_dir'], self.args['test_out']+'_by_'+self.args['train_out']+'_'+key+'_'+train_label+'_'+('each' if cluster == 'cell' else 'cluster'))
        if cluster == 'cell':
            if train_label == 'celltype':
                y_true = convert_to_raw_celltype(bdata.obs[cluster_list[train_label]])
            else:
                y_true = bdata.obs[cluster_list[train_label]]
            y_original = bdata.obs[cluster_list[train_label]]
        else:
            mX, mX_index = average_for_each_cluster_less_memory(X, bdata.obs[cluster_list[cluster]])
            assert mX.shape[0] == len(mX_index)
            print(mX_index)
            if train_label == 'celltype':
                y_true = convert_to_raw_celltype([cluster_annotation.loc[i,'celltype'] for i in mX_index])
            else:
                y_true = mX_index
            y_original = mX_index
            X = mX
        if cluster == 'cluster':
            min_cluster = min(Counter(y_true).values())
        else:
            min_cluster = self.args['cv']
        return header, X, y_true, min_cluster, test_header, y_original


    def train_lasso_prediction(self, adata, marker_list, cluster_list, cluster_annotation, train_label, resolution): # supervised, only annotated data is allowed for training
        marker_groups = set(['_'.join(c.split('_')[0:-1]) for c in marker_list])
        for key in sorted(list(marker_groups)+['all']): # there is a prior for marker genes
            all_result = None
            if key not in ['SF', 'SC', 'CU', 'TA', 'TN', 'all']: continue
            bdata, marker_columns, all_genes = self.extract_features(key, adata, marker_list, train_label)
            print(len(all_genes), all_genes)
            print(bdata.var.shape[0], bdata.var.index.tolist())
            pd.Series(bdata.var.index.tolist()).to_csv(self.get_feature_file(key, train_label))
            if self.args['verbose']:
                print('-- Train classifiers', key, len(all_genes), marker_columns)
            for cluster in resolution:
                header, X, y_true, min_cluster, _, _ = self.obtain_x_and_y(bdata, key, train_label, cluster, cluster_list, cluster_annotation, all_genes)
                df = train_by_raw_signal(X, y_true, header, self.sampling, bdata.var.index.tolist(), self.args['verbose'], no_cv=((self.args['cv'] <= 1) or (min_cluster < self.args['cv']) or (cluster == 'cluster') or (train_label == 'cluster')), cores=(1 if train_label == 'cluster' else 1), saved=False)
                if df is None: continue
                df = df.assign(prior=key)
                df = df.assign(target=cluster)
                all_result = pd.concat([all_result, df])
            if all_result is not None:
                all_result.to_csv(self.args['train_out']+'_'+train_label+'_'+key+'_cv_auroc.csv') # update in each for loop

    def fill_undetected_genes(self, X, detected_genes, all_genes):
        if self.args['verbose']:
            print('Filling a matrix', len(all_genes), len(detected_genes), 'all_genes', all_genes[0:10], 'detected', detected_genes[0:10])
        if len(detected_genes) == 0:
            X = np.zeros(X.shape[0], len(all_genes))
            return X
        filled_X = np.zeros((X.shape[0], len(all_genes)-len(detected_genes)))
        X = np.concatenate((X.reshape((X.shape[0], (X.shape[1] if len(X.shape) == 2 else 1))), filled_X), axis=1)
        gene_order = detected_genes + [gene for gene in all_genes if gene not in detected_genes]
        X = X[:,np.array([gene_order.index(x) for x in all_genes])]
        return X
    
    def assign_cell_level_prediction(self, key, train_label, Y_pred, Y_order, adata):
        for m in Y_pred:
            adata.obs.loc[:,('pred_'+key+'_'+m)] = Y_pred[m].values
            if self.args['verbose']:
                print(Y_pred[m].values[adata.obs.loc[:,'celltype'] != "NA"][0:10])
                print(adata.obs.loc[adata.obs.loc[:,'celltype'] != "NA",train_label])
                print(adata.obs.loc[adata.obs.loc[:,'celltype'] != "NA",('pred_'+key+'_'+m)])
        for m in Y_order:
            adata.obs.loc[:,('ord_'+key+'_'+m)] = Y_order[m].values
        return adata.obs
    
    def assign_cluster_level_prediction(self, key, train_label, Y_pred, Y_order, pred_result):
        for m in Y_pred:
            if pred_result is None: pred_result = pd.DataFrame(Y_pred[m], columns=['pred_'+key+'_'+m])
            else:   pred_result.loc[:,('pred_'+key+'_'+m)] = Y_pred[m]
        for m in Y_order:
            pred_result.loc[:,('ord_'+key+'_'+m)] = Y_order[m]
        return pred_result

    def evaluate_lasso_prediction(self, adata, marker_list, cluster_list, cluster_annotation, train_label, resolution):
        marker_groups = set(['_'.join(c.split('_')[0:-1]) for c in marker_list])
        for key in sorted(list(marker_groups)+['all']): # a prior for marker genes
            all_result = None
            if key not in ['SF', 'SC', 'CU', 'TA', 'TN', 'all']: continue
            bdata, marker_columns, all_genes = self.extract_features(key, adata, marker_list, train_label)
            if self.args['verbose']:
                print('-- Evaluate trained classifiers', train_label, key)
            for cluster in resolution:
                pred_result = None
                if bdata.X.shape[0] == 0:
                    print('No gene is detected', bdata.var.index, all_genes)
                    continue
                train_header, X, y_true, min_cluster, test_header, y_original = self.obtain_x_and_y(bdata, key, train_label, cluster, cluster_list, cluster_annotation, all_genes)
                X = self.fill_undetected_genes(X, bdata.var.index.tolist(), all_genes)
                assert len(X.shape) == 2
                if train_label =='cluster': # answer is not aligned 
                    Y_pred, Y_order, _ = predict_by_raw_signal_trained(X, train_header, test_header, bdata.var.index.tolist(), self.args['verbose'], (cluster == 'cell'))
                else:
                    Y_pred, Y_order, df = evaluate_by_raw_signal_trained(X, y_true, train_header, test_header, bdata.var.index.tolist(), self.args['verbose'], (cluster == 'cell'))
                    df = df.assign(prior=key)
                    df = df.assign(target=cluster)
                    df = df.assign(problem=train_label)
                    if self.args['verbose']:
                        print('- Predicted')
                        print(df)
                    if all_result is None: all_result = df
                    else: all_result = pd.concat([all_result, df])
                if cluster == 'cell':
                    adata.obs = self.assign_cell_level_prediction(key, train_label, Y_pred, Y_order, adata)
                else:
                    pred_result = self.assign_cluster_level_prediction(key, train_label, Y_pred, Y_order, None)
                if all_result is not None:
                    all_result.to_csv(self.args['test_out']+'_by_'+self.args['train_out']+'_'+train_label+'_'+key+'_auroc.csv') # update in each for-loop
                if pred_result is not None:
                    print([len(Y_pred[x]) for x in Y_pred], [len(Y_order[x]) for x in Y_order], key, cluster, train_label, X.shape)
                    print(pred_result.shape)
                    print(len(y_true), len(y_original))
                    pred_result.index = [o for t, o in zip(y_true, y_original) if t == t and 'NA' not in str(t)]
                    pred_result.to_csv(self.args['test_out']+'_by_'+self.args['train_out']+'_'+train_label+'_'+key+'_'+cluster+'_prediction.csv')
                else:
                    with open(self.args['test_out']+'_by_'+self.args['train_out']+'_'+train_label+'_'+key+'_'+cluster+'_prediction.csv', 'w') as file:
                        adata.obs.to_csv(file)


    def test_raw_expression(self, all_cells):
        pass

    def set_neuron_column(self, cluster_list, all_cells, resolution):
        if 'celltype' not in all_cells.obs.columns:
            return all_cells, cluster_list, resolution
        all_cells.obs['celltype'] = convert_to_raw_celltype(all_cells.obs.loc[:,'celltype'])
        all_cells.obs['neuron'] = convert_to_raw_celltype(all_cells.obs['celltype'])
        all_cells.obs['neuron'] = ['P' if c in ['EX', 'IN'] else 'NA' if c == 'NA' else 'N' for c in all_cells.obs['neuron']]
        all_cells.obs['inex'] = convert_to_raw_celltype(all_cells.obs['celltype'])
        all_cells.obs['inex'] = ['EX' if c == 'EX' else 'IN' if c == 'IN' else 'NA' for c in all_cells.obs['inex']]
        resolution['neuron'] = ['cell']
        resolution['inex'] = ['cell']
        cluster_list['neuron'] = 'neuron'
        cluster_list['inex'] = 'inex'
        return all_cells, cluster_list, resolution

    def read_all_cells(self):
        all_cells = None
        if not os.path.exists(self.all_cell_path):
            print('No file found, run visualization first', self.all_cell_path)
            return None
        if os.path.exists(self.all_cell_modif_path):
            try:
                print('Try to read:', self.all_cell_modif_path)
                with open(self.all_cell_modif_path, "rb") as f:
                    all_cells = pickle.load(f)
            except:
                print('Try to read:', self.all_cell_path)
                with open(self.all_cell_path, "rb") as f:
                    all_cells = pickle.load(f)
                self.write_modif_adata(all_cells)
                with open(self.all_cell_modif_path, "wb") as f:
                    pickle.dump(all_cells, f)
        if self.args['debug']:
            self.test_raw_expression(all_cells)
        return all_cells
        
    def run_prediction(self):
        all_cells = self.read_all_cells()
        if all_cells is None: return
        marker_list = read_biomarker_matrix(self.args['markers'], self.args['mdir'], self.args['top_markers'], self.args['verbose'], self.args['data_markers'])
        resolution = {'celltype':['cell', 'cluster'], 'cluster':['cell', 'cluster']}
        cluster_list = dict([('celltype', 'celltype')])
        if 'Ident' in all_cells.obs and 'Ident' in self.args['cluster']:
            cluster_list['cluster'] = 'Ident'
        else:
            cluster_list['cluster'] = 'cluster'
        all_cells, cluster_list, resolution = self.set_neuron_column(cluster_list, all_cells, resolution)
        if self.args['verbose']:
            print(all_cells)
        evaluate_flag = ('celltype' in all_cells.obs.columns)
        if self.args['verbose']:
            if evaluate_flag: print('-- Evaluation mode')
            else:   print('-- Prediction mode')
        if self.args['train_out'] == '':
            cluster_annotation = (pd.read_csv(self.args['cannotation'], index_col=0) if self.args['cannotation'] is not None else None)
            if self.args['simulate']:
                self.simulate_agg_exp_prediction(all_cells, marker_list, cluster_list, cluster_annotation, evaluate_flag, resolution.keys())
            else:
                self.evaluate_agg_exp_prediction(all_cells, marker_list, cluster_list, cluster_annotation, evaluate_flag, resolution.keys())
        elif len(self.args['train_out']) > 0:
            if self.args['rank']:
                if self.args['verbose']: print('-- Converting to rank data!')
                if len(self.args['test_out']) > 0: self.args['test_out'] = self.args['test_out']+'_rank'
                else:   self.args['train_out'] = self.args['train_out']+'_rank'
            all_cells = normalize_features_for_cls(all_cells, self.args['rank'], marker_list, self.args['top_markers'])
            for key in (['celltype', 'neuron', 'inex', 'cluster'] if self.args['prediction'] == '' else [self.args['prediction']]):
                if key not in resolution:
                    continue
                if key == 'celltype':
                    cluster_annotation = (pd.read_csv(self.args['cannotation'], index_col=0) if self.args['cannotation'] is not None else None)
                else:
                    cluster_annotation = None
                if self.args['test_out'] != '':
                    obs_table = all_cells.obs.copy(deep=True)
                    assert obs_table.loc[:,obs_table.columns.str.startswith('ord_CU')].shape[1] == 0
                    print(obs_table.columns)
                    self.evaluate_lasso_prediction(all_cells, marker_list, cluster_list, cluster_annotation, key, resolution[key])
                    assert obs_table.loc[:,obs_table.columns.str.startswith('ord_CU')].shape[1] == 0
                    all_cells.obs = obs_table
                    print(all_cells.obs.columns)
                else:
                    self.train_lasso_prediction(all_cells, marker_list, cluster_list, cluster_annotation, key, resolution[key])
        else:
            pass


class CatactorItemSet(Catactor):

    def __init__(self, args):
        Catactor.__init__(self, args)
        if self.args['scobj']:
            self.all_cell_path = self.args['scobj']
            self.all_header = self.args['scobj'].replace('_scanpy_obj.pyn', '')
            self.use_row = ('_trans' in self.all_cell_path)
        self.min_lift = 0.5
        self.min_cov  = 10
  
    def convert_to_transaction(self, mat, dir, outfile, inactive_cluster={}):
        if self.args['verbose']:
            print('-- Write transaction to ', dir, outfile)
        with open(os.path.join(dir, self.matrix, outfile), 'w') as f:
            for i in range(len(mat.indptr)):
                if i < len(mat.indptr)-1:
                    trans = list(mat.indices[mat.indptr[i]:mat.indptr[i+1]])
                else:
                    trans = list(mat.indices[mat.indptr[i]:mat.shape[0]])
                trans = [x for x in trans if x not in inactive_cluster]
                f.write(' '.join(map(str, trans))+'\n')

    
    def check_transaction_number(self, file):
        with open(file) as f:
            lines = len([1 for line in f.readlines() if len(line) > 1])
        return lines

    def before_convergence(self, cluster_list, max_r, min_c, all_cell_num, threshold):
        if cluster_list.shape[0] <= min_c:
            return False
        if all([row['cov'] < self.min_cov or row['cov']/float(all_cell_num) >= max_r for i, row in cluster_list.iterrows()]):
            return False
        if threshold < self.min_lift:
            return False
        return True
    
    def obtain_inactive_cluster_list(self, clust_list, max_r, all_cell_num):
        inactive_index = dict([(row['n_index'], 0) for i, row in clust_list.iterrows() if row['cov']/float(all_cell_num) >= max_r and row['cov'] < self.min_cov ])
        if self.args['verbose']:
            print('Inactive columns:', len(inactive_index), inactive_index.keys())
        return inactive_index

    def run_lcm_for_cluster_merging(self, dir, path, out, freq=50, threshold=0.001):
        min_comb = 2
        command = [os.path.join(self.args['lcm_dir'], 'lcm')]
        commands = []
        commands.append(' '.join(list(map(str, command+['FQs', '-f', threshold, '-l', min_comb, '-u', min_comb, os.path.join(dir, path), freq, os.path.join(dir, out)]))))
        for i, com in enumerate(commands):
            print('time_start_', i, datetime.datetime.now())
            print(com)
            subprocess.run(com, shell=True, capture_output=False, stdout=subprocess.DEVNULL)
            print('time_end_all', datetime.datetime.now())
        

    def merge_column_lcm(self, mat, new_cluster, cluster_list):
        cluster_list = cluster_list.set_index('n_index')
        column_max = cluster_list.shape[0]
        conv_mat = self.compute_column_conversion(cluster_list, column_max, projected=new_cluster)
        return mat.dot(conv_mat)

    def integrate_merge_history(self, cluster_list, new_cluster):
        new_cluster_list = pd.DataFrame({'n_index':list(range(max(new_cluster)+1)), 'all_indices':''})
        print(new_cluster)
        cluster_list = cluster_list.loc[:,['n_index', 'all_indices', 'cov']]
        for i in range(cluster_list.shape[0]):
            new_cluster_list.iloc[new_cluster[i],1] +=  (',' if new_cluster_list.iloc[new_cluster[i],1] != '' else '')+cluster_list.iloc[i,1]
        print(new_cluster_list.shape, cluster_list.shape)
        assert all(new_cluster_list.loc[:,'all_indices'] != '')
        return new_cluster_list

    def aggregate_cluster_based_on_list(self, mat, cluster_list, new_cluster):
        new_cluster_list = self.integrate_merge_history(cluster_list, new_cluster)
        mat = self.merge_column_lcm(mat, new_cluster, cluster_list)
        print('agg', mat.shape, new_cluster_list.shape, cluster_list.shape, new_cluster.shape)
        new_cluster_list = new_cluster_list.assign(cov=np.squeeze(np.array(mat.sum(axis=0))))
        return mat, new_cluster_list

    def read_lcm_result(self, dir, path, clust_list, inactive_cluster):
        temp_file = os.path.join(dir, path)
        new_cluster = np.full((clust_list.shape[0]), -1, dtype=np.int)
        count = 0
        with open(temp_file, 'r') as f:
            for line in f.readlines():
                contents = np.array(list(map(int, re.split('[\[ ,\]]', line.rstrip('\n'))[3:])))
                clust = max(new_cluster[contents])
                if clust < 0:
                    new_cluster[contents] = count
                    count += 1
                else:
                    for c in contents:
                        if new_cluster[c] < 0:
                            new_cluster[c] = clust
                        elif new_cluster[c] != clust:
                            new_cluster = np.where(new_cluster == new_cluster[c], clust, new_cluster)
        for ic in inactive_cluster:
            new_cluster[ic] = count
            count += 1
        for i, x in enumerate(new_cluster):
            if x < 0:
                new_cluster[i] = count
                count += 1
        clusters = list(set(new_cluster))
        assert (-1 not in clusters)
        new_cluster = np.array([clusters.index(x) for x in new_cluster])
        print('new_cluster')
        print(new_cluster)
        return new_cluster
    
    def run_itemset_based_clustering(self):
        if self.args['verbose']:
            print('-- Run LCM-based clustering', self.all_cell_path)
        min_c, max_r = self.args['min_c'], self.args['max_r'] 
        with open(self.all_cell_path, "rb") as f:
            all_cells = pickle.load(f)
        if self.use_row:
            mat = all_cells.X.transpose()
            all_peak_num, all_cell_num = all_cells.shape
        else:
            mat = all_cells.X.copy()
            all_cell_num, all_peak_num = all_cells.shape
        temporal_cluster_list = pd.DataFrame({'all_indices': [str(x) for x in range(all_peak_num)], 'cov':np.squeeze(np.array(mat.sum(axis=0))), 'n_index': list(range(0, all_peak_num))})
        threshold = 5
        for iter in range(100):
            if self.args['verbose']:
                print('Peaks are being merged...', iter, 'times', mat.shape)
            if not self.before_convergence(temporal_cluster_list, max_r, min_c, all_cell_num, threshold):
                break
            mat.data = np.ones(mat.data.shape, dtype=int)
            inactive_cluster = self.obtain_inactive_cluster_list(temporal_cluster_list, max_r, all_cell_num)
            if len(inactive_cluster) >= mat.shape[1]-1:
                break
            _, path = tempfile.mkstemp()
            _, out  = tempfile.mkstemp()
            try:
                self.convert_to_transaction(mat, '/', path, inactive_cluster)
                self.run_lcm_for_cluster_merging('/', path, out, threshold=threshold)
                new_cluster_list = self.read_lcm_result('/', out, temporal_cluster_list, inactive_cluster)
                threshold = max(min(threshold/1.1, threshold-0.1), self.min_lift)
                if temporal_cluster_list.shape[0]-1 <= max(new_cluster_list):
                    print('Update threshold only')
                    continue
                mat, temporal_cluster_list = self.aggregate_cluster_based_on_list(mat, temporal_cluster_list, new_cluster_list)
            finally:
                os.remove(path)
                # os.remove(out)
        temporal_cluster_list.to_csv(self.all_header+'_cluster.csv')
        with open(self.all_header+'_mat.pyn', 'wb') as f:
            pickle.dump(mat, f)

def run_Catactor(args):
    args = parse_args(args)
    mode, args = extract_mode(args)
    if args['verbose']:
        print('Start mode:', mode)
    if mode is None:
        return True
    if 'annotation' in mode:
        ann_atac = Annotation(args)
        if mode == 'column_annotation':
            ann_atac.construct_basic_annotation_columns()
        elif mode == 'row_annotation':
            ann_atac.construct_basic_annotation_rows()
    elif mode == 'rank_analysis':
        rank_ann = RankAnalysis(args)
        rank_ann.run_rank_analysis()
    else:
        # use scanpy objects
        if mode in ['transaction', 'detection']:
            catac = CatactorItemSet(args)
        elif mode == 'prediction':
            catac = CatactorLassoAlt(args)
        else:
            catac = Catactor(args)
        if mode == 'transaction':
            pass
            # catac.obtain_transaction_data()
        elif mode == 'detection': 
            catac.run_itemset_based_clustering()
        elif mode == 'preprocess':
            catac.run_visualization(True)
        elif mode == 'average':
            catac.run_average_profiling()
        elif mode == 'visualization':
            catac.run_visualization(args['test_vis'])
        elif mode == 'prediction':
            catac.run_prediction()
        else:
            return False

if __name__ == "__main__":
    np.random.seed(42)
    parser = get_parser()
    args = parser.parse_args()
    error_flag = True
    if args.silent:
        args.verbose = False
    if args.adir == '':
        args.adir = args.dir
    if len(args.files) > 0:
        error_flag = run_Catactor(args)
    if error_flag:
        parser.print_help()
