#!python

import argparse
import os
import shutil
import subprocess
import sys
from importlib.util import find_spec

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--install_3p",
        action="store_true",
        help="Install/update third party dependencies (Speechbrain and WhisperX)",
    )
    args = parser.parse_args()
    anchor_found = find_spec("anchor") is not None
    speechbrain_found = find_spec("speechbrain") is not None
    whisperx_found = find_spec("whisperx") is not None

    conda_path = shutil.which("conda")
    if conda_path is None:
        print("Please install conda before running this command.")
        sys.exit(1)
    mamba_path = shutil.which("mamba")
    if mamba_path is None:
        print("No mamba found, installing first...")
        subprocess.call(
            [conda_path, "install", "-c", "conda-forge", "-y", "mamba"], env=os.environ
        )
    package_list = ["montreal-forced-aligner", "kalpy", "kaldi=*=cpu*"]
    if anchor_found:
        package_list.append("anchor-annotator")
    subprocess.call(
        [mamba_path, "update", "-c", "conda-forge", "-y"] + package_list, env=os.environ
    )
    if args.install_3p:
        channels = ["conda-forge", "pytorch", "nvidia", "anaconda"]
        package_list = ["pytorch", "torchaudio"]
        if not whisperx_found:
            package_list.extend(["cudnn=8", "transformers"])
        command = [mamba_path, "install", "-y"]
        for c in channels:
            command.extend(["-c", c])
        command += package_list
        subprocess.call(command, env=os.environ)
        command = ["pip", "install", "-U"]
        package_list = ["whisperx", "speechbrain", "pygtrie"]
        subprocess.call(command, env=os.environ)
