class EventList(list):
    """
    list of Event
    """


class AnimalAnimalEvent(Dict[str, AnimalEvent]):
    """
    Return types of animals_social_events()   
    Examples
    --------
    >>> # Define <|run_after|> as a social behavior where distance between animals are less than 100, one animal is in front of the other animal and the target animal has speed faster than 3. Get duration of run_after
    >>> def task_program():
    >>>     behavior_analysis = AnimalBehaviorAnalysis()
    >>>     run_after_social_events = behavior_analysis.animals_social_events(['distance', 'orientation'], 
    >>>         [f'< 100', f'=={Orientation.FRONT}'],
    >>>         individual_animal_state_query_list = ['speed'],
    >>>         individual_animal_state_comparison_list=['>=3'],
    >>>         bodyparts = ['all'],
    >>>         otheranimal_bodyparts = ['all'])
    >>>     run_after_social_events_duration = run_after_social_events.duration
    >>>     return run_after_social_events_duration
    """

class AnimalEvent(Dict[str, EventList]): 
    """
    Do not use the iterator of this class
    Attributes    
    ----------    
    duration -> int
        duration of the AnimalEvent
    Methods    
    -------
    sort()    
    Examples
    --------
    >>> # get events where the animal stays in ROI0 and ROI1
    >>> def task_program():
    >>>     behavior_analysis = AnimalBehaviorAnalysis() 
    >>>     overlap_roi0_events = behavior_analysis.animals_object_events('ROI0', 'overlap', bodyparts = ['all'])
    >>>     overlap_roi1_events = behavior_analysis.animals_object_events('ROI1', 'overlap', bodyparts = ['all'])
    >>>     overlap_roi0_and_roi1_events =  overlap_roi0_events | overlap_roi1_events
    >>>     return overlap_roi0_and_roi1_events
    >>> # get events where the animal first enters in ROI0
    >>> def task_program():
    >>>     behavior_analysis = AnimalBehaviorAnalysis() 
    >>>     enter_roi0_events = behavior_analysis.enter_obj('ROI0', bodyparts = ['all'])
    >>>     enter_roi0_first_time_events = enter_roi0_events[0]
    >>>     return enter_roi0_first_time_events
    """ 

    
    def filter_tensor_by_events(self, data, frame_range = None):        
        """
        Filter the tensor by events. The filtered data will contain nans. Need to handle nans
        Parameters
        ----------
        data: np.array
        frame_range: optional, constrains the events by frame range
        Returns
        -------
        ret: np.ndarray
        -------
        >>> # plot orientation vectors formed by nose and mouse_center filered by events where the animal0 orienting towards object 8. Only plot the frame range (2000, 3000)
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     orientation_vectors = behavior_analysis.get_orientation_vector('mouse_center', 'nose')
        >>>     orienting_towards_object_8_events = behavior_analysis.animals_object_events('8', 'orientation', comparison = '==Orientation.FRONT', bodyparts = ['all'])
        >>>     filtered_orientation_vectors = orienting_towards_object_8_events.filter_tensor_by_events(orientation_vectors, frame_range = range(2000,3000))
        >>>     fig, axes
         = plt.subplots(1, 2, figsize=(12, 8))
        >>>     axes[0].plot(np.arange(filtered_orientation_vectors), filtered_orientation_vectors[...,0])
        >>>     axes[1].plot(np.arange(filtered_orientation_vectors), filtered_orientation_vectors[...,1])
        >>>     return fig, axes
        >>>  # Give me the distance travelled for animals in roi0
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      in_roi0_events = behavior_analysis.animals_object_events('ROI0', 'overlap', bodyparts = ['all'])
        >>>      # speed is of shape (n_frames, n_individuals, n_kpts, 1)
        >>>      speed = behavior_analysis.get_speed()
        >>>      # Important! Averaging the speed over the dimension of n_kpts to avoid duplicate calculation of distance travelled
        >>>      averaged_speed_in_roi0 = in_roi0_events.filter_tensor_by_events(averaged_speed_over_kpts)      
        >>>      distance_travelled_in_roi0 = np.nansum(averaged_speed_in_roi0, axis = 0)
        >>>      return distance_travelled_in_roi0
        """


