#! python

import bigfile
import logging
import argparse
import os

ap = argparse.ArgumentParser('bigfile-convert',
    description="""
Converting between HDF5, bigfile, and FITS files.

Bigfile: (directories)
    - Composite Data Types are converted as sub-file,
    - Simple Columns are converted as BigBlock.
    - Shape are stored as ndarray.shape attribute.
      Multidimensional arrays are flattened and stored.
    - attrs are stored as attrs.

FITS: (.fits, .fz, .fits.gz)
    - Composite Data Types are stored as TableHDU
    - Simple Columns are stored as TableHDU with '_'
    - attrs are stored as strings.
    - u8 is converted to i8

HDF5: (.h5, .hdf5, .hdf)
    - everything, but u1 is not supported.

Vector columns in a DataSet is converted to 2d tables (nmemb >= 1)

include and exclude list are not implemented

Type of file is infered from file name / type
""")

ap.add_argument("src")
ap.add_argument("dest")
ap.add_argument("--src-format", default=None)
ap.add_argument("--dest-format", default=None)
ap.add_argument("--stop-on-errors", default=False, action='store_true')
ap.add_argument("--verify", action='store_true', default=True)
ap.add_argument("--no-verify", dest='verify', action='store_false', default=True)
#ap.add_argument("--include", action="append")
#ap.add_argument("--exclude", action="append")

def ishdf5(path):
    if os.path.exists(path):
        if not os.path.isfile(path):
            return False
    if path.endswith((".hdf5", ".hdf", ".h5")):
        return True
    return False

def isfits(path):
    if os.path.exists(path):
        if not os.path.isfile(path):
            return False
    if path.endswith((".fits", ".fz", ".fits.gz")):
        return True
    return False

def isbigfile(path):
    if os.path.exists(path):
        if not os.path.isdir(path):
            return False
    return True

def main(ns):
    if ns.src_format is None:
        if ishdf5(ns.src):
            ns.src_format = "hdf"
        elif isfits(ns.src):
            ns.src_format = "fits"
        elif isbigfile(ns.src):
            ns.src_format = "bigfile"

    readers = {'hdf' : HDFReader,
               'fits': FITSReader,
               'bigfile' : BigFileReader,
              }
    writers = {'hdf' : HDFWriter,
               'fits': FITSWriter,
               'bigfile' : BigFileWriter,
              }

    if ishdf5(ns.dest):
        ns.dest_format = "hdf"
    elif isfits(ns.dest):
        ns.dest_format = "fits"
    elif isbigfile(ns.dest):
        ns.dest_format = "bigfile"

    if not ns.stop_on_errors:
        logger = []
    else:
        logger = None

    convert(ns.src, ns.dest, readers[ns.src_format], writers[ns.dest_format],
            readers[ns.dest_format] if ns.verify else None,
            errlogger=logger)

    if logger is not None:
        print('\n'.join(logger))
        import sys
        if len(logger) > 0:
            sys.exit(-1)

class ReadProxy(object):
    def __init__(self, obj):
        self.obj = obj
    def __getitem__(self, index):
        return self.obj[index].view(dtype=self.dtype).reshape(self.shape)
    def __len__(self):
        return self.shape[0]
    def __repr__(self):
        return repr(self.obj)

class IOBase:
    def __enter__(self):
        self.fileobj.__enter__()
        return self

    def __exit__(self, *args):
        self.fileobj.__exit__(*args)

    def traverse(self, operation, prefix="", errlogger=None):
        children = self.getchildren(prefix)
        for k in children:
            path = '/'.join([prefix, k])
            self.traverse(operation, path, errlogger=errlogger)

        if len(children) == 0:
            srcobj = self.getobj(prefix)
            if errlogger is not None:
                try:
                    operation(srcobj, prefix)
                except Exception as e:
                    errlogger.append('Error processing %s : %s' % (prefix, str(e)))
            else:
                operation(srcobj, prefix)

    def verify(self, srcobj, path):
        from numpy.testing import assert_allclose
        destobj = self.getobj(path)
        if len(srcobj) > 0:
        # by pass dtype yet because it may change endianess e.g.
        #    assert (srcobj.dtype == destobj.dtype)
            assert_allclose(tuple(srcobj.shape), tuple(destobj.shape))

        src = srcobj[:].ravel()
        dest = destobj[:].ravel()
        if src.dtype.fields:
            for name in src.dtype.names:
                assert_allclose(src[name], dest[name])
        else:
            assert_allclose(src, dest)

