#!/usr/bin/env python

from minknow_api.manager import Manager
import plotext as plt
from collections import defaultdict
import time
import numpy as np
import argparse


def get_mux_scan_data(connection):
    """
    Extract mux scan results from acquisition info
    Returns a list of (timestamp, counts_dict) tuples
    """
    acquisition_info = connection.acquisition.get_acquisition_info()
    
    # List to store (timestamp, counts) pairs
    scan_data = []
    
    # Extract data from each mux scan result
    for scan in acquisition_info.bream_info.mux_scan_results:
        timestamp = scan.mux_scan_timestamp
        counts = {}
        for key, value in scan.counts.items():
            counts[key] = value

        scan_data.append((timestamp, counts))
    
    # Sort by timestamp
    scan_data.sort(key=lambda x: x[0])
    return scan_data


def plot_status(scan_data):
    """
    Create a stacked bar chart of mux scan results
    """

    yields = get_flowcell_yields()

    # Clear the plot
    plt.clf()
    plt.theme('dark')

    plt.subplots(2,len(scan_data))
    
    # mux plots

    n_plot = 0

    categories = ['single_pore', 'saturated', 'zero', 'unavailable', 'reserved_pore', 'other', 'multiple']

    for posname, mux_data in scan_data.items():
        n_plot += 1

        latest_time, latest_counts = mux_data[-1]

        plt.subplot(1,n_plot)

        plt.subplot(1,n_plot).subplots(2,2)

        plt.subplot(1,n_plot).subplot(1,1)
        plt.title(f"{posname}")
        plt.text(f"Elapsed: {yields[posname]['elapsed_time']}", 0, 4, alignment='left', color="blue+")
        plt.text(f"Sample: {yields[posname]['sample']}", 0, 3, alignment='left', color="blue+")
        plt.text(f"Flowcell Type: {yields[posname]['flowcell_type']}", 0, 2, alignment='left', color="blue+")
        plt.text(f"Flowcell ID: {yields[posname]['flowcell_id']}", 0, 1, alignment='left', color="blue+")
        plt.text(f"Prep Kit: {yields[posname]['kit']}", 0, 0, alignment='left', color="blue+")
        plt.xlim(0, None)
        plt.xticks([])
        plt.yticks([])

        plt.subplot(1,n_plot).subplot(1,2)
        plt.title("Yield")
        plt.text(f"{yields[posname]['yield_gb']} Gbases", 0, 3,  alignment='left', color="blue+")
        plt.text(f"{yields[posname]['reads']} Reads", 0, 2,  alignment='left', color="blue+")
        plt.text(f"{yields[posname]['wrote_gb']} Gbytes", 0, 1, alignment='left', color="blue+")
        plt.text(f"N50 {yields[posname]['n50']} bp", 0, 0,  alignment='left', color="blue+")
        plt.xlim(0, None)
        plt.xticks([])
        plt.yticks([])

        plt.subplot(1,n_plot).subplot(2,1)
        plt.title("Log10 Count vs Read Length (excl outliers)")
        hist_values = yields[posname]['hist_values']
        hist_starts = yields[posname]['hist_starts']
        plt.plot(hist_starts, np.log10(hist_values))

        plt.subplot(1,n_plot).subplot(2,2)
        plt.title(f"Pores at {int(latest_time/60)} m")
        plt.text(f"Available:   {latest_counts['single_pore']}", 0, 6, alignment='left', color="blue+")
        plt.text(f"Saturated:   {latest_counts['saturated']}", 0, 5, alignment='left', color="blue+")
        plt.text(f"Zero:        {latest_counts['zero']}", 0, 4, alignment='left', color="blue+")
        plt.text(f"Unavailable: {latest_counts['unavailable']}", 0, 3, alignment='left', color="blue+")
        plt.text(f"Reserved:    {latest_counts['reserved_pore']}", 0, 2, alignment='left', color="blue+")
        plt.text(f"Other:       {latest_counts['other']}", 0, 1, alignment='left', color="blue+")
        plt.text(f"Multiple:    {latest_counts['multiple']}", 0, 0, alignment='left', color="blue+")
        plt.xlim(0,None)
        plt.xticks([])
        plt.yticks([])
        

        plt.subplot(2,n_plot)
        plt.plotsize(None, plt.th()//1.8)

        # Create lists of values for each category
        values = defaultdict(list)
        for _, counts in mux_data:
            for cat in categories:
                values[cat].append(counts.get(cat, 0))

        # Convert timestamps to minutes for readability
        times = [int(t/60) for t, _ in mux_data]

        # Create stacked bar chart
        plt.stacked_bar(
            times,
            [values[cat] for cat in categories],
            labels=categories,
            width=50
        )
        
        plt.xlabel("Minutes")
        plt.ylabel("Pores")
    
    # Show the plot
    plt.show()


def get_flowcell_yields():
    """
    Monitor and return the current yield in gigabases for each active flowcell
    Returns a dictionary of position_id: yield_in_gb pairs
    """
    # Connect to the MinKNOW server
    manager = Manager()
    
    yields = {}
    
    # Iterate through each position (flowcell)
    for position in manager.flow_cell_positions():
        # Get connection to the position
        connection = position.connect()

        # Get the current acquisition state
        run_info = connection.acquisition.current_status()
        
        # For debugging - let's see what we're working with
        info = connection.acquisition.get_acquisition_info()
        
        # Check if sequencing is active (status 3 indicates running)
        if run_info.status == 3:
            # Try to access fields directly from the protobuf object
            acquisition_info = connection.acquisition.get_acquisition_info()
            
            try:
                # Try to get yield info
                yield_info = acquisition_info.yield_summary
                read_count = yield_info.read_count
                bases = float(yield_info.estimated_selected_bases)
                gigabases = bases / 1e9
                gigabytes = float(acquisition_info.writer_summary.bytes_to_write_completed)/1e9

                n50 = connection.statistics.read_length_n50(acquisition_run_id=acquisition_info.run_id)
                len_hist = connection.statistics.stream_read_length_histogram(acquisition_run_id=acquisition_info.run_id, bucket_value_type='ReadCounts', discard_outlier_percent=0.05)
                len_hist = len_hist.next()

                # Extract values from histogram_data
                hist_values = list(len_hist.histogram_data)[0].bucket_values
                hist_starts = [br.start if hasattr(br, 'start') else 0 for br in len_hist.bucket_ranges]
               
                # Try to get timing info
                start_time = acquisition_info.start_time.seconds
                current_time = time.time()
                run_time_seconds = current_time - start_time
                hours = int(run_time_seconds // 3600)
                minutes = int((run_time_seconds % 3600) // 60)
                
                prot_info = connection.protocol.get_run_info()
                user_info = prot_info.user_info

                yields[position.name] = {
                    'hist_values': hist_values,
                    'hist_starts': hist_starts,
                    'yield_gb': round(gigabases, 2),
                    'wrote_gb': round(gigabytes, 2),
                    'n50': int(n50.n50_data.estimated_n50),
                    'elapsed_time': f"{hours:02d}h:{minutes:02d}m",
                    'reads': read_count,
                    'run_id': acquisition_info.run_id,
                    'reads_directory': acquisition_info.config_summary.reads_directory,
                    'sample': user_info.sample_id.value,
                    'flowcell_type': prot_info.flow_cell.product_code,
                    'flowcell_id': prot_info.flow_cell.flow_cell_id,
                    'kit': prot_info.meta_info.tags['kit'].string_value
                }
            except AttributeError as e:
                print(f"Error accessing fields: {e}")
                yields[position.name] = f"Error accessing acquisition info fields: {e}"
        else:
            yields[position.name] = f"No active sequencing run (Status: {run_info.status})"
                
    return yields


def monitor(args):  # Default 5 minute interval
    """
    Continuously monitor mux scan results
    Args:
        interval: Time between checks in seconds
    """
    # Connect to the MinKNOW server
    interval = int(args.interval)
    manager = Manager()
    
    try:
        while True:
            print(f"\nChecking mux scans at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
            
            scan_data = {}
            
            for position in manager.flow_cell_positions():
                try:
                    connection = position.connect()
                    run_info = connection.acquisition.current_status()
                    
                    if run_info.status == 3:  # Running
                        print(f"\nPosition: {position.name}")
                        scan_data[position.name] = get_mux_scan_data(connection)
                        
                    else:
                        print(f"\nPosition {position.name}: No active sequencing run")
                        
                except Exception as e:
                   print(f"Error with position {position.name}: {str(e)}")

            # Plot the data
            plot_status(scan_data)
            
            time.sleep(interval)
            
    except KeyboardInterrupt:
        print("\nMonitoring stopped by user")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='monitor minknow via command line')
    parser.add_argument('-i', '--interval', default=300, help='update interval (seconds, default = 300)')
    
    args = parser.parse_args()
    monitor(args)
