#!python
"""
This script detects a single moving red laser spot for an alternative 
calibration procedure under development by Pranav Khandelwal and
Dylan Ray. 
"""

import argparse
import logging
import cv2
import numpy as np
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="get red laser spot position from movie",epilog=__doc__)
    parser.add_argument("ifile",
                        help="filename of movie to process")
    parser.add_argument("--ofile",default="dot.csv",
                        help="output filename for csv")
    parser.add_argument("--mask",nargs=4,
                        help="take pixels only inside [xmin xmax ymin ymax]")
    parser.add_argument("--frames",nargs=2,
                        help="take frames only inside [start finish]")
    parser.add_argument("--threshold",default=0.35,
                        help="threshold, default 0.35")
    parser.add_argument("--alpha",default=0.05,
                        help="alpha for background subtractor, default 0.05")
    parser.add_argument("--display",action="store_true",
                        help="display while processing")
    parser.add_argument("--verbose",action="store_true",
                        help="toggle verbosity")
    args = parser.parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        logging.debug("running in verbose mode")
    else:
        logging.basicConfig(level=logging.INFO)
    






    # open the movie
    logging.info("opening {0}".format(args.ifile))
    movie = cv2.VideoCapture(args.ifile)
    n_frames = int(movie.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(movie.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(movie.get(cv2.CAP_PROP_FRAME_HEIGHT))
    logging.debug("got {0} frames size ({1},{2})".format(n_frames,height,width))

    # initialize background 
    background = np.zeros((height,width),dtype=np.float32)

    # initialize display window
    if args.display:
        logging.debug("creating display window")
        cv2.namedWindow('argus_detect_laser',cv2.WINDOW_NORMAL)
    
    # initialize result
    result = dict()
    result['frame'] = list()
    result['pt1_X'] = list()
    result['pt1_Y'] = list()

    # mask optional
    if args.mask is not None:
        mask = np.zeros((height,width),dtype=np.uint8)
        xmin = int(args.mask[0])
        xmax = int(args.mask[1])
        ymin = int(args.mask[2])
        ymax = int(args.mask[3])
        mask[ymin:ymax,xmin:xmax] = 255
    else:
        mask = None

    # frame optional
    if args.frames is not None:
        frames = [int(args.frames[0]),int(args.frames[1])]
    else:
        frames = [0,n_frames] 





    # walk through frame by frame and do stuff
    logging.info("processing frame by frame")
    for a in range(n_frames): #xrange(100):
        retval, raw = movie.read()
        result['frame'].append(a)

        if (retval and
            (a >= frames[0]) and
            (a < frames[1])):

            b,g,r = cv2.split(raw) # red laser, use r
            # an alternative here may be to subtract g-b? 

            scaled = np.array(r/255.,dtype=np.float32)
            subtracted = scaled-background
            thresh = np.array((subtracted>float(args.threshold))*1,
                              dtype=np.uint8)
            cv2.accumulateWeighted(scaled,
                                   background,
                                   alpha=float(args.alpha))
            # last line updates the background

            # do any additional morphological operations here
            if mask is not None:
               thresh = cv2.bitwise_and(thresh,mask)



            # check moments
            m = cv2.moments(thresh)
            if ((m['mu20'] < (10.**2.)) and 
                (m['mu02'] < (10.**2.)) and 
                (m['mu11'] < (10.**2.)) and 
                (m['m00'] > 1.) and
                (m['m00'] < 144.)):
                # so if the dot is concentrated enough and within size limits
                # you may want to adjust these... 
                
                # record its position
                x = m['m10']/(m['m00'])
                y = m['m01']/(m['m00'])
                center = (int(np.round(x)),int(np.round(y)))
                result['pt1_X'].append(x)
                result['pt1_Y'].append(y)
                logging.debug("frame {0} got point at ({1:0.1f},{2:0.1f})".format(a,x,y))

            else:
                center = None
                result['pt1_X'].append(np.nan)
                result['pt1_Y'].append(np.nan)
                logging.debug("frame {0} no output".format(a))
                




            # display optional
            if args.display:
                draw = cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY)
                draw = cv2.cvtColor(draw,cv2.COLOR_GRAY2BGR)
                if center is not None:
                    cv2.circle(draw,center,2,(0,0,255),2)
                    cv2.putText(draw,'PT1 DETECTED',
                                (10,50),cv2.FONT_HERSHEY_SIMPLEX,1,
                                (0,0,255),3)
                cv2.imshow('argus_detect_laser',draw)
                cv2.waitKey(1)
        else:
            result['pt1_X'].append(np.nan)
            result['pt1_Y'].append(np.nan)
                
    # write the output
    logging.info("writing result to {0}".format(args.ofile))
    result = pd.DataFrame(result)
    #result = result.dropna(how='any')
    result = result[['frame','pt1_X','pt1_Y']]
    result['frame'] = result.frame.astype(int)
    result.to_csv(args.ofile,index=False)

    # cleanup
    logging.debug("done, cleaning up")
    movie = None
    if args.display:
        cv2.destroyAllWindows()