class HDFReader(IOBase):
    def __init__(self, path):
        import h5py
        self.fileobj = h5py.File(path, mode='r')

    def getobj(self, path):
        obj = self.fileobj
        for seg in path.split('/'):
            if len(seg) == 0: continue
            if hasattr(obj, 'keys'): # datagroup
                obj = obj[seg]
            elif hasattr(obj, 'dtype'): # dataset
                obj = obj[seg]
        return obj

    def getchildren(self, prefix):
        srcobj = self.getobj(prefix)
        if hasattr(srcobj, 'keys'):
            # this is a datagroup
            return list(srcobj.keys())
        if hasattr(srcobj, 'dtype'):
            if srcobj.dtype.names is not None:
                return list(srcobj.dtype.names)
            else:
                return [] # leaf
        raise TypeError("Unknown data block type")

class node(dict): pass

def _blocks_to_tree(blocks):
    tree = node()
    node.final = False
    node.path = '/'
    def _add(tree, b):
        obj = tree
        path = ""
        for seg in b.split('/'):
            if len(seg) == 0: continue
            obj = obj.setdefault(seg, node())
            path = path + '/' + seg
            obj.path = path
            obj.final = False
        obj.final = True

    for b in blocks:
       _add(tree, b)

    return tree

class BigFileReader(IOBase):
    def __init__(self, path):
        self.fileobj = bigfile.File(path, create=False)
        self.tree = _blocks_to_tree(self.fileobj.blocks)

    def getnode(self, path):
        obj = self.tree
        for seg in path.split('/'):
            if len(seg) == 0: continue
            obj = obj[seg]
        return obj

    def _getobj_from_node(self, node):
        if node.final:
            obj = self.fileobj[node.path]
        else:
            obj = self.fileobj[node.path + '/']
            try:
                obj = bigfile.Dataset(obj)
                return obj
            except Exception as e:
                pass

        return obj

    def getobj(self, path):
        node = self.getnode(path)
        obj = self._getobj_from_node(node)

        obj1 = ReadProxy(obj)
        obj1.dtype = obj.dtype
        obj1.size = obj.size
        if hasattr(obj, 'attrs') and 'ndarray.shape' in obj.attrs:
            # terminal
            obj1.shape = obj.attrs['ndarray.shape']
        else:
            if hasattr(obj, 'shape'):
                obj1.shape = obj.shape
            else:
                obj1.shape = (obj.size, )
        if hasattr(obj, 'attrs'):
            obj1.attrs = dict(obj.attrs)
        else:
            obj1.attrs = None

        obj = obj1
        return obj

    def getchildren(self, prefix):
        node = self.getnode(prefix)
        obj = self._getobj_from_node(node)
        if isinstance(obj, (bigfile.Dataset, bigfile.Column)):
            return [] # leaf
        else:
            # this is a datagroup
            return list(node.keys())
        raise TypeError("Unknown data block type")

class BigFileWriter(IOBase):
    def __init__(self, path):
        self.fileobj = bigfile.File(path, create=True)

    def create(self, srcobj, path):
        dtype = srcobj.dtype

        if len(dtype.shape) > 1:
            raise ValueError("vector types unsupported")

        isvector = len(dtype.shape) != 0

        if not isvector:
            size = srcobj.size
        else:
            size, Nmemb = srcobj.shape
            assert Nmemb == dtype.shape[0]

        bb = self.fileobj.create(path, dtype, size=size)
        bb.write(0, srcobj[:].ravel())

        if not isvector:
            bb.attrs['ndarray.shape'] = srcobj.shape

class HDFWriter(IOBase):
    def __init__(self, path):
        import h5py
        self.fileobj = h5py.File(path, mode='w')

    def create(self, srcobj, path):
        dtype = srcobj.dtype

        obj = self.fileobj
        segs = [seg for seg in path.split('/') if len(seg) > 0]
        final = segs[-1]

        leading = '/'.join(segs[:-1])
        if len(leading) == 0:
            leading = '/'
        grp = obj.require_group(leading)
        bb = grp.create_dataset(final, data=srcobj[:])

        if srcobj.attrs is not None:
            for key, value in srcobj.attrs.items():
                try:
                    bb.attrs[key] = value
                except Exception as e:
                    print("missed property: %s %s for %s" % (key, value, e))

