#!/usr/bin/env python
import argparse
import ast
import logging
import os
import sys

import numpy as np
import pandas as pd
import scipy.io

sys.path.append(
    os.path.realpath(
        os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "python")
    )
)

import warnings  # noqa

warnings.filterwarnings("ignore")
import upsp.processing.context as context  # noqa
import upsp.kulite_comparison.selection as selection  # noqa
import upsp.kulite_comparison.plotting as plotting  # noqa

warnings.filterwarnings("default")

log = logging.getLogger(__name__)


def _read_pressure_transpose(dp_ctx: context.Datapoint, node_index):
    """Returns a 1-D numpy array containing model grid node pressure time history
    - filename: path to pressure_transpose file
    - number_grid_nodes: total number of model grid nodes
    - node_index: index of grid node (zero-based) in model grid file
    """
    grid_x_filename = dp_ctx.output_path("psp_process", "X")
    pressure_filename = dp_ctx.output_path("psp_process", "pressure_transpose")
    grid_x_size_bytes = os.path.getsize(grid_x_filename)
    pressure_size_bytes = os.path.getsize(pressure_filename)
    item_size_bytes = 4  # 32-bit floating point values
    number_grid_nodes = int(grid_x_size_bytes / item_size_bytes)
    number_frames = int((pressure_size_bytes / item_size_bytes) / number_grid_nodes)
    assert number_frames * number_grid_nodes * item_size_bytes == pressure_size_bytes
    offset = node_index * number_frames * item_size_bytes
    return np.fromfile(
        pressure_filename, dtype=np.float32, count=number_frames, offset=offset
    )


def dump_vertices(args):
    df = pd.read_excel(args.input, dtype={"Datapoint": str})
    output = {
        "pipelineDir": [],
        "datapoint": [],
        "kuliteName": [],
        "direction": [],
        "distance": [],
        "selectionVertexIndices": [],
        "selectionVertexDeltaCp": [],
    }

    number_rows = df.shape[0]
    for irow, row in df.iterrows():
        root_dir = row["Pipeline Directory"]
        datapoint_name = row["Datapoint"]
        vertices = ast.literal_eval(row["Selection Vertices"])
        ctx = context.Pipeline(root_dir)
        dp_ctx = ctx.datapoint(datapoint_name)
        time_series = []
        for idx in vertices:
            this_series = _read_pressure_transpose(dp_ctx, idx)
            time_series.append(this_series)
        output["pipelineDir"].append(root_dir)
        output["datapoint"].append(datapoint_name)
        output["kuliteName"].append(row["Kulite Name"])
        output["direction"].append(row["Direction"])
        output["distance"].append(row["Distance (in)"])
        output["selectionVertexIndices"].append(vertices)
        output["selectionVertexDeltaCp"].append(time_series)
        log.info(
            "Dumped pressure-time histories for query (%d/%d)", irow + 1, number_rows
        )

    output_filename = args.output
    output["selectionVertexIndices"] = np.asarray(output["selectionVertexIndices"], dtype=object)
    output["selectionVertexDeltaCp"] = np.asarray(output["selectionVertexDeltaCp"], dtype=object)
    scipy.io.savemat(output_filename, output)


def plot_vertices(args):
    df = pd.read_excel(args.input, dtype={"Datapoint": str})
    plotting.make_selection_area_images(df)


def select_vertices(args):
    df = pd.read_excel(args.input, dtype={"Datapoint": str})
    # Group queries with the same output dir + datapoint number,
    # so we can reuse indexing performed each time a
    # KuliteNeighborhoodSearch is constructed.
    df["ctx"] = df["Pipeline Directory"] + "-" + df["Datapoint"]
    results = {
        "Pipeline Directory": [],
        "Datapoint": [],
        "Kulite Name": [],
        "Direction": [],
        "Distance (in)": [],
        "Kulite Nearest Vertex": [],
        "Selection Vertices": [],
    }
    for _, df_query in dict(tuple(df.groupby("ctx"))).items():
        pipeline_dir = df_query.iloc[0]["Pipeline Directory"]
        datapoint_name = df_query.iloc[0]["Datapoint"]
        ctx = context.Pipeline(pipeline_dir)
        dp_ctx = ctx.datapoint(datapoint_name)
        search = selection.KuliteNeighborhoodSearch(dp_ctx)
        for idx, q in df_query.iterrows():
            direction = q["Direction"].lower()
            displacement = q["Distance (in)"]
            number_vertices = q["# nodes"]
            kulite_name = q["Kulite"]
            res = search.query(
                kulite_name,
                direction,
                displacement,
                number_vertices,
                duplicate_vertex_tol=args.duplicate_vertex_tol,
            )
            results["Pipeline Directory"].append(pipeline_dir)
            results["Datapoint"].append(datapoint_name)
            results["Kulite Name"].append(kulite_name)
            results["Direction"].append(direction)
            results["Distance (in)"].append(displacement)
            results["Kulite Nearest Vertex"].append(res["Kulite Nearest Vertex"])
            results["Selection Vertices"].append(res["Selection Vertices"])
    out = pd.DataFrame(results)
    out.to_excel(args.output)


def main():
    logging.basicConfig(level=logging.INFO)
    ap = argparse.ArgumentParser(
        prog="upsp-kulite-comparison", description="uPSP-Kulite comparison utility"
    )
    sp = ap.add_subparsers(title="Available subcommands")
    select_vertices_sp = sp.add_parser("select-vertices")
    select_vertices_sp.add_argument("input", help="query input (*.xlsx)")
    select_vertices_sp.add_argument("output", help="query result (*.xlsx)")
    select_vertices_sp.add_argument(
        "--duplicate_vertex_tol",
        default=0.02,
        type=float,
        help=" ".join(
            [
                "For each query row, ensure selected vertices"
                "are at least this far apart (default=0.02)"
            ]
        ),
    )
    select_vertices_sp.set_defaults(func=select_vertices)
    plot_vertices_sp = sp.add_parser("plot-vertices")
    plot_vertices_sp.add_argument(
        "input", help="select-model-vertices output (*.xlsx)"
    )
    plot_vertices_sp.set_defaults(func=plot_vertices)
    dump_vertices_sp = sp.add_parser("dump-vertices")
    dump_vertices_sp.add_argument("input", help="select-model-vertices output (*.xlsx)")
    dump_vertices_sp.add_argument(
        "output",
        help="output filename (default='dump_vertices.mat')",
        default="dump_vertices.mat",
    )
    dump_vertices_sp.set_defaults(func=dump_vertices)
    args = ap.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()
