#!python
# -*- coding: utf-8 -*-
# @Time : 2023/1/5 11:03
# @Author : jmzhang
# @Email : zhangjm@biomarker.com.cn


import spatial_tools
import pandas as pd
import scanpy as sc
import argparse
import logging

if __name__ == '__main__':
    desc = """
    Version: Version v2
    Contact: zhangjm <zhangjm@biomarker.com.cn>
    Program Date: 2022.10.25
    Description: spatial tools
    """

    parser = argparse.ArgumentParser(description=desc, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--pic', type=str, help='seurat_obj')
    parser.add_argument('--barcodes_pos', type=str, help='barcodes pos')
    parser.add_argument('--adata', type=str, help='pd.DataFrame or AnnData')
    parser.add_argument('--color', type=str, help='seurat_clusters')
    parser.add_argument('--feature', type=str, help='feature', default=None)
    parser.add_argument('--dpi', type=int, help='dpi', default=500)
    parser.add_argument('--groups', type=str, help='groups', default=None)
    parser.add_argument('--size', type=float, help='dot size', default=1)
    parser.add_argument('--size_auto', type=bool, help='just size auto', default=False)
    parser.add_argument('--alpha', type=float, help='alpha', default=1)
    parser.add_argument('--figsize', type=str, help='figsize', default='(10, 10)')
    parser.add_argument('--to_save', type=str, help='to_save', default=None)
    parser.add_argument('--darken', type=float, help='darken the color', default=None)
    parser.add_argument('--return_fig', type=bool, help='return_fig', default=False)
    parser.add_argument('--return_table', type=bool, help='return_table', default=False)
    parser.add_argument('--split', type=bool, help='colnames of groups', default=False)
    parser.add_argument('--ncol', type=int, help='ncol', default=2)
    parser.add_argument('--color_dict', type=str, help='color_dict', default=None)
    parser.add_argument('--hspace', type=float, help='hspace', default=None)
    parser.add_argument('--wspace', type=float, help='wspace', default=None)
    parser.add_argument('--para_dict', type=str, help='para_dict', default=None)
    input_args = parser.parse_args()

    st_data = spatial_tools.SpatialTools(pic=input_args.pic, barcodes_pos=input_args.barcodes_pos)

    if input_args.groups:
        input_args.groups = st_data.resolve_para_to_list_tuple_dict(input_args.groups, list, str)

    if input_args.figsize:
        input_args.figsize = st_data.resolve_para_to_list_tuple_dict(input_args.figsize, tuple, float)

    if input_args.feature:
        input_args.feature = st_data.resolve_para_to_list_tuple_dict(input_args.feature, list, str)

    if input_args.color_dict:
        input_args.color_dict = st_data.resolve_para_to_list_tuple_dict(input_args.color_dict, dict)

    if input_args.para_dict:
        input_args.para_dict = st_data.resolve_para_to_list_tuple_dict(input_args.para_dict, dict)

    logging.info('input args: {}'.format(input_args))

    if str(input_args.adata).endswith('.xls'):
        logging.info('reading xls')
        adata = pd.read_csv(input_args.adata, sep='\t')
        if 'barcode' not in adata.columns:
            raise ValueError('xls file must contains barcode column')
        else:
            adata.index = adata['barcode']
            adata.index.name = None

    elif str(input_args.adata).endswith('.loom'):
        logging.info('reading loom file ...')
        adata = sc.read_loom(input_args.adata)

    else:
        raise ValueError('wrong adata args!')

    logging.info('plotting ...')
    st_data.s1000_spatial_plot(
        adata=adata,
        color=input_args.color,
        groups=input_args.groups,
        size=input_args.size,
        auto_size=input_args.size_auto,
        alpha=input_args.alpha,
        figsize=input_args.figsize,
        to_save=input_args.to_save,
        darken=input_args.darken,
        dpi=input_args.dpi,
        return_fig=input_args.return_fig,
        return_table=input_args.return_table,
        feature=input_args.feature,
        split=input_args.split,
        ncol=input_args.ncol,
        color_dict=input_args.color_dict,
        hspace=input_args.hspace,
        wspace=input_args.wspace,
        para_dict=input_args.para_dict)

    logging.info('done!')