class AnimalBehaviorAnalysis:
    """
    Methods
    -------
    Note for kinematics such as speed, acceleration and velocity, give the averaged value across n_kpts if no bodypart is specified

    get_object_names() -> List[str]
        get names of all objects. 
    get_roi_object_names() -> List[str]
        get names of ROI objects
    get_animal_names() -> List[str]
        get names of all animals
    get_keypoints() -> np.array (n_frames, n_individuals, n_kpts, 2). where 2 is x and y
        Can be used to get the number of frame from the video, number of animals in the video and number of keypoints for the animals. 
    get_speed() -> np.array (n_frames, n_individuals, n_kpts, 1). 
        This function returns an array of speed of the animals. Note when used to calculate distance travelled, the speed must be averaged across n_kpts dimension first.
    get_velocity() -> np.array (n_frames, n_individuals, n_kpts, 2), where 2 is x and y 
        This function returns an array of speed of the animals. Note velocity is vector
    get_acceleration() -> np.array (n_frames, n_individuals, n_kpts, 1). 
        This function returns an array of acceleration. 
    get_bodypart_names -> List[str]
        get names of all bodyparts of the animal
    get_bodypart_indices(bodypart_name: List[str]) -> List[int]
        Get the indices of the named bodyparts of the animal. If named bodyparts are given, MUST use this function before accessing tensor from get_keypoints(), get_speed() and get_acceleration.
    get_object_center(object_name: str) -> np.array (2), where 2 is x and y:
        get the center of the named object    
    get_orientation_vector(bodypart_name1, bodypart_name2): -> np.array shape (n_frames, n_animals, 2)
        return the orientation vector of the vector connecting bodypart_name1 and bodypart_name2 for all animals

    plot_ethogram(events: AnimalEvent) -> Tuple[plt.Figure, plt.Axes, plot_caption]  
        get ethogram for corresponding AnimalEvent    
    
    animals_social_events(inter_individual_animal_state_query_list, 
                                inter_individual_animal_state_comparison_list,
                                individual_animal_state_query_list = [],
                                individual_animal_state_comparison_list = [],
                                bodyparts = ['all'],
                                otheranimal_bodyparts = ['all'],
                                min_window = 11, 
                                pixels_per_cm = 8,
                                smooth_window_size = 5) -> AnimalAnimalEvent 

    """
    def get_keypoints(self):
        """
        Examples
        --------
        >>>  # plot histogram for keypoints
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      keypoints = behavior_analysis.get_keypoints()
        >>>      n_frames, n_individuals, n_kpts, _ = keypoints.shape
        >>>      fig, axes = plt.subplots(n_individuals, 2, figsize=(12, 8))
        >>>      axes = np.array(axes)        
        >>>      axes = axes.reshape(n_individuals, -1)
        >>>      for i in range(n_individuals):     
        >>>          axes[i, 0].hist(keypoints[:,  0], bins=50, color='blue', alpha=0.5)
        >>>          axes[i, 1].hist(keypoints[:,  1], bins=50, color='blue', alpha=0.5)
        >>>          axes[i, 0].set_title(f'Histogram of x for Animal {i}')
        >>>          axes[i, 1].set_title(f'Histogram of y for Animal {i}')
        >>>      return fig, axes

        """
        return get_keypoints()
    
    def get_speed(self):
        """
        Examples
        --------
        >>>  # plot histogram for speed
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      speed = behavior_analysis.get_speed()
        >>>      n_frames, n_individuals, n_kpts, _ = speed.shape
        >>>      fig, axes = plt.subplots(n_individuals, 2, figsize=(12, 8))
        >>>      axes = np.array(axes)
        >>>      axes = axes.reshape(n_individuals, -1)
        >>>      for i in range(n_individuals):     
        >>>          axes[i, 0].hist(speed[:,  0], bins=50, color='blue', alpha=0.5)
        >>>          axes[i, 1].hist(speed[:,  1], bins=50, color='blue', alpha=0.5)
        >>>          axes[i, 0].set_title(f'Histogram of speed x for Animal {i}')
        >>>          axes[i, 1].set_title(f'Histogram of speed y for Animal {i}')
        >>>      return fig, axes
        """        
        return get_speed()

    def get_acceleration(self):
        """
        Examples
        --------
        >>>  # plot histogram for speed
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      acceleration = behavior_analysis.get_acceleration()
        >>>      n_frames, n_individuals, n_kpts, _ = acceleration.shape
        >>>      fig, axes = plt.subplots(n_individuals, 1, figsize=(12, 8))
        >>>      axes = np.array(axes)
        >>>      axes = axes.reshape(n_individuals, -1)
        >>>      for i in range(n_individuals):     
        >>>          axes[i, 0].hist(acceleration[:,  0], bins=50, color='blue', alpha=0.5)
        >>>          axes[i, 0].set_title(f'Histogram of acceleration for Animal {i}')
        >>>      return fig, axes
        """        
        return get_acceleration()    
    


    def animals_state_events(self, 
                            state_type, 
                            comparison, 
                            bodyparts = ['all'],                                                
                            ) -> AnimalEvent:   
        """
        Parameters
        ----------
        state_type: str
            Must be one of 'speed', 'acceleration', 'bodypart_pairwise_distance'
        comparison: str
            Must be a comparison operator followed by a number like <50,
        Examples
        --------
        >>>  # Get events where animal moving faster than 3 pixels across frames.       
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      animal_faster_than_3_events =  behavior_analysis.animals_state_events('speed', '>3')
        >>>      return animal_faster_than_3_events
        >>>  # Get events where animal's nose to its own tail_base distance larger than 10
        >>>  def task_program():
        >>>      behavior_analysis = AnimalBehaviorAnalysis()
        >>>      nose_tail_base_larger_than_10_events =  behavior_analysis.animals_state_events(['bodypart_pairwise_distance'], '>10', bodyparts = ['nose', 'tail_base'])
        >>>      return nose_tail_base_larger_than_10_events    



        """
        return animals_state_events(state_type, comparison)

    def plot_trajectory(self, bodyparts: List[str], events: Union[Dict[str, dict],dict] = None, frame_range = None,
                        cmap = 'rainbow',
                        marker = 'o',
                        alpha = None,
                        linewidths = None):
                        
        """     
        Plot the trajectory of the animals
        Parameters
        ----------
        bodyparts : List[str]
            name of the animal's bodypart. Can either be 'all' or list of body parts。 if 'all' is given, it takes all bodyparts and use the center of those bodyparts.
        events: Union[Dict[str, dict],dict], optional
            The type must be either dict or dictionary of dict
        frame_range: range, optional
            selection of frames
        kwargs : parameters for plt.scatter
        Returns:
        -------
        Tuple[plt.Figure, List[plt.Axes], plot_caption]
           Always return a tuple of figure and axes
        -------
        Examples
        --------
        >>> # plot trajectory of the animal.
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     traj_plot_info = behavior_analysis.plot_trajectory(["all"])
        >>>     return traj_plot_info
        >>> # plot trajectory of events where the nose of the animal with the event that animal overlaps with object 6
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     overlap_6_events = behavior_analysis.animals_object_events('6', 'overlap', bodyparts = ['all'])
        >>>     traj_plot_info = behavior_analysis.plot_trajectory(["nose"], events = overlap_6_events)
        >>>     return overlap_6_events， traj_plot_info
        >>> # plot the trajectory of the animal in the first 1000 frames
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     traj_plot_info = behavior_analysis.plot_trajectory(["all"], frame_range = range(1000))
        >>>     return traj_plot_info      

        """
        return plot_trajectory(bodyparts, events)
    

    def superanimal_video_inference(self) -> None:
        """
        Examples
        --------
        >>> # extract keypoints (aka pose) from the video file
        >>> def task_program():
	    >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     resultfile = behavior_analysis.superanimal_video_inference()
        >>>     return resultfile
        """
        return superanimal_video_inference()
    
    # do not call animals_social_events if there is no multiple animals or social events
    def animals_social_events(self, 
                                inter_individual_animal_state_query_list = [], 
                                inter_individual_animal_state_comparison_list = [],
                                individual_individual_animal_state_query_list = [],
                                individual_individual_animal_state_comparison_list = [],
                                bodyparts = ['all'],
                                otheranimal_bodyparts = ['all'],
                                min_window = 11, 
                                pixels_per_cm = 8,
                                smooth_window_size = 5) -> AnimalAnimalEvent:        
        """
        The function is specifically for capturing multiple animals social events        
        Parameters
        ----------        
        inter_individual_animal_state_query_list: List[str], optional
        list of queries describing relative states among animals. The valid relative states can be and only can be any subset of the following ['to_left', 'to_right', 'to_below', 'to_above', 'overlap', 'distance', 'relative_speed', 'orientation', 'closest_distance', 'relative_angle', 'relative_head_angle']. Note distance and closest_distance are distinct queries. Also relative_angle and relative_head_angle are distinct queries.        
        inter_individual_animal_state_comparison_list: List[str], optional
	    This list consists of comparison operators such as '==', '<', '>', '<=', '>='. Every comparison operator uniquely corresponds to an item in relation_query_list.
        IMPORTANT: the `inter_individual_animal_state_comparison_list[i]` and `inter_individual_animal_state_query_list[i]` should correspond to each other. Also the length of two lists should be the same
        Note the same relative state can appear twice to define a range.  For example, we can have ['distance', 'distance'] correspond to comparison list ['<100', '>20']
        individual_individual_animal_state_query_list: List[str], optional
	    This is a list of queries related to the individual animal. The valid animal states can be any subset of the following ['speed', 'acceleration', 'confidence'] or empty list []
        individual_individual_animal_state_comparison_list: List[str], optional
        list of comparison operator such as '==', '<', '>', '<=', '>='. individual_individual_animal_state_comparison_list[i] should correspond to individual_individual_animal_state_query_list[i] or empty list []
        bodyparts: List[str], optional
        list of bodyparts for the this animal. Length of the list must be 1
        otheranimal_bodyparts: list[str], optional
        list of bodyparts for the other animals. Length of the list must be 1
        min_window: int, optional
        Only include events that are longer than min_window
        pixels_per_cm: int, optional
        how many pixels for 1 centimer
        smooth_window_size: int, optional
        smooth window size for smoothing the events.
        Returns
        -------        
        AnimalAnimalEvent: 
        Examples 
        --------
        >>> # Define <|run_after|> as a social behavior where distance between animals are less than 100, one animal is in front of the other animal and the target animal has speed faster than 3. Get events for run_after.
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     # note we use individual_animal_state_query_list for speed here because its describing independent state of animals instead of relative states between animals
        >>>     run_after_social_events = behavior_analysis.animals_social_events(inter_individual_animal_state_query_list = ['distance', 'orientation'], 
        >>>         inter_individual_animal_state_comparison_list = [f'< 100', f'=={Orientation.FRONT}'],
        >>>         individual_animal_state_query_list = ['speed'],
        >>>         individual_animal_state_comparison_list=['>=3'],
        >>>         bodyparts = ['all'],
        >>>         otheranimal_bodyparts = ['all'])
        >>>     return run_after_social_events 
        >>> # Define <|far_away|> as a social behavior where distance between animals is larger than 20 pixels and smaller than 100 pixels and head angle less than 15
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     far_away_social_events = behavior_analysis.animals_social_events(inter_individual_animal_state_query_list = ['distance', 'distance', 'relative_head_angle'],
        >>>         inter_individual_animal_state_comparison_list = ['> 20', '< 100', '<15'],
        >>>         individual_animal_state_query_list = []
        >>>         individual_animal_state_comparison_list= [],
        >>>         bodyparts = ['all'],
        >>>         otheranimal_bodyparts = ['all'])
        >>>     return far_away_social_events 
        """
        return animals_social_events(inter_individual_animal_state_query_list, 
                                     inter_individual_animal_state_comparison_list,
                                     individual_animal_state_query_list = individual_animal_state_query_list,
                                     individual_animal_state_comparison_list = individual_animal_state_comparison_list)
                                     
    def plot_rois(self):
        """
        >>> # plot rois
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     plot_info = behavior_analysis.plot_rois()
        >>>     return plot_info
        """                  
        return plot_rois()
    def animals_object_events(self,
                                object_name: str, 
                                relation_query, 
                                comparison = None, 
                                negate = False, 
                                bodyparts: List[str] = ['all'],                                
                                min_window = 0,                                                     
                                ) -> AnimalEvent:
        """ 
        This function is only used when there is object with name involved in the queries.
               
        object_name : str
        This parameter represents the name of the object of interest. It is expected to be a string.
        The accepted naming conventions include numeric strings (e.g., '0', '1', '2', ...), or
        the prefix 'ROI' followed by a number (e.g., 'ROI0', 'ROI1', ...). 
        relation_query: str. Must be one of ['to_left', 'to_right', 'to_below', 'to_above', 'overlap', 'distance', 'angle', 'orientation']
        comparison : str, Must be a comparison operator followed by a number like <50, optional
        bodyparts: List[str], optional
           bodyparts of the animal
        min_window: min length of the event to include
        max_window: max length of the event to include
        negate: bool, default false
           whether to negate the spatial events. For example, if negate is set True, inside roi would be outside roi       
        Examples
        --------
        >>> # find events where the animal is to the left of object 6  
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     left_to_object6_events = behavior_analysis.animals_object_events('6', 'to_left',  bodyparts = ['all'])
        >>>     return left_to_object6_events        
        """
        return animals_object_events(object_name, relation_query)

  
    
