# -----------------------------------------------------------------------------
# Copyright (c) 2018 The Regents of the University of California
#
# This file is part of kevlar (http://github.com/dib-lab/kevlar) and is
# licensed under the MIT license: see LICENSE.
# -----------------------------------------------------------------------------
import os


# -----------------------------------------------------------------------------
# Ancillary code for parsing the sample configuration and organizing files in
# the working directory.
# -----------------------------------------------------------------------------

def unpack(list_of_lists):
    """Flatten a list of lists into a single list."""
    return [item for sublist in list_of_lists for item in sublist]


class SimplexDesign(object):
    """Class for managing files related to samples from a simplex case.

    A simplex case is an isolated manifestation of a condition in a family. The
    sample of interest is designated the "case" sample (sometimes referred to
    as the "proband" or "focal sample"), and all other samples (parents,
    siblings) are control samples.

    This class manages files associated with each sample. It creates symlinks
    to input files and creates standardized filenames to facilitate processing
    in a designated working directory. The workflow enforces the following
    constraints.

        - There is a single case sample.
        - There are 1 or more control samples.
        - There are 1 or more Fasta/Fastq files associated with each sample.
    """
    def __init__(self, config):
        self._config = config
        self.case_seqfiles = config['samples']['case']['fastx']
        self.control_seqfiles = unpack(
            [cfg['fastx'] for cfg in config['samples']['controls']]
        )
        self.case_seqfiles_internal = list(
            self.internal_filenames(self.case_seqfiles, sampleclass='case')
        )
        self.control_seqfiles_internal = list()
        self.control_seqfiles_internal_flat = list()
        for n, cfg in enumerate(config['samples']['controls']):
            seqlist = list(
                self.internal_filenames(
                    cfg['fastx'], sampleclass='ctrl', index=n
                )
            )
            self.control_seqfiles_internal.append(seqlist)
            self.control_seqfiles_internal_flat.extend(seqlist)

    @property
    def numcontrols(self):
        return len(self._config['samples']['controls'])

    def internal_filenames(self, filelist, sampleclass='case', index=0):
        """Designate a standardized filename for each input for internal use.

        If the input filename extension is recognized, it is retained.
        Otherwise, the internal file has no extension.
        """
        extensions = ['fastq', 'fq', 'fasta' 'fa', 'fna', 'gz']
        gzextensions = [ext + '.gz' for ext in extensions]
        theclass = sampleclass
        if sampleclass != 'case':
            assert sampleclass == 'ctrl'
            theclass += str(index)
        for n, filename in enumerate(filelist):
            for ext in gzextensions + extensions:
                if filename.endswith(ext):
                    newfilename = 'Reads/{cls}.inseq.{n}.{ext}'.format(
                        cls=theclass, n=n, ext=ext
                    )
                    yield newfilename
                    break
            else:
                newfilename = 'Reads/{cls}.inseq.{n}'.format(cls=theclass, n=n)
                yield newfilename

    @property
    def refr_files(self):
        def genfiles():
            name = os.path.basename(self._config['reference']['fasta'])
            yield os.path.join('Reference', name)
            for ext in ['amb', 'ann', 'bwt', 'pac', 'sa']:
                yield os.path.join('Reference', name + '.' + ext)
        return list(genfiles())

    @property
    def mask_file(self):
        return 'Mask/mask.nodetable'

    @property
    def mask_input_files(self):
        def genfiles():
            for infile in self._config['mask']['fastx']:
                name = os.path.basename(infile)
                yield os.path.join('Mask', name)
        return list(genfiles())

    @property
    def case_sample_counts(self):
        return 'Sketches/case-counts.counttable'

    @property
    def control_sample_counts(self):
        pattern = 'Sketches/ctrl{i}-counts.counttable'
        return [pattern.format(i=i) for i in range(self.numcontrols)]

    @property
    def all_sample_counts(self):
        return [self.case_sample_counts] + self.control_sample_counts

    @property
    def sample_labels(self):
        def getlabels():
            yield self._config['samples']['case']['label']
            for control in self._config['samples']['controls']:
                yield control['label']
        return list(getlabels())

    @property
    def varfilter_input(self):
        infiles = list()
        if self._config['varfilter']:
            infiles.append(self._config['varfilter'])
        return infiles + expand('calls.{num}.prelim.vcf.gz', num=range(config['numsplit']))

    @property
    def simlike_input(self):
        if self._config['varfilter']:
            return ['calls.filtered.vcf.gz']
        else:
            return expand('calls.{num}.prelim.vcf.gz', num=range(config['numsplit']))


simplex = SimplexDesign(config)


