#!python

import os
import glob
import argparse
import logging
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def convert_tab_to_csv(tab_file, csv_file):
    """Convert a .tab file to .csv format.

    Args:
        tab_file (str): Path to the .tab file.
        csv_file (str): Path to save the .csv file.
    """
    try:
        df = pd.read_csv(tab_file, sep="\t")
        df.to_csv(csv_file, index=False)
        logging.info(f"Converted {tab_file} to {csv_file}")
    except Exception as e:
        logging.error(f"Error converting {tab_file} to CSV: {e}")
        raise


def load_and_merge_data(ncbi_clean_path, abricate_summary_file):
    """Load and merge NCBI and Abricate data."""
    try:
        ncbi_clean_df = pd.read_csv(ncbi_clean_path)
        abricate_summary_df = pd.read_csv(abricate_summary_file)
        
        # Extract Assembly Accession from the '#File' column
        file_col = "#File" if "#File" in abricate_summary_df.columns else "#FILE"
        abricate_summary_df["Assembly Accession"] = abricate_summary_df[file_col].str.extract(r"(GCF_\d+\.\d+)")
        
        # Merge dataframes
        merged_df = pd.merge(ncbi_clean_df, abricate_summary_df, on="Assembly Accession", how="left")
        merged_df.fillna("0", inplace=True)
        
        return merged_df
    except Exception as e:
        logging.error(f"Error loading or merging data: {e}")
        raise


def save_merged_data(merged_df, output_dir, output_filename):
    """Save the merged dataframe to a CSV file."""
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, output_filename)
    merged_df.to_csv(output_file, index=False)
    logging.info(f"Merged file saved to: {output_file}")


def convert_to_tidy_format(df):
    """Convert the merged dataframe to tidy format."""
    try:
        if df.shape[1] < 14:
            raise ValueError("Not enough columns in the dataframe.")
        
        # Extract metadata columns
        id_vars = df.columns[:13]
        
        # Melt the dataframe
        df_tidy = df.melt(id_vars=id_vars, var_name="Gene", value_name="Identity")
        df_tidy["Identity"] = pd.to_numeric(df_tidy["Identity"], errors="coerce").fillna(0)
        
        # Add Presence column
        df_tidy["Presence"] = df_tidy["Identity"].apply(lambda x: 0 if x == 0 else 1)
        
        return df_tidy
    except Exception as e:
        logging.error(f"Error converting to tidy format: {e}")
        raise


def analyze_gene_presence(df, output_dir, base_name, fig_format):
    """Analyze gene presence and generate plots."""
    try:
        # Group by Gene and sum Presence
        gene_presence = df.groupby("Gene")["Presence"].sum().reset_index()
        gene_presence = gene_presence.sort_values(by="Presence", ascending=False)
        
        # Calculate percentage presence
        total_biosamples = df["BioSample"].nunique()
        gene_presence["Percentage"] = (gene_presence["Presence"] / total_biosamples) * 100
        
        # Save count and percentages
        count_percentage_file = os.path.join(output_dir, f"{base_name}_gene_presence_count_percentage.csv")
        gene_presence.to_csv(count_percentage_file, index=False)
        logging.info(f"Count and percentages saved to {count_percentage_file}")
        
        # Generate and save lollipop plot
        generate_lollipop_plot(gene_presence, output_dir, base_name, fig_format)
        
        # Generate and save percentage bar plot
        generate_percentage_bar_plot(gene_presence, output_dir, base_name, fig_format)
    except Exception as e:
        logging.error(f"Error analyzing gene presence: {e}")
        raise


def generate_lollipop_plot(gene_presence, output_dir, base_name, fig_format):
    """Generate a lollipop plot for gene presence."""
    plt.figure(figsize=(12, 8))
    plt.stem(gene_presence["Gene"], gene_presence["Presence"], basefmt=" ")
    plt.xlabel("Gene", fontsize=14)
    plt.ylabel("Total Presence", fontsize=14)
    plt.xticks(rotation=90)
    plt.ylim(0, gene_presence["Presence"].max() * 1.2)
    
    # Add numbers at the top of each lollipop
    for i, (gene, presence) in enumerate(zip(gene_presence["Gene"], gene_presence["Presence"])):
        plt.text(i, presence + 2, str(presence), ha="center", va="bottom", fontsize=12, rotation=90)
    
    plt.tight_layout()
    
    # Save the plot
    lollipop_output_file = os.path.join(output_dir, f"{base_name}_gene_presence_lollipop.{fig_format}")
    plt.savefig(lollipop_output_file, format=fig_format, dpi=300)
    plt.close()
    logging.info(f"Lollipop plot saved to {lollipop_output_file}")


