#!python
"""
This script provides basic auto-detection for brightly colored lacrosse balls
under fluorescent light. 

It is possible to modify this for other situations; see the comments within 
the body of the code.  Automatic segmentation in general can be sensitive
to lighting, color, shape, etc so each application will require modification
to the template here
"""

import argparse
import logging
import cv2
import pandas as pd
import numpy as np

class ColoredBall(object):
   def __init__(self,
                lowerb,upperb,
                display,
                name=None,colorname=None,
                min_area=20.**2.,max_area=200.**2.):
      self.lowerb = lowerb
      self.upperb = upperb
      self.display = display
      self.name = name
      self.colorname = colorname
      self.min_area = min_area
      self.max_area = max_area
      super(ColoredBall,self).__init__()

   def inRange(self,hsv):
      return cv2.inRange(hsv,
                         lowerb=self.lowerb,
                         upperb=self.upperb)



PinkBall = ColoredBall(lowerb=(165,190,90),
                       upperb=(185,255,255),
                       display=(131,59,236),
                       name="pt1",
                       colorname="pink lacrosse ball")
OrangeBall = ColoredBall(lowerb=(0,190,90),
                         upperb=(20,255,255),
                         display=(0,69,255),
                         name="pt2",
                         colorname="orange lacrosse ball")
GreenBall = ColoredBall(lowerb=(25,110,100),
                        upperb=(50,180,255),
                        display=(51,255,153),
                        name="pt3",
                        colorname="light green lacrosse ball")
BlueBall = ColoredBall(lowerb=(95,105,0),
                       upperb=(135,165,255),
                       display=(255,51,51),
                       name="pt4",
                       colorname="blue lacrosse ball")

balls = [PinkBall,OrangeBall,GreenBall,BlueBall]








if __name__ == "__main__":

   # get arguments
   parser = argparse.ArgumentParser(description="find lacrosse balls",epilog=__doc__)
   parser.add_argument("ifile",
                       help="input movie to process")
   parser.add_argument("--ofile",default="pts.csv",
                       help="output filename for result, default pts.csv")
   parser.add_argument("--mask",nargs=4,
                       help="take pixels only inside [xmin xmax ymin ymax]")
   parser.add_argument("--frames",nargs=2,
                       help="take frames only inside [start finish]")
   parser.add_argument("--display",action="store_true",
                       help="display while working")
   parser.add_argument("--verbose",action="store_true",
                       help="toggle verbosity")
   args = parser.parse_args()
   if args.verbose:
      logging.basicConfig(level=logging.DEBUG)
      logging.debug("running in verbose mode")
   else:
      logging.basicConfig(level=logging.INFO)



   # open movie
   movie = cv2.VideoCapture(args.ifile)
   n_frames = int(movie.get(cv2.CAP_PROP_FRAME_COUNT))
   height = int(movie.get(cv2.CAP_PROP_FRAME_HEIGHT))
   width = int(movie.get(cv2.CAP_PROP_FRAME_WIDTH))

   # initialize stuff
   result = dict()
   result['frame'] = list()
   for eachball in balls:
      result['{0}_X'.format(eachball.name)] = list()
      result['{0}_Y'.format(eachball.name)] = list()

   # mask optional
   if args.mask is not None:
      mask = np.zeros((height,width,3),dtype=np.uint8)
      xmin = int(args.mask[0])
      xmax = int(args.mask[1])
      ymin = int(args.mask[2])
      ymax = int(args.mask[3])
      mask[ymin:ymax,xmin:xmax,:] = 255
   else:
      mask = None

   # frame optional
   if args.frames is not None:
      frames = [int(args.frames[0]),int(args.frames[1])]
   else:
      frames = [0,n_frames]

   # initialize display window
   if args.display:
      cv2.namedWindow('argus_detect_balls',cv2.WINDOW_NORMAL)
      if args.verbose:
         cv2.namedWindow('p',cv2.WINDOW_NORMAL)
         cv2.namedWindow('o',cv2.WINDOW_NORMAL)
         cv2.namedWindow('g',cv2.WINDOW_NORMAL)
         cv2.namedWindow('b',cv2.WINDOW_NORMAL)

      
   # go through frame by frame
   logging.info("processing frame by frame")
   for frame in range(n_frames):
      retval, raw = movie.read()
      result['frame'].append(frame)

      if (retval and 
          (frame >= frames[0]) and
          (frame < frames[1])):

         blurred = cv2.medianBlur(raw,5)
         hsv = cv2.cvtColor(blurred,cv2.COLOR_BGR2HSV)


         # apply mask (optional)
         if mask is not None:
            hsv = cv2.bitwise_and(hsv,mask)

         if args.display:
            draw = cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY)
            draw = cv2.cvtColor(draw,cv2.COLOR_GRAY2BGR)

         for ix,eachball in enumerate(balls):
            detected = eachball.inRange(hsv)
            working = detected.copy()
            contours, _ = cv2.findContours(working,
                                           mode = cv2.RETR_LIST,
                                           method = cv2.CHAIN_APPROX_NONE)
            if len(contours)>0:
               contours = sorted(contours,
                                 key = lambda c:cv2.moments(c)['m00'],
                                 reverse=True)
               probable = contours[0]
               moments = cv2.moments(probable)
               if ((moments['m00']>=eachball.min_area) and
                   (moments['m00']<eachball.max_area)):
                  center = (moments['m10']/moments['m00'],
                            moments['m01']/moments['m00'])
                  result['{0}_X'.format(eachball.name)].append(center[0])
                  result['{0}_Y'.format(eachball.name)].append(center[1])
                  if args.display:
                     cv2.circle(draw,
                                (int(round(center[0])),int(round(center[1]))),
                                2,
                                eachball.display,
                                2)
                     cv2.putText(draw,
                                 '{0} DETECTED'.format(eachball.name.upper()),
                                 (10,50*(ix+1)),cv2.FONT_HERSHEY_SIMPLEX,1,
                                 eachball.display,3)
               else: # contour not in size range
                  center = None
                  result['{0}_X'.format(eachball.name)].append(np.nan)
                  result['{0}_Y'.format(eachball.name)].append(np.nan)
            else: # no contours found
               center = None
               result['{0}_X'.format(eachball.name)].append(np.nan)
               result['{0}_Y'.format(eachball.name)].append(np.nan)
                
         if args.display:
            cv2.imshow('argus_detect_balls',draw)
            if args.verbose:
               cv2.imshow('p',PinkBall.inRange(hsv))
               cv2.imshow('o',OrangeBall.inRange(hsv))
               cv2.imshow('g',GreenBall.inRange(hsv))
               cv2.imshow('b',BlueBall.inRange(hsv))
            cv2.waitKey(1)

      else:
         for eachball in balls:
            result['{0}_X'.format(eachball.name)].append(np.nan)
            result['{0}_Y'.format(eachball.name)].append(np.nan)


    
   # output results
   logging.info("writing result to {0}".format(args.ofile))
   result = pd.DataFrame(result)
   cols = ['frame']
   for eachball in balls:
      cols.append('{0}_X'.format(eachball.name))
      cols.append('{0}_Y'.format(eachball.name))
   result = result[cols]
   result.to_csv(args.ofile,index=False)

   # cleanup
   logging.info("done, cleaning up")
   movie = None
   if args.display:
      cv2.destroyAllWindows()
