#!/usr/bin/env python
# ChromBPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import argparse

from bpnetlite.bpnet import BPNet
from bpnetlite.chrombpnet import ChromBPNet

from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci
from bpnetlite.attributions import calculate_attributions

import json

desc = """ChromBPNet is a neural network that builds off the original BPNet
	architecture by explicitly learning bias in the signal tracks themselves.
	Specifically, for ATAC-seq and DNAse-seq experiments, the cutting enzymes
	have a soft sequence bias (though this is much stronger for Tn5, the
	enzyme for ATAC-seq). Accordingly, ChromBPNet is a pair of neural networks
	where one models the bias explicitly and one models the accessibility
	explicitly. This tool provides functionality for training the combination
	of the bias model and accessibility model and making predictions using it.
	After training, the accessibility model can be used using the `bpnet`
	tool."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'bias' or 'fit' or 'predict'.", required=True, dest='cmd')

fit_parser = subparsers.add_parser("fit", help="Fit a ChromBPNet model.")
fit_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for training the model.")

predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained ChromBPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

bias_parser = subparsers.add_parser("bias", help="Fit a bias model.")
bias_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

###
# Default Parameters
###

default_fit_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 250,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,
	'bias_model': None,

	'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'validation_chroms': ['chr4', 'chr15', 'chr21'],
	'sequences': None,
	'loci': None,
	'signals': None,
	'random_state': None

}

default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'sequences': None,
	'loci': None,
	'model': None,
	'profile_filename': 'y_profile.npz',
	'count_filename': 'y_count.npz'
}

default_bias_parameters = {
	'n_filters': 256,
	'n_layers': 4,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 250,
	'validation_iter': 100,
	'lr': 0.001,
	'early_stopping': None,
	'alpha': 1,
	'beta': 0.5,
	'verbose': False,

	'min_counts': 0,
	'max_counts': None,

	'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'validation_chroms': ['chr4', 'chr15', 'chr21'],
	'sequences': None,
	'peaks': None,
	'negatives': None,
	'signals': None,
	'random_state': None

}

###
# Commands
###

# Pull the arguments
args = parser.parse_args()

# Fit ChromBPNet model
if args.cmd == "fit":
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	for parameter, value in default_fit_parameters.items():
		if parameter not in parameters:
			if value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)


	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	bias = torch.load(parameters['bias_model'], map_location='cpu').cuda().eval()
	accessibility = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_control_tracks=0, n_outputs=1,
		alpha=parameters['alpha'],
		name=parameters['name'] + '.accessibility',
		trimming=trimming).cuda()

	model = ChromBPNet(bias=bias, accessibility=accessibility,
		name=parameters['name'])

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	model.fit_generator(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

# Make predictions from the full ChromBPNet model
elif args.cmd == 'predict':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	for parameter, value in default_predict_parameters.items():
		if parameter not in parameters:
			if value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	).cuda()

	y_profiles, y_counts = model.predict(X, batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['count_filename'], y_counts)

# Fit the bias model
elif args.cmd == 'bias':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	for parameter, value in default_bias_parameters.items():
		if parameter not in parameters:
			if parameter != 'max_counts' and value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	if parameters['max_counts'] is None:
		_, train_signals = extract_loci(
			sequences=parameters['sequences'],
			signals=parameters['signals'],
			controls=None,
			loci=parameters['peaks'],
			chroms=parameters['training_chroms'],
			in_window=parameters['in_window'],
			out_window=parameters['out_window'],
			max_jitter=0,
			verbose=parameters['verbose']
		)

		parameters['max_counts'] = train_signals.sum(dim=-1).min()

	training_data = PeakGenerator(
		loci=parameters['negatives'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts']*parameters['beta'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		loci=parameters['negatives'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_outputs=1, n_control_tracks=0,
		alpha=parameters['alpha'], 
		name=parameters['name'],
		trimming=trimming).cuda()

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'],
		early_stopping=parameters['early_stopping'],
		batch_size=parameters['batch_size'])