# -----------------------------------------------------------------------------
# Create standardized named links to input Fasta/Fastq files for internal use.
# -----------------------------------------------------------------------------

rule link_input_seqs:
    input:
        simplex.case_seqfiles,
        simplex.control_seqfiles,
    output:
        simplex.case_seqfiles_internal,
        simplex.control_seqfiles_internal_flat,
    threads: 1
    message: 'Create internal links for sample sequence data.'
    run:
        for oldseq, newseq in zip(input, output):
            os.symlink(os.path.abspath(oldseq), newseq)


rule link_reference:
    input: config['reference']['fasta']
    output: simplex.refr_files
    threads: 1
    message: 'Create internal links for reference genome, and index if needed.'
    run:
        os.symlink(os.path.abspath(input[0]), output[0])
        extensions = ['amb', 'ann', 'bwt', 'pac', 'sa']
        for ext, outfile in zip(extensions, output[1:]):
            testinfile = os.path.abspath(input[0]) + '.' + ext
            if os.path.isfile(testinfile):
                os.symlink(testinfile, outfile)
            else:
                shell('bwa index {output[0]}')
                break


rule link_mask:
    input: config['mask']['fastx']
    output: simplex.mask_input_files
    threads: 1
    message: 'Create internal links for mask sequence data.'
    run:
        for infile, outfile in zip(input, output):
            os.symlink(os.path.abspath(infile), outfile)


# -----------------------------------------------------------------------------
# Create sketches for the reference, mask, and each sample.
# -----------------------------------------------------------------------------

rule create_mask:
    input: simplex.mask_input_files
    output:
        simplex.mask_file,
        'Logs/mask.log'
    threads: 32
    message: 'Generate a mask of sequences to ignore while k-mer counting.'
    shell: 'kevlar --tee --logfile {output[1]} count --ksize {config[ksize]} --counter-size 1 --memory {config[mask][memory]} --max-fpr {config[mask][max_fpr]} --threads {threads} {output[0]} {input}'


rule count_reference:
    input: simplex.refr_files[0]
    output:
        'Reference/refr-counts.smallcounttable',
        'Logs/refrcount.log'
    threads: 32
    message: 'Count k-mers in the reference genome.'
    shell: 'kevlar --tee --logfile {output[1]} count --ksize {config[ksize]} --counter-size 4 --memory {config[reference][memory]} --max-fpr {config[reference][max_fpr]} --threads {threads} {output[0]} {input}'


rule count_all_samples:
    input: simplex.all_sample_counts
    output: touch('Sketches/counting.complete')


rule count_case:
    input:
        simplex.mask_file,
        simplex.case_seqfiles_internal
    output:
        simplex.case_sample_counts,
        'Logs/casecount.log'
    threads: 32
    message: 'Count k-mers in the case sample'
    shell: 'kevlar --tee --logfile {output[1]} count --ksize {config[ksize]} --memory {config[samples][case][memory]} --max-fpr {config[samples][case][max_fpr]} --threads {threads} {output[0]} --mask {input}'


rule count_control:
    input: lambda wildcards: [simplex.mask_file] + simplex.control_seqfiles_internal[int(wildcards.index)]
    output:
        'Sketches/ctrl{index}-counts.counttable',
        'Logs/ctrl{index}count.log'
    threads: 32
    message: 'Count k-mers in a control sample'
    run:
        mem = config['samples']['controls'][int(wildcards.index)]['memory']
        fpr = config['samples']['controls'][int(wildcards.index)]['max_fpr']
        seqfiles = ' '.join(input[1:])
        cmd = 'kevlar --tee --logfile {{output[1]}} count --ksize {{config[ksize]}} --memory {mem} --max-fpr {fpr} --mask {{input[0]}} --threads {{threads}} {{output[0]}} {seqs}'.format(mem=mem, fpr=fpr, seqs=seqfiles)
        shell(cmd)


# -----------------------------------------------------------------------------
# Find, filter, and organize reads containing novel k-mers.
# -----------------------------------------------------------------------------

rule novel:
    input:
        simplex.all_sample_counts,
        simplex.case_seqfiles_internal
    output:
        'NovelReads/novel.augfastq.gz',
        'Logs/novel.log'
    threads: 1
    message: 'Find reads containing k-mers that are abundant in the proband and low/zero abundance in each control.'
    run:
        caseconfig = '--case ' + ' '.join(simplex.case_seqfiles_internal)
        caseconfig += ' --case-counts Sketches/case-counts.counttable'
        caseconfig += ' --case-min ' + str(config['samples']['casemin'])
        ctrlconfig = ' --control-counts ' + ' '.join(simplex.control_sample_counts)
        ctrlconfig += ' --ctrl-max ' + str(config['samples']['ctrlmax'])
        cmd = [
            'kevlar', '--tee', '--logfile', output[1], 'novel',
            caseconfig, ctrlconfig,
            '--ksize', str(config['ksize']), '--threads', str(threads),
            '--max-fpr', str(config['samples']['case']['max_fpr']),
            '--out', output[0]
        ]
        shell(' '.join(cmd))


