#! /usr/bin/env python3

import pandas as pd
from astropy.coordinates import SkyCoord
import astropy.units as u
from collections import namedtuple
import numpy as np
from astropy.coordinates import SkyCoord
import astropy.units as u

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('filename', help='A Hector Robot file',)
parser.add_argument('hexabundle_1', type=str)
parser.add_argument('hexabundle_2', type=str)

args = parser.parse_args()

fname = args.filename
hexabundle_1 = args.hexabundle_1
hexabundle_2 = args.hexabundle_2

df = pd.read_csv(fname, skiprows=11)
df = df.loc[df["type"] >= 0]


# # Check that the index of the dataframe runs from 1 to (at most) 6:
# if df.index[0] == 0:
#     df.index = df.index + 1

if len(df) >= 28:
    raise ValueError(f"We seem to have {len(df)} bundles but Hector only has 6!")



def plate2sky(x, y, linear=False):
    """Convert position on plate to position on sky, relative to plate centre.
    x and y are input as positions on the plate in microns, with (0, 0) at
    the centre. Sign conventions are defined as in the CSV allocation files.
    Return a named tuple (xi, eta) with the angular coordinates in arcseconds,
    relative to plate centre with the same sign convention. If linear is set
    to True then a simple linear scaling is used, otherwise pincushion
    distortion model is applied."""

    # Should implement something to cope with (0, 0)

    # Define the return named tuple type
    AngularCoords = namedtuple('AngularCoords', ['xi', 'eta'])

    # Make sure we're dealing with floats
    x = np.array(x, dtype='d')
    y = np.array(y, dtype='d')

    if np.size(x) == 1 and np.size(y) == 1 and x == 0.0 and y == 0.0:
        # Plate centre, return zeros before you get an error
        return AngularCoords(0.0, 0.0)

    if linear:
        # Just do a really simple transformation
        plate_scale = 15.22 / 1000.0   # arcseconds per micron
        coords = AngularCoords(x * plate_scale, y * plate_scale)
    else:
        # Include pincushion distortion, found by inverting:
        #    x = xi * f * P(xi, eta)
        #    y = eta * f * P(xi, eta)
        # where:
        #    P(xi, eta) = 1 + p * (xi**2 + eta**2)
        #    p = 191.0
        #    f = 13.515e6 microns, the telescope focal length
        p = 191.0
        f = 13.515e6
        a = p * (x**2 + y**2) * f
        twentyseven_a_squared_d = 27.0 * a**2 * (-x**3)
        root = np.sqrt(twentyseven_a_squared_d**2 +
                       4 * (3 * a * (x**2 * f))**3)
        xi = - (1.0/(3.0*a)) * ((0.5*(twentyseven_a_squared_d +
                                      root))**(1.0/3.0) -
                                (-0.5*(twentyseven_a_squared_d -
                                       root))**(1.0/3.0))
        # Convert to arcseconds
        xi *= (180.0 / np.pi) * 3600.0
        eta = y * xi / x
        if np.size(x) > 1 and np.size(y) > 1:
            # Check for (0,0) values in input arrays
            zeros = ((x == 0.0) & (y == 0.0))
            xi[zeros] = 0.0
            eta[zeros] = 0.0
        coords = AngularCoords(xi, eta)

    return coords

def offset(base, target):

    dra, ddec = base.spherical_offsets_to(target)

if hexabundle_1 == 'centre':
    delta_x1 = 0
    delta_y1 = 0
else:
    try:
        bundle_1 = df.loc[df.Hexabundle == f"{hexabundle_1}"]
    except KeyError:
        raise ValueError(f"The first index is {index_1} but it doesn't appear in the file!")

    delta_x1, delta_y1 = plate2sky(bundle_1['x'] * 1000, bundle_1['y'] * 1000, linear=True)

try:
    bundle_2 = df.loc[df.Hexabundle == f"{hexabundle_2}"]
except KeyError:
    raise ValueError(f"The second index is {index_2} but it doesn't appear in the file!")



delta_x2, delta_y2  = plate2sky(bundle_2['x'] * 1000, bundle_2['y'] * 1000, linear=True)

offset_x = delta_x2 - delta_x1
offset_y = delta_y2 - delta_y1


def print_offsets(offset_x, offset_y):

    if offset_y <= 0:
        NS_direction = 'S'

    else:
        NS_direction = "N"

    if offset_x <= 0:
        EW_direction = 'E'
    else:
        EW_direction = 'W'

    print(f"\t--> offset {np.abs(offset_y):.1f} arcseconds {NS_direction} and {np.abs(offset_x):.1f} arcseconds {EW_direction}")

if hexabundle_1 == 'centre':
    print(f"\nMoving from the centre of the plate to Hexabundle {hexabundle_2}:")
else:
    print(f"\nMoving from Hexabundle {hexabundle_1} to Hexabundle {hexabundle_2}:")
print_offsets(offset_x[0], offset_y[0])