def _purify_path(path):
    segs = [seg for seg in path.split('/') if len(seg) > 0]
    return '/'.join(segs)
def _basename(path):
    segs = [seg for seg in path.split('/') if len(seg) > 0]
    return segs[-1]

class FITSReader(IOBase):
    def __init__(self, path):
        import fitsio
        self.fileobj = fitsio.FITS(path, mode='r')
        blocks = []
        for i in range(len(self.fileobj)):
            blocks.append(self.fileobj[i].get_extname())
        self.tree = _blocks_to_tree(blocks)

    def getnode(self, path):
        obj = self.tree
        for seg in path.split('/'):
            if len(seg) == 0: continue
            obj = obj[seg]
        return obj

    def _getobj_from_node(self, node):
        obj = self.fileobj[node.path.lstrip('/')]
        return obj

    def getobj(self, path):
        import numpy
        node = self.getnode(path)
        obj = self._getobj_from_node(node)

        obj1 = ReadProxy(obj)
        tbldtype = obj.get_rec_dtype()[0]
        if len(tbldtype.fields) == 1 and tbldtype.names[0] == '_':
            obj1.dtype = tbldtype['_']
        else:
            obj1.dtype = tbldtype

        header = obj.read_header()
        if 'NDARRAY_NDIM' in header:
            ndim = int(header['NDARRAY_NDIM'])
            shape = [
                int(header['NDARRAY_SHAPE_%d' % i]) for i in range(ndim)]
            shape = tuple(shape)
        else:
            shape = (obj.get_nrows(), )

        obj1.shape = shape
        obj1.size = numpy.prod(obj1.shape, dtype='intp')

        obj1.attrs = dict(obj.read_header())

        obj = obj1
        return obj

    def getchildren(self, prefix):
        node = self.getnode(prefix)

        if node.final:
            return [] # leaf
        else:
            # this is a datagroup
            return list(node.keys())
        raise TypeError("Unknown data block type")

class FITSWriter(IOBase):
    def __init__(self, path):
        import fitsio
        self.fileobj = fitsio.FITS(path, mode='rw', clobber=True)

    @staticmethod
    def _make_compatible_dtype(dtype):
        import numpy
        if dtype.fields is None:
            if dtype == numpy.dtype('u8'):
                return numpy.dtype('i8')
            else:
                return dtype
        else:
            descr = []
            for name, subdescr in dtype.fields.items():
                subdtype, offset = subdescr[:2]
                descr.append((name, (FITSWriter._make_compatible_dtype(subdtype), offset)))
            return numpy.dtype(dict(descr))

    def create(self, srcobj, path):
        dtype = srcobj.dtype

        obj = self.fileobj
        extname = _purify_path(path)

        data = srcobj[:]

        dtype = self._make_compatible_dtype(data.dtype)
        data = data.view(dtype)

        if data.dtype.fields is None:
            # workaround FITS not allowing non-structured data
            data = data.view(dtype=[('_', data.dtype)])

        obj.create_table_hdu(extname=extname, dtype=data.dtype)
        bb = obj[extname]
        bb.write(data.ravel())

        header = {}
        if srcobj.attrs is not None:
            header.update(dict([(k, str(v)) for k, v in srcobj.attrs.items()]))

        header['NDARRAY_NDIM'] = len(data.shape)

        for i, s in enumerate(data.shape):
            header['NDARRAY_SHAPE_%d' % i] = s

        bb.write_keys(header)

        #print(extname, data.shape)
        #print('created table', bb, bb.read_header())

def convert(srcpath, outpath, readertype, writertype, verifiertype, errlogger):
    with readertype(srcpath) as hin:

        with writertype(outpath) as writer:
            hin.traverse(writer.create, errlogger=errlogger)

        if verifiertype:
            with verifiertype(outpath) as reader:
                hin.traverse(reader.verify, errlogger=errlogger)

if __name__ == "__main__" :
    ns = ap.parse_args()

    main(ns)