class Orientation(IntEnum):
   """
   Attributes
   ----------
   FRONT
   BACK
   LEFT
   RIGHT
   """            
class Event: 
    """
    
    Attributes
    ----------
    mask: -> np.array, dtype bool
        array that presents whether the event is met or not
    duration: -> float
        number of seconds for that event. The unit is seconds     
    Methods
    -------
    add_simultaneous_events
    add_sequential_events
    count_bouts(AnimalEvent) -> str
       description of count number of bouts in the events
    """
    @classmethod
    def count_bouts(cls, events: AnimalEvent):
        """                
        >>> # Count the number of bouts for the events where the the animal stays in ROI0
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis() 
        >>>     overlap_roi0_events = behavior_analysis.animals_object_events('ROI0', 'overlap', bodyparts = ['all'])
        >>>     overlap_roi0_events_bouts = Event.count_bouts(overlap_roi0_events)
        >>>     return overlap_roi0_events_bouts
        """
        return count_bouts(events)
    @classmethod
    def add_simultaneous_events(cls, *events_list: List[dict]):
        """       
        Parameters
        ----------
        events_list: List[dict] 
        Returns
        -------
        AnimalEvent(dict): A actionary containing animal name and List[Event]
            - key: name of the animal
            - value: list of Event
        Examples
        --------
        >>> # get events for the animal's nose overlaps the roi0, left eye overlaps the roi0 and tail base not overlap the roi0     
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()
        >>>     nose_left_eye_in_roi0_events =  behavior_analysis.animals_object_events('ROI0', 'overlap', bodyparts = ['nose', 'left_eye'], negate = False)        
        >>>     tail_base_not_in_roi0_events = behavior_analysis.animals_object_events('ROI0', 'overlap', bodyparts = ['tail base'], negate = True)
        >>>     return Event.add_simultaneous_events(nose_left_eye_in_roi0_events, tail_base_not_in_roi0_events)        
        """        
        return add_simultaneous_events(events_list)
    @classmethod
    def add_sequential_events(cls, 
                              *events_list: List[dict],                             
                              continuous = False)-> dict:
        """
        Keywords such as "then", "later" might suggest it is a sequential event
        Parameters
        ----------
        events_list: List[dict]      
        Returns
        -------
        EventList: list of events
        Examples
        --------
        >>> # Get events where the animal enters object 6 from the bottom.
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()               
        >>>     bottom_to_object_6_events = behavior_analysis.animals_object_events('6','to_below', bodyparts = ['all'])
        >>>     enter_object_6_events = behavior_analysis.enter_object('6', bodyparts = ['all'])
        >>>     enter_object_6_from_bottom_events = Event.add_sequential_events(bottom_to_object_6_events, enter_object_6_events)
        >>>     return enter_object_6_from_bottom_events
        >>> # Find events where animal leaves object 6
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()                
        >>>     leave_object_6_events = behavior_analysis.leave_object('6', bodyparts = ['all'])
        >>>     return leave_object_6_events
        >>> # Find events where the animal leaves object 6 then enters object 3
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()     
        >>>     leave_object_6_events = behavior_analysis.leave_object('6', bodyparts = ['all'])        
        >>>     enter_object_3_events = behavior_analysis.enter_object('3', bodyparts = ['all'])      
        >>>     leave_object_6_and_enter_object_3_events = Event.add_sequential_events(leave_object_6_events, enter_object_3_events)
        >>>     return leave_object_6_and_enter_object_3_events
        >>> # find events where the animal moves from left of object 6 to right of object 6. 
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()     
        >>>     left_to_object_6_events = behavior_analysis.animals_object_events('6','to_left', bodyparts = ['all'])
        >>>     right_to_object_6_events = behavior_analysis.animals_object_events('6','to_right', bodyparts = ['all'])       
        >>>     left_to_right_events = Event.add_sequential_events(left_to_object_6_events, right_to_object_6_events)       
        >>>     return left_to_right_events
        >>> # find events where the animal moves from object 6 to object 12. 
        >>> def task_program():
        >>>     behavior_analysis = AnimalBehaviorAnalysis()     
        >>>     on_object_6_events = behavior_analysis.animals_object_events('6','overlap', bodyparts = ['all'])
        >>>     on_object_12_events = behavior_analysis.animals_object_events('12','overlap', bodyparts = ['all'])       
        >>>     on_object6_to_object12_events = Event.add_sequential_events(on_object_6_events, on_object_12_events)
        >>>     return on_object6_to_object12_events
        """
        return add_sequential_events(events_list)