def generate_percentage_bar_plot(gene_presence, output_dir, base_name, fig_format):
    """Generate a bar plot for gene presence percentages."""
    plt.figure(figsize=(12, 8))
    sns.barplot(data=gene_presence, x="Gene", y="Percentage", color="skyblue")
    plt.xlabel("Gene", fontsize=14)
    plt.ylabel("Prevalence", fontsize=14)
    plt.xticks(rotation=90)
    plt.ylim(0, gene_presence["Percentage"].max() * 1.2)
    
    # Add percentages at the top of each bar
    for i, percentage in enumerate(gene_presence["Percentage"]):
        plt.text(i, percentage + 0.8, f"{percentage:.1f}%", ha="center", va="bottom", fontsize=12, rotation=90)
    
    plt.tight_layout()
    
    # Save the plot
    percentage_output_file = os.path.join(output_dir, f"{base_name}_gene_presence_percentage.{fig_format}")
    plt.savefig(percentage_output_file, format=fig_format, dpi=300)
    plt.close()
    logging.info(f"Percentage plot saved to {percentage_output_file}")


def generate_country_comparison_heatmap(tidy_file, output_dir, fig_format):
    """Generate a heatmap for country comparison."""
    try:
        # Extract the base name of the file (e.g., "A" from "A_tidy_summary.csv")
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        
        # Load the file into a DataFrame
        df = pd.read_csv(tidy_file)
        
        # Extract relevant columns
        df = df[['Assembly Accession', 'Geographic Location', 'Gene', 'Presence']]
        
        # Remove duplicates based on all four columns
        df = df.drop_duplicates(subset=['Assembly Accession', 'Geographic Location', 'Gene', 'Presence'])
        
        # Filter locations with more than 4 unique Assembly Accession occurrences
        location_counts = df.groupby('Geographic Location')['Assembly Accession'].nunique()
        selected_locations = location_counts[location_counts > 4].index
        filtered_df = df[df['Geographic Location'].isin(selected_locations)]
        
        # Check if filtered_df is empty
        if filtered_df.empty:
            logging.warning(f"No data for {base_name} after filtering locations. Skipping.")
            return
        
        # Group by Geographic Location and Gene, then sum Presence
        grouped_df = filtered_df.groupby(['Geographic Location', 'Gene'], as_index=False)['Presence'].sum()
        
        # Count the number of Assembly Accession values for each Geographic Location
        assembly_count_per_location = filtered_df.groupby('Geographic Location')['Assembly Accession'].nunique().reset_index()
        assembly_count_per_location.columns = ['Geographic Location', 'Assembly Count']
        
        # Merge the counts with the grouped data
        merged_df = pd.merge(grouped_df, assembly_count_per_location, on='Geographic Location')
        
        # Calculate the percentage of each gene within each location
        merged_df['Percentage'] = (merged_df['Presence'] / merged_df['Assembly Count']) * 100
        
        # Save the results
        csv_file = os.path.join(output_dir, f"{base_name}_gene_percentage.csv")
        merged_df.to_csv(csv_file, index=False)
        logging.info(f"Gene percentage CSV saved to {csv_file}")
        
        # Create a complete grid of all Geographic Locations and Genes
        all_locations = merged_df['Geographic Location'].unique()
        all_genes = merged_df['Gene'].unique()
        complete_grid = pd.MultiIndex.from_product([all_locations, all_genes], names=['Geographic Location', 'Gene']).to_frame(index=False)
        
        # Merge the complete grid with the data to fill in missing values with 0%
        merged_df = pd.merge(complete_grid, merged_df, on=['Geographic Location', 'Gene'], how='left')
        merged_df['Percentage'] = merged_df['Percentage'].fillna(0)
        
        # Filter genes with more than 10% presence
        filtered_genes_df = merged_df[merged_df['Percentage'] > 10]
        
        # Check if filtered_genes_df is empty
        if filtered_genes_df.empty:
            logging.warning(f"No data for {base_name} after filtering genes. Skipping.")
            return
        
        # Pivot the data for the heatmap
        heatmap_data = filtered_genes_df.pivot(index='Geographic Location', columns='Gene', values='Percentage')
        heatmap_data = heatmap_data.fillna(0)
        heatmap_data = heatmap_data.round(0).astype(int)
        
        # Check if heatmap_data is empty
        if heatmap_data.empty:
            logging.warning(f"No data for {base_name} after pivoting. Skipping.")
            return
        
        # Calculate overall percentage for each gene
        total_assembly_count = filtered_df['Assembly Accession'].nunique()
        overall_percentage = filtered_genes_df.groupby('Gene')['Presence'].sum() / total_assembly_count * 100
        overall_percentage = overall_percentage.reset_index()
        overall_percentage.columns = ['Gene', 'Overall Percentage']
        
        # Add overall percentage as a new row in the heatmap
        overall_row = overall_percentage.set_index('Gene').T
        overall_row.index = ['Overall']
        heatmap_data = pd.concat([heatmap_data, overall_row])
        
        # Add counts to Geographic Location names
        location_counts = filtered_df.groupby('Geographic Location')['Assembly Accession'].nunique()
        heatmap_data.index = [f"{loc} ({count})" for loc, count in location_counts.items() if loc in heatmap_data.index] + ['Overall']
        
        # Set up the plot
        plt.figure(figsize=(20, 10))
        sns.set_style("white")
        sns.set(font_scale=1.2)
        
        # Create the heatmap
        ax = sns.heatmap(
            heatmap_data,
            annot=True,
            fmt=".0f",
            cmap="Reds",
            linewidths=0.5,
            cbar_kws={"shrink": 0.8},
            vmin=0,
            vmax=100,
        )
        
        # Customize the plot
        plt.xlabel("Genes/Plasmids", fontsize=14)
        plt.ylabel("Geographic Location (Total Submitted Sequence)", fontsize=14)
        plt.xticks(rotation=90)
        plt.yticks(rotation=0)
        plt.tight_layout()
        
        # Save the plot
        heatmap_output_file = os.path.join(output_dir, f"{base_name}_country_comparison_heatmap.{fig_format}")
        plt.savefig(heatmap_output_file, format=fig_format, dpi=300, bbox_inches='tight')
        plt.close()
        logging.info(f"Heatmap saved to {heatmap_output_file}")
    except Exception as e:
        logging.error(f"Error generating country comparison heatmap: {e}")
        raise


