"""Modalities"""
{% raw -%}
from pydantic import BaseModel, Field, ConfigDict
from typing import Literal, Union
from typing_extensions import Annotated
from enum import IntEnum
from aind_data_schema_models.pid_names import BaseName
{% endraw %}

class _ModalityModel(BaseName):
    """Base model for modality"""
    model_config = ConfigDict(frozen=True)
    name: str
    abbreviation: str

{% for _, row in data.iterrows() %}
class {{ row['abbreviation'] | to_class_name_underscored }}(_ModalityModel):
    """Model {{row['abbreviation']}}"""
    name: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    abbreviation: Literal["{{ row['abbreviation'] }}"] = "{{ row['abbreviation'] }}"

{% endfor %}
class Modality:
    """Modalities"""
{% for _, row in data.iterrows() %}
    {{ row['abbreviation'] | to_class_name | upper }} = {{ row['abbreviation'] | to_class_name_underscored }}()
{%- endfor %}

    ALL = tuple(_ModalityModel.__subclasses__())

    ONE_OF = Annotated[Union[tuple(_ModalityModel.__subclasses__())], Field(discriminator="abbreviation")]

    abbreviation_map = {m().abbreviation: m() for m in ALL}

    @classmethod
    def from_abbreviation(cls, abbreviation: str):
        """Get modality from abbreviation"""
        return cls.abbreviation_map.get(abbreviation, None)


class FileRequirement(IntEnum):
    """Whether a file is required for a specific modality"""

    REQUIRED = 1
    OPTIONAL = 0
    EXCLUDED = -1


class _ExpectedFilesModel(BaseModel):
    """Base model for modality"""
    model_config = ConfigDict(frozen=True)
    name: str
    modality_abbreviation: str
    subject: FileRequirement
    data_description: FileRequirement
    procedures: FileRequirement
    session: FileRequirement
    rig: FileRequirement
    processing: FileRequirement
    acquisition: FileRequirement
    instrument: FileRequirement
    quality_control: FileRequirement

{% for _, row in data.iterrows() %}
class {{ row['abbreviation'] | to_class_name_underscored }}_Files(_ExpectedFilesModel):
    """Model {{row['abbreviation']}}_Files"""
    name: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    modality_abbreviation: Literal["{{ row['abbreviation'] }}"] = "{{ row['abbreviation'] }}"
    subject: FileRequirement = {% if row['subject'] == 1 %} FileRequirement.REQUIRED {% elif row['subject'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    data_description: FileRequirement = {% if row['data_description'] == 1 %} FileRequirement.REQUIRED {% elif row['data_description'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    procedures: FileRequirement = {% if row['procedures'] == 1 %} FileRequirement.REQUIRED {% elif row['procedures'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    session: FileRequirement = {% if row['session'] == 1 %} FileRequirement.REQUIRED {% elif row['session'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    rig: FileRequirement = {% if row['rig'] == 1 %} FileRequirement.REQUIRED {% elif row['rig'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    processing: FileRequirement = {% if row['processing'] == 1 %} FileRequirement.REQUIRED {% elif row['processing'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    acquisition: FileRequirement = {% if row['acquisition'] == 1 %} FileRequirement.REQUIRED {% elif row['acquisition'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    instrument: FileRequirement = {% if row['instrument'] == 1 %} FileRequirement.REQUIRED {% elif row['instrument'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}
    quality_control: FileRequirement = {% if row['quality_control'] == 1 %} FileRequirement.REQUIRED {% elif row['quality_control'] == 0 %} FileRequirement.OPTIONAL {% else %} FileRequirement.EXCLUDED {% endif %}

{% endfor %}
class ExpectedFiles:
    """Expected files for each modality"""
{% for _, row in data.iterrows() %}
    {{ row['abbreviation'] | to_class_name | upper }} = {{ row['abbreviation'] | to_class_name_underscored }}_Files()
{%- endfor %}

    ALL = tuple(_ExpectedFilesModel.__subclasses__())

    ONE_OF = Annotated[Union[tuple(_ExpectedFilesModel.__subclasses__())], Field(discriminator="abbreviation")]
