#!/usr/bin/env python3
"""Trains up a residual model to remove an uninteresting signal from an experiment.

BNF
---

.. highlight:: none

.. literalinclude:: ../../doc/bnf/trainCombinedModel.bnf


Parameter Notes
---------------
Most of the parameters for the combined model are the same as for a solo
model, and they are described at
:py:mod:`trainSoloModel<bpreveal.trainSoloModel>`.

use-bias-counts
    Selects if you want to add the counts prediction from the transformation
    model, and the appropriateness of this flag will depend on the nature of
    your bias. If the bias is a constant background signal, then it makes sense
    to subtract the bias contribution to the counts. However, if your bias is
    multiplied by the underlying biology, then you probably shouldn't add in
    the bias counts since they won't affect the actual experiment.

transformation-model-file
    The name of the Keras model file generated by
    :py:mod:`trainTransformationModel<bpreveal.trainTransformationModel>`.

    .. note:
        As of BPReveal 5.0, the extension of model files changed from ``.model``
        to ``.keras``.

input-length
    The input size of the *residual* model, not the *solo* model. The solo
    model, having already been created, knows its own input length. If the solo
    model's input length is smaller than the ``input-length`` setting in this
    config file, the sequence input to the solo model will automatically be
    cropped down to match.

HISTORY
-------

Before BPReveal 3.0.0, the solo model had to have the same input length as
the residual model. An auto-cropdown feature was implemented by Melanie Weilert
to remove this restriction.

API
---
"""
import sys
import bpreveal.schema
import bpreveal.internal.disableTensorflowLogging  # pylint: disable=unused-import # noqa
from bpreveal import utils
if __name__ == "__main__":
    utils.setMemoryGrowth()
import keras  # pylint: disable=wrong-import-order
import bpreveal.training
from bpreveal import models
from bpreveal import logUtils
from bpreveal.internal import interpreter
# pylint: disable=duplicate-code


def trainCombinedModel(config: dict) -> None:
    """Build and train a combined model.

    :param config: A config dict, per the spec
    """
    logUtils.setVerbosity(config["verbosity"])
    logUtils.debug("Initializing")
    inputLength = config["settings"]["architecture"]["input-length"]
    outputLength = config["settings"]["architecture"]["output-length"]
    regressionModel = utils.loadModel(
        config["settings"]["transformation-model"]["transformation-model-file"])
    regressionModel.trainable = False
    regressionModel.compile()
    regressionModel.trainable = False
    logUtils.debug("Loaded regression model.")
    combinedModel, residualModel, _ = models.combinedModel(
        inputLength, outputLength,
        config["settings"]["architecture"]["filters"],
        config["settings"]["architecture"]["layers"],
        config["settings"]["architecture"]["input-filter-width"],
        config["settings"]["architecture"]["output-filter-width"],
        config["heads"], regressionModel)
    losses, lossWeights = bpreveal.training.buildLosses(config["heads"])

    residualModel.compile(
        optimizer=keras.optimizers.Adam(learning_rate=config["settings"]["learning-rate"]),
        loss=losses, loss_weights=lossWeights,
        metrics=losses
    )
    combinedModel.compile(
        optimizer=keras.optimizers.Adam(learning_rate=config["settings"]["learning-rate"]),
        loss=losses, loss_weights=lossWeights,
        metrics=losses
    )
    logUtils.debug("Models compiled.")
    bpreveal.training.trainWithGenerators(combinedModel, config, inputLength, outputLength)
    combinedModel.save(config["settings"]["output-prefix"] + "_combined" + ".keras")
    residualModel.save(config["settings"]["output-prefix"] + "_residual" + ".keras")
    logUtils.info("Training job completed successfully.")


def main() -> None:
    """A zero-argument wrapper around the main function."""
    configJson = interpreter.evalFile(sys.argv[1])
    assert isinstance(configJson, dict)
    bpreveal.schema.trainCombinedModel.validate(configJson)
    trainCombinedModel(configJson)


if __name__ == "__main__":
    main()
# Copyright 2022-2025 Charles McAnany. This file is part of BPReveal. BPReveal is free software: You can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. BPReveal is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with BPReveal. If not, see <https://www.gnu.org/licenses/>.  # noqa
