#!python
"""
Finds chessboards or dots in a video, typically an MP4 video from a GoPro. 
The detected patterns are stored as a Python .pkl format dict() of numpy arrays
stored in the following order: objectPoints (size (N,1,3)), imagePoints (size
(N,1,2)), and imageSize (width,height) in pixels.  If you have trouble detecting
check that the pattern size is correct; you may have to use --inverted flag
or swap the rows and columns depending on how the pattern is held in the frame. 
"""

import argparse
import logging
import cv2
import os
import numpy as np
import pickle
import scipy.spatial.distance

if __name__=="__main__":
    # parse arguments
    parser = argparse.ArgumentParser(description="Detects patterns in movie", epilog=__doc__)
    parser.add_argument("ifile",help="movie to process, eg Green/GOPRO130.MP4")
    parser.add_argument("--ofile","-o",default=None,
                        help="output file, default corners.pkl")

    # Which frames to analyze
    parser.add_argument("--start",type=int,default=0,
                        help="patterns start here, default frame 0")
    parser.add_argument("--stop",type=int,default=None,
                        help="patterns stop here, defaults to entire sequence")

    # information about the pattern
    parser.add_argument("--chessboard",action="store_true",default=True,
                        help="pattern is chessboard (default)")
    parser.add_argument("--dots",action="store_true",
                        help="pattern is dot pattern")
    parser.add_argument("--asymmetricdots",action="store_true",
                        help="pattern is asymmetric dot pattern")
    parser.add_argument("--inverted",action="store_true",
                        help="use if pattern is inverted?")
    parser.add_argument("--rows",type=int,default=None,
                        help="rows, or number of points per column in pattern, default 10")
    parser.add_argument("--cols",type=int,default=None,
                        help="cols, or number of points per row in pattern, default 7")
    parser.add_argument("--spacing",type=float,default=0.02,
                        help="pattern spacing, default 0.02 m")

    # other options
    parser.add_argument("--no_qa",action="store_true",
                        help="skip quality check for duplicate points")
    parser.add_argument("--display","-d",action="store_true",
                        help="display frames during processing")
    parser.add_argument("--verbose","-v",action="store_true",
                        help="toggle verbosity")
    args = parser.parse_args()

    # set verbosity
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        logging.debug("running in verbose mode")
    else:
        logging.basicConfig(level=logging.INFO)

    # configure display
    if args.display:
        WINDOWNAME = "argus_detect_patters in {0}".format(args.ifile)
        cv2.namedWindow(WINDOWNAME, cv2.WINDOW_NORMAL)
        logging.debug("display window created")

    # configure ofile
    if args.ofile is None:
        args.ofile = os.path.join(os.path.dirname(args.ifile),"corners.pkl")
        
    # open movie, set start and stop, and get imageSize
    logging.debug("opening video from {0}".format(args.ifile))
    movie = cv2.VideoCapture(args.ifile)
    if ((args.stop is None) or \
            (args.stop < 0) or \
            (args.stop > int(movie.get(cv2.CAP_PROP_FRAME_COUNT)))):
        args.stop = int(movie.get(cv2.CAP_PROP_FRAME_COUNT))
    if ((args.start < 0) or \
            (args.start > args.stop)):
        args.start = 0
    logging.debug("using from frame {0} to {1}".format(args.start,args.stop))
    imageSize = (int(movie.get(cv2.CAP_PROP_FRAME_WIDTH)),
                 int(movie.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    logging.debug("image size is {0}".format(imageSize))
    
    # initialize pattern stuff
    if args.cols is None:
        if args.dots:
            args.cols = 9
        elif args.asymmetricdots:
            args.cols = 6
        else:
            args.cols = 7
    if args.rows is None:
        if args.dots:
            args.rows = 12
        elif args.asymmetricdots:
            args.rows = 21
        else:
            args.rows = 10
    boardsize = (args.cols,args.rows)
    if args.dots:
        logging.debug("dot pattern size is {0}".format(boardsize))
    elif args.asymmetricdots:
        logging.debug("asymmetric dot pattern size is {0}".format(boardsize))
    else:
        logging.debug("pattern size is {0}".format(boardsize))

    # build object points
    single_board = []
    if not(args.asymmetricdots):
        for col in range(args.rows):
            for row in range(args.cols):
                if args.inverted:
                    single_board.append([-row,-col,0])
                else:
                    single_board.append([row,col,0])
        single_board = np.array(single_board,dtype=np.float32).reshape(-1,1,3)
        single_board = single_board*args.spacing
    else:
        for X in xrange(args.rows):
            if (X%2 == 0):
                for Y in range(0,args.cols*2,2):
                    if args.inverted: 
                        single_board.append([-Y,-X,0])
                    else:
                        single_board.append([Y,X,0])
                    pass
            else:
                for Y in range(1,args.cols*2,2):
                    if args.inverted:
                        single_board.append([-Y,X,0])
                    else:
                        single_board.append([Y,X,0])
                    pass
        single_board = np.array(single_board,dtype=np.float32).reshape(-1,1,3)
        single_board = single_board*float(args.spacing)/2.
        #print(single_board)




    # find patterns in range
    objectPoints = dict()
    imagePoints = dict()
    logging.debug("beginning frame by frame pattern search")
    for frame in range(args.stop):
        retval,raw = movie.read() # read the frame

        if retval and (frame >= args.start):
            draw = raw
            gray = cv2.cvtColor(raw,cv2.COLOR_RGB2GRAY) # convert to gray
            if args.dots: # detect dot pattern
                try:
                    retval, corners = cv2.findCirclesGrid(gray,boardsize)
                except:
                    retval, corners = cv2.findCirclesGridDefault(gray,boardsize)
            elif args.asymmetricdots:
                try:
                    retval, corners = cv2.findCirclesGrid(gray,boardsize,
                                                          centers = None,
                                                          flags=cv2.CALIB_CB_ASYMMETRIC_GRID | cv2.CALIB_CB_CLUSTERING)
                except:
                    retval, corners = cv2.findCirclesGridDefault(gray,boardsize,
                                                                 centers = None,
                                                                 flags=cv2.CALIB_CB_ASYMMETRIC_GRID | cv2.CALIB_CB_CLUSTERING)
            else: # detect chessboard
                retval, corners = cv2.findChessboardCorners(gray,boardsize)
                if retval:
                    cv2.cornerSubPix(gray,
                                     corners,
                                     (3,3),(-1,-1),
                                     (cv2.TERMCRIT_ITER |
                                      cv2.TERMCRIT_EPS,30,0.1))
            # show results    
            cv2.drawChessboardCorners(draw,boardsize,corners,retval)
            
            # add quality control check here for duplicate points
            # or very tiny boards
            if not(args.no_qa) and retval:
                check = np.array(corners)
                check = check.reshape(-1,2)
                dists = scipy.spatial.distance.pdist(check)
                if (dists<1).any():
                    retval = False
                    logging.warning("duplicate point detected")
                if np.max(dists)<100:
                    retval = False
                    logging.warning("pattern too small")

            # add to output
            if retval:
                objectPoints[frame] = single_board
                imagePoints[frame] = corners
                logging.info("Pattern found at frame {0}".format(frame))
            else:
                logging.debug("Pattern not found at frame {0}".format(frame))

            # display result if ordered to
            if args.display:
                cv2.imshow(WINDOWNAME,draw)
                cv2.waitKey(1)
    
    # save results    
    logging.info("Saving results to {0}".format(args.ofile))
    ofile = open(args.ofile,"wb")
    pickle.dump(objectPoints,ofile)
    pickle.dump(imagePoints,ofile)
    pickle.dump(imageSize,ofile)
    ofile.close()
    
    # report completion
    if args.dots:
        logging.info("Checked {0} frames from {1}, found {2} complete dot patterns of size ({3},{4}), spacing {5}".format(frame,args.ifile,len(imagePoints),args.cols,args.rows,args.spacing)) 
    elif args.asymmetricdots:
        logging.info("Checked {0} frames from {1}, found {2} complete asymmetric dot patterns of size ({3},{4}), spacing {5}".format(frame,args.ifile,len(imagePoints),args.cols,args.rows,args.spacing)) 
    else:
        logging.info("Checked {0} frames from {1}, found {2} complete chessboards of size ({3},{4}), spacing {5}".format(frame,args.ifile,len(imagePoints),args.cols,args.rows,args.spacing))


    # cleanup
    logging.debug("cleaning up")
    movie = None
    if args.display: # destroy display if on
        cv2.destroyWindow(WINDOWNAME) 
