#!/usr/bin/python
'''`cat` for avro files'''

__author__ = 'Miki Tebeka <miki.tebeka@gmail.com>'
__version__ = '0.3.1'

from avro.io import DatumReader
from avro.datafile import DataFileReader

import json
import csv
from sys import stdout
from itertools import ifilter, imap
from functools import partial
from operator import itemgetter

def print_json(row, fields, getter):
    row = dict(zip(fields, getter(row)))
    print(json.dumps(row))

def print_json_pretty(row, fields, getter):
    row = dict(zip(fields, getter(row)))
    print(json.dumps(row, indent=4))

_write_row = csv.writer(stdout).writerow
_encoding = stdout.encoding or 'UTF-8'

def _encode(v):
    if not isinstance(v, basestring):
        return v
    return v.encode(_encoding)

def print_csv(row, fields, getter):
    # We sort the keys to the fields will be in the same place
    _write_row(getter(row))

def select_printer(format):
    return {
        'json' : print_json,
        'json-pretty' : print_json_pretty,
        'csv' : print_csv
    }[format]

def record_match(expr, record):
    return eval(expr, None, {'r' : record})

def parse_fields(fields):
    return filter(None, map(lambda field: field.strip(), fields.split(',')))

def fields_selector(fields, getter):
    def keys_filter(obj):
        return dict(zip(fields, getter(obj)))
    return keys_filter

def print_avro(avro, args):
    if args.header and (args.format != 'csv'):
        raise ValueError('--header applies only to CSV format')

    # Apply filter first
    if args.filter:
        avro = ifilter(partial(record_match, args.filter), avro)

    for i in xrange(args.skip):
        try:
            next(avro)
        except StopIteration:
            return

    getter = itemgetter(*args.parsed_fields)
    avro = imap(fields_selector(args.parsed_fields, getter), avro)
    printer = select_printer(args.format)

    for i, record in enumerate(avro):
        if i == 0 and args.header:
            _write_row(args.parsed_fields)
        if i >= args.count:
            break
        printer(record, args.parsed_fields, getter)

def get_schema(avro):
    return json.loads(avro.meta['avro.schema'])

def print_schema(avro):
    schema = get_schema(avro)
    # Pretty print
    print json.dumps(schema, indent=4)

def schema_fields(avro):
    schema = get_schema(avro)

    return ','.join(field['name'] for field in schema['fields'])

def main(argv=None):
    import sys
    from argparse import ArgumentParser

    argv = argv or sys.argv

    parser = ArgumentParser(description='`cat` for Avro files')
    parser.add_argument('-v', '--version', action='version',
                        version='%(prog)s {0}'.format(__version__))
    parser.add_argument('files', help='avro file(s)', nargs='*',
                        metavar='FILE')
    parser.add_argument('-n', '--count', default=float('Infinity'),
                    help='number of records to print', type=int)
    parser.add_argument('-s', '--skip', help='number of records to skip', type=int,
                   default=0)
    parser.add_argument('-f', '--format', help='record format', default='json',
                  choices=['json', 'csv', 'json-pretty'])
    parser.add_argument('--header', help='print CSV header', default=False,
                   action='store_true')
    parser.add_argument('--filter', help='filter records (e.g. r["age"]>1)',
                    default=None)
    parser.add_argument('--schema', help='print schema', action='store_true',
                       default=False)
    parser.add_argument('--fields', default='',
                        help='Fields to show (comma separated)')

    args = parser.parse_args(argv[1:])

    if not args.files:
        raise SystemExit('error: missing filename(s)')

    for filename in args.files:
        try:
            fo = open(filename, 'rb')
        except (OSError, IOError) as e:
            raise SystemExit('error: cannot open %s - %s' % (filename, e))

        avro = DataFileReader(fo, DatumReader())

        if args.schema:
            print_schema(avro)
            return

        args.parsed_fields = parse_fields(args.fields or schema_fields(avro))

        print_avro(avro, args)

if __name__ == '__main__':
    main()