rule filter_novel:
    input: 'NovelReads/novel.augfastq.gz'
    output:
        'NovelReads/filtered.augfastq.gz',
        'Logs/filter.log'
    threads: 1
    message: 'Filter out k-mers that '
    shell: 'kevlar --tee --logfile {output[1]} filter --mask Mask/mask.nodetable --memory {config[recountmem]} --case-min {config[samples][casemin]} --ctrl-max {config[samples][ctrlmax]} --out {output[0]} {input}'


rule partition:
    input: 'NovelReads/filtered.augfastq.gz'
    output:
        'NovelReads/partitioned.augfastq.gz',
        'Logs/partition.log'
    threads: 1
    message: 'Partition reads by shared novel k-mers.'
    shell: 'kevlar --tee --logfile {output[1]} partition --min-abund {config[samples][casemin]} --out {output[0]} {input}'


rule split:
    input: 'NovelReads/partitioned.augfastq.gz'
    output: expand('NovelReads/partitioned.{num}.augfastx.gz', num=range(config['numsplit']))
    threads: 1
    shell: 'kevlar split {input} {config[numsplit]} NovelReads/partitioned'


# -----------------------------------------------------------------------------
# Assemble, localize, align, make preliminary calls, and score/sort calls.
# -----------------------------------------------------------------------------

rule assemble:
    input: 'NovelReads/partitioned.{num}.augfastx.gz'
    output:
        'contigs.{num}.augfasta.gz',
        'Logs/assemble.{num}.log'
    threads: 1
    message: 'Assemble reads into contigs partition by partition.'
    shell: 'kevlar --tee --logfile {output[1]} assemble --out {output[0]} {input}'


rule localize:
    input:
        simplex.refr_files[0],
        expand('contigs.{num}.augfasta.gz', num=range(config['numsplit'])),
    output:
        'reference.gdnas.fa.gz',
        'Logs/localize.log'
    threads: 1
    message: 'Compute reference target sequences for each partition.'
    shell: 'kevlar --tee --logfile {output[1]} localize --delta {config[localize][delta]} --seed-size {config[localize][seedsize]} --include \'{config[localize][seqpattern]}\' --out {output[0]} {input}'


rule call:
    input:
        'contigs.{num}.augfasta.gz',
        'reference.gdnas.fa.gz',
        simplex.refr_files[0]
    output:
        'calls.{num}.prelim.vcf.gz',
        'Logs/call.{num}.log'
    threads: 1
    message: 'Align contigs to reference targets and make preliminary calls.'
    shell: 'kevlar --tee --logfile {output[1]} call --mask-mem 10M --refr {input[2]} --ksize {config[ksize]} --out {output[0]} {input[0]} {input[1]}'


rule varfilter:
    input: simplex.varfilter_input
    output:
        'calls.filtered.vcf.gz',
        'Logs/varfilter.log'
    threads: 1
    message: 'Filter out calls from the provided regions (usually common variants, repetitive DNA, etc)'
    shell: 'kevlar --tee --logfile {output[1]} varfilter --out {output[0]} {input}'


rule like_scores:
    input:
        'Reference/refr-counts.smallcounttable',
        simplex.all_sample_counts,
        simplex.simlike_input
    output:
        'calls.scored.sorted.vcf.gz',
        'Logs/simlike.log'
    threads: 1
    message: 'Filter calls, compute likelihood scores, and sort calls by score.'
    run:
        in_vcfs = simplex.simlike_input
        cmd = [
            'kevlar', '--tee', '--logfile', output[1], 'simlike',
            '--mu', config['samples']['coverage']['mean'],
            '--sigma', config['samples']['coverage']['stdev'],
            '--epsilon', '0.001', '--case-min', config['samples']['casemin'],
            '--refr', 'Reference/refr-counts.smallcounttable',
            '--sample-labels', *simplex.sample_labels, '--out', output[0],
            '--controls', *simplex.control_sample_counts,
            '--case', simplex.case_sample_counts, *in_vcfs
        ]
        cmd = [str(arg) for arg in cmd]
        cmd = ' '.join(cmd)
        shell(cmd)


rule calls:
    input: 'calls.scored.sorted.vcf.gz'
    output: touch('complete')
