#!/usr/bin/env python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'

import sys
import numpy
import torch
import argparse
import subprocess


from bpnetlite.io import PeakGenerator

from bpnetlite.bpnet import BPNet
from bpnetlite.bpnet import CountWrapper
from bpnetlite.bpnet import ProfileWrapper
from bpnetlite.bpnet import ControlWrapper
from bpnetlite.bpnet import _ProfileLogitScaling

from bpnetlite.chrombpnet import ChromBPNet
from bpnetlite.chrombpnet import _Log, _Exp

from bpnetlite.marginalize import marginalization_report

from tangermeme.io import extract_loci
from tangermeme.match import extract_matching_loci
from tangermeme.predict import predict

from tangermeme.deep_lift_shap import _nonlinear
from tangermeme.deep_lift_shap import deep_lift_shap

import json

torch.backends.cudnn.benchmark = True


desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

_help = """Must be either 'negatives', 'fit', 'predict', 'attribute',
	'marginalize', or 'pipeline'."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd')

negatives_parser = subparsers.add_parser("negatives", 
	help="Sample GC-matched negatives.")
negatives_parser.add_argument("-i", "--peaks", required=True, 
	help="Peak bed file.")
negatives_parser.add_argument("-f", "--fasta", help="Genome FASTA file.")
negatives_parser.add_argument("-b", "--bigwig", help="Optional signal bigwig.")
negatives_parser.add_argument("-o", "--output", required=True, 
	help="Output bed file.")
negatives_parser.add_argument("-l", "--bin_width", type=float, default=0.02, 
	help="GC bin width to match.")
negatives_parser.add_argument("-n", "--max_n_perc", type=float, default=0.1, 
	help="Maximum percentage of Ns allowed in each locus.")
negatives_parser.add_argument("-a", "--beta", type=float, default=0.5, 
	help="Multiplier on the minimum counts in peaks.")
negatives_parser.add_argument("-w", "--in_window", type=int, default=2114, 
	help="Width for calculating GC content.")
negatives_parser.add_argument("-x", "--out_window", type=int, default=1000, 
	help="Non-overlapping stride to use for loci.")
negatives_parser.add_argument("-v", "--verbose", default=False, 
	action='store_true')

fit_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
fit_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

predict_parser = subparsers.add_parser("predict", 
	help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

attribute_parser = subparsers.add_parser("attribute", 
	help="Calculate attributions using a trained BPNet model.")
attribute_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

marginalize_parser = subparsers.add_parser("marginalize", 
	help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

pipeline_parser = subparsers.add_parser("pipeline", 
	help="Run each step on the given files.")
pipeline_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters used for each step.")


###
# Default Parameters
###

training_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 
	'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 
	'chr19', 'chr20', 'chr21', 'chr22', 'chrX']

validation_chroms = ['chr8', 'chr10']

default_fit_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,

	'training_chroms': training_chroms,
	'validation_chroms': validation_chroms,
	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,
	'random_state': None
}

default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'sequences': None,
	'loci': None,
	'controls': None,
	'model': None,
	'profile_filename': 'y_profile.npz',
	'counts_filename': 'y_counts.npz'
}

default_attribute_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'sequences': None,
	'loci': None,
	'model': None,
	'output': 'counts',
	'ohe_filename': 'ohe.npz',
	'attr_filename': 'attr.npz',
	'n_shuffles':20,
	'random_state':0,
	'warning_threshold':1e-4
}

default_marginalize_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'sequences': None,
	'motifs': None,
	'loci': None,
	'n_loci': None,
	'shuffle': False,
	'model': None,
	'output_filename':'marginalize/',
	'random_state':0,
	'minimal': True
}

default_pipeline_parameters = {
	# Model architecture parameters
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'in_window': 2114,
	'out_window': 1000,
	'name': None,
	'model': None,
	'verbose': False,

	# Data parameters
	'batch_size': 64,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,
	'min_counts': 0,
	'max_counts': 99999999,

	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,

	# Fit parameters
	'fit_parameters': {
		'batch_size': 64,
		'training_chroms': training_chroms,
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,	
		'verbose': None,
		'random_state': None,
	},

	# Predict parameters
	'predict_parameters': {
		'batch_size': 64,
		'chroms': validation_chroms,
		'profile_filename': None,
		'counts_filename': None,
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,
		'verbose': None,
	},


	# Attribution parameters
	'attribute_parameters': {
		'batch_size': 64,
		'chroms': validation_chroms,
		'output': 'counts',
		'loci': None,
		'ohe_filename': None,
		'attr_filename': None,
		'n_shuffles': None,
		'warning_threshold':1e-4,
		'random_state': None,
		'verbose': None

	},

	# Modisco parameters
	'modisco_motifs_parameters': {
		'n_seqlets': 100000,
		'output_filename': None,
		'verbose': None
	},

	# Modisco report parameters
	'modisco_report_parameters': {
		'motifs': None,
		'output_folder': None,
		'verbose': None
	},

	# Marginalization parameters
	'marginalize_parameters': {
		'loci': None,
		'n_loci': 100,
		'batch_size': 64,
		'shuffle': False,
		'random_state': None,
		'output_folder': None,
		'motifs': None,
		'minimal': True,
		'verbose': None
	} 
}


###
# Commands
###


def _extract_set(parameters, defaults, name):
	subparameters = {
		key: parameters.get(key, None) for key in defaults if key in parameters
	}
	
	for parameter, value in parameters[name].items():
		if value is not None:
			subparameters[parameter] = value

	return subparameters

def _check_set(parameters, parameter, value):
	if parameters.get(parameter, None) == None:
		parameters[parameter] = value


def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	with open(parameters, "r") as infile:
		parameters = json.load(infile)

	unset_parameters = ("controls", "warning_threshold")
	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter not in unset_parameters:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()


##########
# NEGATIVES
##########

if args.cmd == 'negatives':
	# Extract regions that match the GC content of the peaks
	matched_loci = extract_matching_loci(
		loci=args.peaks, 
		fasta=args.fasta,
		gc_bin_width=args.bin_width,
		max_n_perc=args.max_n_perc,
		bigwig=args.bigwig,
		signal_beta=args.beta,
		in_window=args.in_window,
		out_window=args.out_window,
		chroms=None,
		verbose=args.verbose
	)

	matched_loci.to_csv(args.output, header=False, sep='\t', index=False)


##########
# FIT
##########

if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_data = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		in_signals=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] is not None:
		valid_sequences, valid_signals, valid_controls = valid_data
		n_control_tracks = 2
	else:
		valid_sequences, valid_signals = valid_data
		valid_controls = None
		n_control_tracks = 0

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		n_outputs=len(parameters['signals']),
		n_control_tracks=n_control_tracks,
		profile_output_bias=parameters['profile_output_bias'],
		count_output_bias=parameters['count_output_bias'],
		alpha=parameters['alpha'],
		trimming=trimming,
		name=parameters['name'],
		verbose=parameters['verbose']).cuda()

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	if parameters['verbose']:
		print("Training Set Size: ", training_data.dataset.sequences.shape[0])
		print("Validation Set Size: ", valid_sequences.shape[0])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])


##########
# PREDICT
##########

elif args.cmd == 'predict':
	parameters = merge_parameters(args.parameters, default_predict_parameters)

	###

	model = torch.load(parameters['model']).cuda()

	examples = extract_loci(
		sequences=parameters['sequences'],
		in_signals=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl = examples

	if X_ctl is not None:
		X_ctl = (X_ctl,)

	y_profiles, y_counts = predict(model, X, args=X_ctl, 
		batch_size=parameters['batch_size'], device='cuda', 
		verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['counts_filename'], y_counts)


##########
# ATTRIBUTE
##########

elif args.cmd == 'attribute':
	parameters = merge_parameters(args.parameters, default_attribute_parameters)

	###

	model = torch.load(parameters['model']).cuda()

	dtype = torch.float32
	if parameters['output'] == 'profile' or isinstance(model, ChromBPNet):
		dtype = torch.float64

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	X = X[X.sum(dim=(1, 2)) == X.shape[-1]]

	model = ControlWrapper(model)
	if parameters['output'] == 'counts':
		wrapper = CountWrapper(model)
	elif parameters['output'] == 'profile':
		wrapper = ProfileWrapper(model)
	else:
		raise ValueError("output must be either `counts` or `profile`.")

	X_attr = deep_lift_shap(wrapper.type(dtype), X.type(dtype),
		hypothetical=True,
		additional_nonlinear_ops={
			_ProfileLogitScaling: _nonlinear,
			_Log: _nonlinear,
			_Exp: _nonlinear
		},
		n_shuffles=parameters['n_shuffles'],
		batch_size=parameters['batch_size'],
		random_state=parameters['random_state'],
		verbose=parameters['verbose'],
		warning_threshold=parameters['warning_threshold'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['attr_filename'], X_attr)


##########
# MARGINALIZE
##########

elif args.cmd == 'marginalize':
	parameters = merge_parameters(args.parameters, 
		default_marginalize_parameters)

	###

	model = torch.load(parameters['model']).cuda()
	model = ControlWrapper(model)

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		n_loci=parameters['n_loci'],
		verbose=parameters['verbose']
	).float()

	if parameters['shuffle'] == True:
		idxs = numpy.arange(X.shape[0])
		numpy.random.shuffle(idxs)
		X = X[idxs]

	if parameters['n_loci'] is not None:
		X = X[:parameters['n_loci']]

	marginalization_report(model, parameters['motifs'], X, 
		parameters['output_filename'], batch_size=parameters['batch_size'], 
		minimal=parameters['minimal'], verbose=parameters['verbose'])


##########
# PIPLEINE
##########

elif args.cmd == 'pipeline':
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)
	pname = parameters['name']

	# Step 1: Fit a BPNet model to the provided data
	if parameters['verbose']:
		print("Step 1: Fitting a BPNet model")

	fit_parameters = _extract_set(parameters, default_fit_parameters, 
		'fit_parameters')

	if parameters.get('model', None) == None:
		name = pname + '.bpnet.fit.json'
		parameters['model'] = pname + '.torch'

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		subprocess.run(["bpnet", "fit", "-p", name], check=True)

	
	# Step 2: Make predictions for the entire validation set
	if parameters['verbose']:
		print("\nStep 2: Making predictions")

	predict_parameters = _extract_set(parameters, 
		default_predict_parameters, 'predict_parameters')
	_check_set(predict_parameters, 'profile_filename', pname+'.y_profiles.npz')
	_check_set(predict_parameters, 'counts_filename', pname+'.y_counts.npz')

	name = '{}.bpnet.predict.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(predict_parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "predict", "-p", name], check=True)
	

	# Step 3: Calculate attributions
	if parameters['verbose']:
		print("\nStep 3: Calculating attributions")


	attribute_parameters = _extract_set(parameters, 
		default_attribute_parameters, 'attribute_parameters')
	_check_set(attribute_parameters, 'ohe_filename', pname+'.ohe.npz')
	_check_set(attribute_parameters, 'attr_filename', pname+'.attr.npz')

	name = '{}.bpnet.attribute.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(attribute_parameters, sort_keys=True, 
			indent=4))

	subprocess.run(["bpnet", "attribute", "-p", name], check=True)


	# Step 4: Calculate tf-modisco motifs
	if parameters['verbose']:
		print("\nStep 4: TF-MoDISco motifs")

	modisco_parameters = parameters['modisco_motifs_parameters']
	_check_set(modisco_parameters, "output_filename", 
		pname+'_modisco_results.h5')

	cmd = "modisco motifs -s {} -a {} -n {} -o {}".format(
		attribute_parameters['ohe_filename'], 
		attribute_parameters['attr_filename'],
		modisco_parameters['n_seqlets'],
		modisco_parameters['output_filename'])
	if modisco_parameters['verbose']:
		cmd += ' -v'

	subprocess.run(cmd.split(), check=True)


	# Step 5: Generate the tf-modisco report
	modisco_name = "{}_modisco_results.h5".format(parameters['name'])

	report_parameters = parameters['modisco_report_parameters']
	_check_set(report_parameters, "verbose", parameters["verbose"])
	_check_set(report_parameters, "output_folder", pname+"_modisco/")

	if report_parameters['verbose']:
		print("\nStep 5: TF-MoDISco reports")

	subprocess.run(["modisco", "report", 
		"-i", modisco_parameters['output_filename'], 
		"-o", report_parameters['output_folder'],
		"-s", report_parameters['output_folder'],
		"-m", report_parameters['motifs']
		], check=True)
	

	# Step 6: Marginalization experiments
	if parameters['verbose']:
		print("\nStep 6: Run marginalizations")

	marginalize_parameters = _extract_set(parameters, 
		default_marginalize_parameters, "marginalize_parameters")
	_check_set(marginalize_parameters, 'output_filename', pname+"_marginalize/")

	name = '{}.bpnet.marginalize.json'.format(parameters['name'])

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(marginalize_parameters, sort_keys=True, 
			indent=4))

	subprocess.run(["bpnet", "marginalize", "-p", name], check=True)