def main(ncbi_dir, abricate_dir, output_dir, fig_format):
    """Main function to process data and generate outputs."""
    logging.info("Starting the script.")
    
    # Define paths
    ncbi_clean_path = os.path.join(ncbi_dir, "ncbi_clean.csv")
    abricate_summary_files = glob.glob(os.path.join(abricate_dir, "*.[ct][sa][bv]"))  # Match .csv and .tab files
    
    if not os.path.exists(ncbi_clean_path):
        logging.error(f"ncbi_clean.csv not found in {ncbi_dir}.")
        return
    
    if not abricate_summary_files:
        logging.error(f"No CSV or TAB files found in {abricate_dir}.")
        return
    
    # Create output subdirectories
    merged_output_dir = os.path.join(output_dir, "merged_output")
    figures_dir = os.path.join(output_dir, "figures")
    os.makedirs(merged_output_dir, exist_ok=True)
    os.makedirs(figures_dir, exist_ok=True)
    
    for abricate_summary_file in abricate_summary_files:
        try:
            # Convert .tab to .csv if necessary
            if abricate_summary_file.endswith(".tab"):
                csv_file = abricate_summary_file.replace(".tab", ".csv")
                convert_tab_to_csv(abricate_summary_file, csv_file)
                abricate_summary_file = csv_file
            
            # Load and merge data
            merged_df = load_and_merge_data(ncbi_clean_path, abricate_summary_file)
            
            # Save merged data
            output_filename = f"ncbi_{os.path.basename(abricate_summary_file)}"
            save_merged_data(merged_df, merged_output_dir, output_filename)
            
            # Convert to tidy format
            tidy_df = convert_to_tidy_format(merged_df)
            tidy_file = os.path.join(merged_output_dir, output_filename.replace("_summary.csv", "_tidy_summary.csv"))
            tidy_df.to_csv(tidy_file, index=False)
            logging.info(f"Tidied file saved to {tidy_file}")
            
            # Analyze gene presence
            base_name = os.path.basename(abricate_summary_file).replace("_summary.csv", "")
            analyze_gene_presence(tidy_df, figures_dir, base_name, fig_format)
            
            # Generate country comparison heatmap
            generate_country_comparison_heatmap(tidy_file, figures_dir, fig_format)
        except Exception as e:
            logging.error(f"Error processing {abricate_summary_file}: {e}")
    
    logging.info("Script completed successfully.")


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Process NCBI and Abricate data.")
    parser.add_argument("--ncbi-dir", required=True, help="Directory containing ncbi_clean.csv.")
    parser.add_argument("--abricate-dir", required=True, help="Directory containing Abricate summary CSV or TAB files.")
    parser.add_argument("--output-dir", required=True, help="Base output directory.")
    parser.add_argument("--format", default="tiff", choices=["tiff", "svg", "png", "pdf"], help="Output format for figures (tiff, svg, png, pdf).")
    args = parser.parse_args()
    
    # Run the main function
    main(args.ncbi_dir, args.abricate_dir, args.output_dir, args.format)
