#!/usr/bin/env python3
"""A little utility to add some noise to training data.

.. warning::
    This tool will be removed in BPReveal 6.0.
    It turns out that it's not very useful.

BNF
---

.. highlight:: none

.. literalinclude:: ../../doc/bnf/addNoise.bnf

Parameter notes
---------------

input-h5
    The name of the hdf5 file generated by
    :py:mod:`prepareTrainingData<bpreveal.prepareTrainingData>`
output-h5
    The name of the output file that will be generated.
num-output-samples
    How many training examples do you want in the output file?
    Mutually exclusive with output-size-ratio
output-size-ratio
    How many times larger should the output be than the input?
keep-original-data
    Should the original data be kept in the output, or just
    the altered data?
sequence-mutation-fraction
    What fraction of the input bases should be randomly mutated?
    For example, 0.05 means that one in twenty bases will be randomly
    mutated to a different base.
profile-mutation-types
    A list of the types of mutation you want to apply to the profile
    outputs.

profile-mutation-fraction
    What fraction of the output bases should be mutated?
"""
import json
import multiprocessing
import h5py
import numpy as np
from bpreveal import logUtils
from bpreveal.tools import addNoiseUtils


def main(config: dict):
    """Run the program.

    :param config: The configuration json.
    """
    logUtils.setVerbosity(config["verbosity"])
    inH5 = h5py.File(config["input-h5"], "r")
    numHeads = 0
    tasksPerHead = []
    while f"head_{numHeads}" in inH5:
        tasksPerHead.append(inH5[f"head_{numHeads}"].shape[2])
        numHeads += 1
    # Note that this is not actually the model input length:
    # This dataset has shape (input-length + 2*jitter)
    keepOriginal = config["keep-original-data"]

    inputLength = inH5["sequence"].shape[1]
    outputLength = inH5["head_0"].shape[1]
    numInputRegions = inH5["sequence"].shape[0]
    if "num-output-samples" in config:
        numOutputRegions = config["num-output-samples"]
    else:
        numOutputRegions = int(numInputRegions * config["output-size-ratio"])
    logUtils.debug("Setup complete. Creating datasets.")
    outputSequences = np.empty((numOutputRegions, inputLength, 4))
    outputHeads = []
    for numTasks in tasksPerHead:
        outputHeads.append(
            np.empty((numOutputRegions, outputLength, numTasks)))
    if keepOriginal:
        logUtils.debug("Copying over original data.")
        assert numOutputRegions > numInputRegions, \
            "Cannot keep old regions with an output file smaller than the input."
        outputSequences[:numInputRegions] = np.array(inH5["sequence"])
        for i in range(numHeads):
            outputHeads[i][:numInputRegions] = np.array(
                inH5[f"head_{i}"])
        writeHead = numInputRegions
    else:
        writeHead = 0
    logUtils.debug("Starting to write mutated regions.")
    # Now, generate the randomized data.
    args = [(json.dumps(config), tasksPerHead, i) for i in range(writeHead, numOutputRegions)]
    wrappedArgs = logUtils.wrapTqdm(args, "INFO")
    if config["num-threads"] > 1:
        with multiprocessing.Pool(config["num-threads"], initializer=addNoiseUtils.loadFile,
                                  initargs=[config["input-h5"], numHeads]) as p:
            mutatedSamples = list(p.imap(addNoiseUtils.gmstar, wrappedArgs))
    else:
        addNoiseUtils.loadFile(config["input-h5"], numHeads)
        mutatedSamples = []
        for arg in wrappedArgs:
            mutatedSamples.append(addNoiseUtils.gmstar(arg))
    for outIdx, outPos in enumerate(range(writeHead, numOutputRegions)):
        mutatedSeq, mutatedHeadDats = mutatedSamples[outIdx]
        outputSequences[outPos] = mutatedSeq
        for i in range(numHeads):
            outputHeads[i][outPos] = mutatedHeadDats[i]
    logUtils.debug("Writing output file.")
    addNoiseUtils.writeOutput(config["output-h5"], outputSequences, outputHeads)
    logUtils.debug("Done with adding noise.")


if __name__ == "__main__":
    import sys
    with open(sys.argv[1], "r") as configFp:
        configJson = json.load(configFp)
    main(configJson)
# Copyright 2022-2025 Charles McAnany. This file is part of BPReveal. BPReveal is free software: You can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. BPReveal is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with BPReveal. If not, see <https://www.gnu.org/licenses/>.  # noqa
