#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
# ATLAS FITTING

This is a CLI script to apply an atlas fitting to a given spectral image.

The purpose of this script is to select the lines in the atlas and data set used for dispersion fitting.
As a test for the result the script will present a plot with fitted continuum, fitted dispersion and the
remaining chi²-error. It will also create a txt file of the line positions used, in order to apply the
created config to a calibration set created by the spectroflat library.

## Configuration
The configuration is given as a YAML file. See examples or `atlas_fit/config.py`
for documentation on parameters.

## Execution
If the script is executable (preferred) you can run it directly as `./bin/prepare --help`.
The script should automatically be marked as executable after git checkout.
If this is not the case and/or you do not want this, you can execute it with
your favorite python interpreter as `python3 bin/prepare --help`.

See `--help` switch for further usage details.

## Author(s)
hoelken@mps.mpg.de
"""
import copy
import logging
import os
import sys
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from mpl_point_clicker import clicker
from scipy.ndimage import zoom
from scipy.signal import resample

sys.path.insert(0, os.path.join(sys.path[0], '..'))
from src.atlas_fit import Config, Comperator
from src.atlas_fit.models import Spectrum
from src.atlas_fit.base.constants import NM2ANGSTROM
from src.atlas_fit.models.model_convolution import ModelPSF
from src.atlas_fit.models.differential_convolution import FWHM, DifferentialSpectralPSF
from src.atlas_fit.fitting.parameter_fit import ParameterFit
from src.atlas_fit.utils import parse_shape, read_fits_data, find_nearest
from src.atlas_fit.utils.tailoring import select_best_atlas_range, select_data_fit_region
from src.atlas_fit.utils.corrections import compute_continuum_fit_error
from src.atlas_fit.atlantes import atlas_factory

DPI = 120
BORDER = 10
FIGSIZE = (1200 / DPI, 900 / DPI)
points = {'atlas': [], 'data': []}


def point_added(position: tuple, clazz: str):
    x, _ = position
    points[clazz].append(x)


def point_removed(position: tuple, clazz: str, idx):
    x, _ = position
    points[clazz].pop(idx)


def select_lines_from_plot():
    atlas = atlas_factory(config.atlas.key, config.input.wl_start, config.input.wl_end, conversion=NM2ANGSTROM)
    if config.input.model_psf is not None:
        mpsf = ModelPSF(config.input.model_psf)
        atlas.intensity = mpsf.convolve(atlas.intensity, atlas.wl)
    else:
        config.atlas.compute_fwhm(config.input.central_wl)
        if config.atlas.fwhm is not None:
            fwhm = FWHM(atlas=config.atlas.fwhm, data=config.input.fwhm)
            atlas.intensity = DifferentialSpectralPSF(fwhm).convolve(atlas.intensity, atlas.wl)
    atlas.intensity + (atlas.intensity.mean() * config.input.stray_light)

    fig, ax = plt.subplots(nrows=2, figsize=FIGSIZE, dpi=DPI)
    title = 'Click on the line cores to be used for calibration in ATLAS and DATA plot.\n'
    title += 'I.e., select same lines in both plots (the more the better), then close this window.'
    fig.suptitle(title)

    ax[0].set_title(f'ATLAS ({config.atlas.key.upper()})')
    ax[0].plot(atlas.wl, atlas.intensity)
    ax[0].set_xlabel(r'$\lambda$ [nm]')
    ax[0].set_xlim([min(atlas.wl), max(atlas.wl)])

    ax[1].set_title('DATA')
    ax[1].plot(spectrum)
    ax[1].set_xlabel(r'$\lambda$ [px]')
    ax[1].set_xlim([0, len(spectrum)])
    plt.tight_layout()

    atlas_clicker = clicker(ax[0], ["atlas"], markers=["*"])
    atlas_clicker.on_point_added(point_added)
    atlas_clicker.on_point_removed(point_removed)
    data_clicker = clicker(ax[1], ["data"], markers=["o"])
    data_clicker.on_point_added(point_added)
    data_clicker.on_point_removed(point_removed)

    plt.show()
    points['data'] = [int(np.round(p)) for p in points['data']]
    points['atlas'] = [find_nearest(atlas.wl, p) for p in points['atlas']]
    points['data'].sort()
    points['atlas'].sort()


aparser = ArgumentParser(description='Fit a Solar Atlas to the data')
aparser.add_argument('path', type=str, metavar='CONFIG_PATH',
                     help='Path to the configuration file to use')
aparser.add_argument('linefile', nargs='?', metavar='LINEFILE_PATH', default=None,
                     help='Path to file with list of lines to use (optional)')
args = aparser.parse_args()

log = logging.getLogger()
config = Config.from_yaml(args.path)
shape = parse_shape(config.input.roi.replace('s', '0'))

log.info('Loading spectra, generating selection plot ...')
spectrum = read_fits_data(config.input.corrected_frame)[shape]
if config.input.flipped:
    spectrum = np.flip(spectrum)
if args.linefile is None:
    select_lines_from_plot()
else:
    with open(args.linefile, 'r') as f:
        points = yaml.safe_load(f)

log.info('SELECTED ATLAS POSITIONS:\n %s', points['atlas'])
log.info('SELECTED DATA POSITIONS:\n %s', points['data'])
if len(points['atlas']) != len(points['data']):
    log.critical('Selected %s DATA points, but %s ATLAS points.', len(points['data']), len(points['atlas']))
    log.error('Please start again and select the same lines in both plots. Aborting.')
    sys.exit(1)

log.info('Finding wavelength solution based on configured lines...')
c = Comperator(copy.copy(spectrum), config, points)
c.run()

xes = np.arange(len(c.spectrum.data))
log.info('Fitting FWHM and stray-light components...')
orig_atlas = select_best_atlas_range(c, config)
log.info('Spectral window: %.2f nm to %.2f nm', min(orig_atlas.wl), max(orig_atlas.wl))
min_x, max_x = select_data_fit_region(c, orig_atlas)
log.info('Best fit window: %s px to %s px (%s of %s px)', min_x, max_x, max_x - min_x, len(spectrum))
correction = compute_continuum_fit_error(orig_atlas, config.fitting)

data = Spectrum(c.warp_profile(spectrum, xes))
data.apply_lowpass_filter_correction(config.fitting)
data.normalize(orig_atlas.intensity.mean())
data.straighten(config.fitting, correction * (data.data.mean() / orig_atlas.intensity.mean()))
new_spectrum = data.data[min_x:max_x]

continuum_correction = data.continuum_correction[min_x:max_x]
scale_factor = len(new_spectrum) / len(orig_atlas.intensity)
log.info('Scale factor data to atlas = %.4f', scale_factor)

if config.debug:
    plt.plot(new_spectrum / np.mean(new_spectrum), label='data')
    res_atlas = resample(orig_atlas.intensity, len(new_spectrum))
    plt.plot(res_atlas / np.mean(res_atlas), '--', label='atlas')
    fwhm = FWHM(atlas=config.atlas.fwhm, data=config.input.fwhm)
    res_atlas = DifferentialSpectralPSF(fwhm).convolve(orig_atlas.intensity, orig_atlas.wl)
    res_atlas += (res_atlas.mean() * config.input.stray_light / 100)
    res_atlas = resample(res_atlas, len(new_spectrum))
    plt.plot(res_atlas / np.mean(res_atlas), '-.', label='convolved atlas')

    xes = c.new_data_xes
    off = np.argmin(np.abs(c.atlas.wl - orig_atlas.wl.min())) * scale_factor
    for i in range(len(c.spectrum.lines)):
        if i in c.spectrum.bad_lines or i in c.asp.bad_lines:
            continue
        plt.axvline(x=xes[int(np.round(c.spectrum.lines[i]))] - min_x, linestyle='--', color='tab:blue')
        plt.axvline(x=c.asp.lines[i] * scale_factor - off, linestyle='-.', color='tab:green')
    plt.axvline(x=len(new_spectrum) // 2)
    plt.legend()
    plt.xlim([0, len(res_atlas)])
    plt.show()

fwhm_fit = config.input.fwhm
log.info('setup done, starting with FWHM: %.3e nm and stray-light: %.2f %%', fwhm_fit, config.input.stray_light)
pf = ParameterFit(new_spectrum, orig_atlas, config, BORDER)
res = pf.minimize_delta(stray_light=config.input.stray_light / 100, fwhm=fwhm_fit)
success = f'[{"success" if res.success else "no fit"}]'
log.info('Optimization done: FWHM: %.3e nm and stray-light: %.3f %% %s', res.x[1], res.x[0], success)
print(res)
if not res.success:
    log.warning("STRAY-LIGHT/FWHM FIT WAS NOT SUCCESSFUL")
log.info('Global error: %.4e', pf.global_error())

wl = zoom(orig_atlas.wl, scale_factor, order=5)[BORDER:-BORDER]
cont = zoom(orig_atlas.continuum, scale_factor, order=5)[BORDER:-BORDER]

y_label = r'$W\; cm^{-1}\; ster^{-1}\; Å^{-1}$'
fig, ax = plt.subplots(nrows=3, sharex='col', figsize=FIGSIZE, dpi=DPI, gridspec_kw={'height_ratios': [3, 3, 1]})
fig.suptitle(f'{config.label}\nWavelength calibration result (deg: {c.deg})')
ax[0].set_title(f'Fitted parameters: FWHM: {res.x[1]:.2e} nm, stray-light: {res.x[0] * 100:.2f} % {success}')

# orig = (spectrum/spectrum.mean())[BORDER+4:-BORDER-5] * pf.data.mean()
# ax[0].plot(wl, orig, '--', linewidth=0.7, label='uncorrected')

ax[0].plot(wl, pf.atlas, label='ref. atlas')
ax[0].plot(wl, cont, '--', label='ref. continuum')
ax[0].plot(wl, continuum_correction[BORDER:-BORDER], '-.', label='continuum corr.')
ax[0].plot(wl, pf.data, label='corrected')
ax[0].tick_params(bottom=True, labelbottom=True)
# ax[0].set_xlabel(r'$\lambda$ [nm]')

ax[0].legend()
ax[0].set_ylabel(y_label)
ax[0].set_xlim([wl[0], wl[-1]])
ax2 = ax[0].twiny()
ax2.set_xlabel(r'$\lambda$ [px]')
ax2.set_xlim([BORDER, len(pf.data) + BORDER])

delta = pf.compute_error()

ax[1].set_title(rf'$\chi^2$ Error ({pf.global_error():.2e})')
ax2 = ax[1].twiny()
error = delta.copy()
for r in pf.ignored_regions:
    ax2.axvspan(r.start, r.stop, alpha=0.2, color='red')
    error[r.start - BORDER: r.stop - BORDER] = delta.mean()

ax2.set_xlim([BORDER, len(pf.data) + BORDER])
ax[1].plot(wl, delta, label='Error values')
rolling_mean = pd.Series(error).rolling(window=150).mean().iloc[150 - 1:].values
ax[1].plot(wl[75:-74], rolling_mean, linewidth=0.85, label='Rolling Mean (150px window)')
ax[1].axhline(y=error.mean(), linestyle='--', color='gray', label=f'Mean {error.mean():.2e}')
ax[1].legend()
ax[1].set_xlim([wl[0], wl[-1]])
if not np.isnan(delta.mean()):
    ax[1].set_ylim([0, delta.mean() + 5 * error.std()])
ax[1].grid(axis='y', linestyle=':')

new_xes = c.new_data_xes
xes = np.arange(len(new_xes))
ax[2].set_title('Fitted dispersion function')
ax[2].set_ylabel('shift [px]')
ax[2].set_xlabel(r'$\lambda$ [nm]')
ax2 = ax[2].twiny()
ax2.plot(xes, c.new_data_xes - xes, label=f'{c.dispersion}')
ax2.set_xlim([BORDER, len(pf.data) + BORDER])
ax2.tick_params(bottom=False, labelbottom=False, labeltop=False)
# ax2.legend()

for a in ax:
    a.grid(axis='x', which='major')
    a.grid(axis='x', which='minor', linestyle='--')
    a.minorticks_on()

plt.subplots_adjust(top=0.849, bottom=0.062, left=0.069, right=0.995, hspace=0.574, wspace=0.18)
plt.show()

if config.fitting.save_dispersion_function:
    dispersion_file = 'dispersion.txt'
    log.info('Writing used lines to %s', dispersion_file)
    data = np.array([
        zoom(orig_atlas.wl, len(xes)/len(orig_atlas.wl), order=5),
        c.new_data_xes - xes
    ])
    print(data.shape)
    np.savetxt(dispersion_file, data, header='lambda [nm], offset [px]')

linefile = 'atlas_fit_lines.yaml'
log.info('Writing used lines to %s', linefile)
with open(linefile, 'w') as file:
    documents = yaml.dump(points, file)
