#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
# SPECTROFLAT CALIBRATION DATA AMENDMENT

This is a CLI script to amend the calibration data from the `spectroflat` library.
It will add dispersion correction (and chi^2 error) to the `OffsetMap`.

## Configuration
The configuration is given as a YAML file. See in examples folder or `atlas_fit/config.py`
for documentation on parameters.

## Execution
If the script is executable (preferred) you can run it directly as `./bin/amend_spectroflat --help`.
The script should automatically be marked as executable after git checkout.
If this is not the case and/or you do not want this, you can execute it with
your favorite python interpreter as `python3 bin/amend_spectroflat --help`.

See `--help` switch for further usage details.

## Author(s)
hoelken@mps.mpg.de
"""
import copy
import logging
import os
import sys
from argparse import ArgumentParser

import numpy as np
import yaml
from astropy.io import fits
from scipy.interpolate import pchip_interpolate
from scipy.ndimage import gaussian_filter
from spectroflat.smile import OffsetMap, SmileInterpolator
from spectroflat.utils.processing import MP

sys.path.insert(0, os.path.join(sys.path[0], '..'))
from src.atlas_fit import Config, Comperator
from src.atlas_fit.models import Spectrum
from src.atlas_fit.utils import parse_shape, read_fits_data, read_hdu
from src.atlas_fit.utils.tailoring import select_best_atlas_range
from src.atlas_fit.utils.corrections import compute_continuum_fit_error

aparser = ArgumentParser(description='Amend calibration data with atlas correction')
aparser.add_argument('path', type=str, metavar='CONFIG_PATH',
                     help='Path to the configuration file to use')
aparser.add_argument('linefile', type=str, metavar='LINEFILE_PATH',
                     help='Path to file with list of lines to use')
aparser.add_argument("-o", "--out", type=str, nargs="?", default=".", help='Output folder. default: current folder')

args = aparser.parse_args()
log = logging.getLogger()

config = Config.from_yaml(args.path)
with open(args.linefile, 'r') as f:
    points = yaml.safe_load(f)

log.info('Active config:\n%s', repr(config))
log.info('SELECTED ATLAS POSITIONS:\n %s', points['atlas'])
log.info('SELECTED DATA POSITIONS:\n %s', points['data'])

log.info('Amending calibration data according to selected atlas & lines.')
offsets = OffsetMap.from_file(config.input.offset_map)
img = read_fits_data(config.input.corrected_frame)
if config.input.flipped:
    offsets.map = np.flip(offsets.map, axis=2)
    img = np.flip(img, axis=2)
    roi = parse_shape(config.input.roi)
    config.input.roi = f'[s, {roi[1]}, {img.shape[2] - roi[2].stop}:{img.shape[2] - roi[2].start}]'
    print(config.input.roi)

log.info('Amending OffsetMap %s', config.input.offset_map)
log.info('\t- Fitting dispersion...')
delta_offsets = []
orig_means = []
correction_state = []
for s in range(config.input.mod_states):
    shape = parse_shape(config.input.roi.replace('s', str(s)))
    print(f'\t  State {s} of {config.input.mod_states - 1}', shape)
    c = Comperator(copy.copy(img[shape]), config, points).run()
    orig_means.append(img[shape].mean())

    orig_atlas = select_best_atlas_range(c, config)
    correction = compute_continuum_fit_error(orig_atlas, config.fitting)

    xes = np.arange(len(c.spectrum.data))
    data = Spectrum(pchip_interpolate(c.new_data_xes, img[shape], xes))
    data.apply_lowpass_filter_correction(config.fitting)
    data.normalize(orig_atlas.intensity.mean())
    data.straighten(config.fitting, correction * (data.data.mean() / orig_atlas.intensity.mean()))

    correction_state.append(data.continuum_correction / orig_atlas.intensity.mean()**2 * orig_means[s])
    xes = np.arange(len(c.spectrum.data))
    col_offsets = c.new_data_xes - xes
    delta_offsets.append(col_offsets)
    for row in range(offsets.map[s].shape[0]):
        offsets.map[s][row][shape[2]] = offsets.map[s][row][shape[2]] + col_offsets
print('')

if config.fitting.squash_offsets:
    log.info('Averaging modulated offsets to a global state (squash).')
    offsets.squash()
    delta_offsets = np.repeat(np.array(delta_offsets).mean(axis=0, keepdims=True), config.input.mod_states, axis=0)
else:
    delta_offsets = np.array(delta_offsets)
    log.warning('Providing MODULATED OFFSETS: Check demodulated data for artifacts!')

log.info('\t- Provide WL calibration information...')
wl = c.atlas.wl_func(c.dispersion(xes))
start = shape[2].start if shape[2].start is not None else 0
stop = shape[2].stop if shape[2].stop is not None else img.shape[2]
offsets.header['HIERARCH MIN_WL_NM'] = wl[0]
offsets.header['HIERARCH MIN_WL_PX'] = stop if config.input.flipped else start
offsets.header['HIERARCH MAX_WL_NM'] = wl[-1]
offsets.header['HIERARCH MAX_WL_PX'] = start if config.input.flipped else stop
offsets.header['HIERARCH DISPERSION'] = (wl[-1] - wl[0]) / (stop - start)
log.info(repr(offsets.header))

out_file = os.path.join(args.out, 'wl_calibrated_offsets.fits')
if config.input.flipped:
    axis = 1 if offsets.is_squashed() else 2
    offsets.map = np.flip(offsets.map, axis=axis)
offsets.dump(out_file=out_file)

log.info('Amending Soft Flat %s', config.input.offset_map)
log.info('\t- Compute correction image...')
roi = parse_shape(config.input.correction_roi) if config.input.correction_roi is not None else None
if config.input.flipped and roi is not None:
    roi = (roi[0], slice(img.shape[2] - roi[1].stop, img.shape[2] - roi[1].start))
if roi is None:
    roi = (slice(None, None), slice(None, None))

vmeans = []
for s in range(config.input.mod_states):
    vmean = (img[s][roi]).mean(axis=1)
    vmean /= vmean.mean()
    vmeans.append(gaussian_filter(vmean, sigma=7))

correction_state = np.array(correction_state)
vmeans = np.array(vmeans)
cimg = [np.array([correction_state[s] * f for f in vmeans[s]]) for s in range(config.input.mod_states)]

with read_hdu(config.input.soft_flat) as soft_flat:
    if config.input.flipped:
        soft_flat.data = np.flip(soft_flat.data, axis=2)
    states, rows, cols = soft_flat.data.shape
    rows = np.arange(rows)
    stop = roi[1].stop if roi[1].stop is not None else cols
    start = roi[1].start if roi[1].start is not None else 0
    xes = np.arange(stop - start)
    log.info('\t- Applying additional offset correction...')
    for s in range(states):
        arguments = [(r, xes, delta_offsets[s], soft_flat.data[s, r, start:stop]) for r in rows]
        res = dict(MP.simultaneous(SmileInterpolator.desmile_row, arguments))
        soft_flat.data[s, :, roi[1].start:roi[1].stop] = np.array([res[row] for row in rows])
    if config.input.flipped:
        soft_flat.data = np.flip(soft_flat.data, axis=2)
    out_file = os.path.join(args.out, 'amended_soft_flat.fits')

    hdul = fits.HDUList()
    hdul.append(soft_flat)

    log.info('\t- adding continuum correction HDU...')
    cont = np.ones(soft_flat.data.shape)
    if roi is None:
        for s in range(states):
            cont[s] *= np.array(cimg[s])
    else:
        for s in range(states):
            cont[s][roi] *= np.array(cimg)[s]
    if config.input.flipped:
        cont = np.flip(cont, axis=2)
    cont = fits.ImageHDU(cont)
    aim = orig_atlas.intensity.mean()
    cont.header.append(('HIERARCH MEAN_ATLAS_CONTINUUM', f"{orig_atlas.continuum.mean():.5e}", "W cm-1 ster-1 A-1"))
    cont.header.append(('HIERARCH MEAN_ATLAS_INTENSITY', f"{orig_atlas.intensity.mean():.5e}", "W cm-1 ster-1 A-1"))
    for s in range(states):
        cont.header.append((f'HIERARCH MEAN_INTENSITY_CONVERSION_{s}', f"{aim / orig_means[s]}"))
    print(repr(cont.header))
    hdul.append(cont)

    log.info('\t- Write to "%s"...', out_file)
    hdul.writeto(out_file, overwrite=True)
