#!/usr/bin/env python3
"""test seqrepo for file descriptor exhaustion, especially in threading context

https://github.com/biocommons/biocommons.seqrepo/issues/112

The idea: read a bunch of NMs on stdin. Fetch the sequence for each in a threading context.

"""

import argparse
import logging
import queue
import pathlib
import random
import threading
import time

from smart_open import open

from biocommons.seqrepo import SeqRepo

_logger = logging.getLogger()


class Worker(threading.Thread):
    def __init__(self, q: queue.Queue, sr: SeqRepo):
        self.q = q
        self.sr = sr
        self.n = 0
        super().__init__()

    def run(self):
        try:
            while True:
                ac = self.q.get(False)
                sr.fetch(ac, 0, 5)
                self.q.task_done()
                self.n += 1
        except queue.Empty:
            _logger.info(f"{self}: Done; processed {self.n} accessions")
            return


def parse_args():
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument("-n", "--n-threads", type=int, default=1)
    ap.add_argument("-s", "--seqrepo-path", type=pathlib.Path, required=True)
    ap.add_argument("-m", "--max-accessions", type=int)
    ap.add_argument("-f", "--fd-cache-size", type=int, default=0)
    opts = ap.parse_args()
    return opts

if __name__ == "__main__":
    import coloredlogs
    import sys

    coloredlogs.install(level="INFO")

    opts = parse_args()    
 
    sr = SeqRepo(opts.seqrepo_path, fd_cache_size=opts.fd_cache_size)
    
    acs = set(a["alias"] for a in sr.aliases.find_aliases(namespace="RefSeq", alias="NM_%"))
    acs = random.sample(sorted(acs), opts.max_accessions or len(acs))
    q = queue.Queue()
    for ac in acs:
        q.put(ac)
    qs = q.qsize()
    _logger.info(f"Queued {qs} accessions")
    
    _logger.info(f"Starting run with {opts.n_threads} threads")
    t0 = time.process_time()
    for _ in range(opts.n_threads):
        Worker(q=q, sr=sr).start()
    q.join()
    t1 = time.process_time()
    td = t1 - t0
    rate = float(qs) / td
    _logger.info(f"Fetched {qs} sequences in {td} s with {opts.n_threads} threads; {rate:.0f} seq/sec")
    
    if hasattr(sr.sequences._open_for_reading, "cache_info"):
        print(sr.sequences._open_for_reading.cache_info())
    