#!/usr/bin/env python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'

import sys
import numpy
import torch
import pandas
import argparse
import subprocess


from bpnetlite.io import PeakGenerator

from bpnetlite.bpnet import BPNet
from bpnetlite.bpnet import CountWrapper
from bpnetlite.bpnet import ProfileWrapper
from bpnetlite.bpnet import ControlWrapper

from bpnetlite.chrombpnet import ChromBPNet

from bpnetlite.attribute import deep_lift_shap
from bpnetlite.marginalize import marginalization_report

from tangermeme.io import _interleave_loci
from tangermeme.io import extract_loci
from tangermeme.utils import example_to_fasta_coords
from tangermeme.match import extract_matching_loci
from tangermeme.seqlet import recursive_seqlets
from tangermeme.predict import predict

import json

torch.backends.cudnn.benchmark = True


desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

_help = """Must be either 'negatives', 'fit', 'predict', 'attribute',
	'seqlets', 'marginalize', or 'pipeline'."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd')

#

negatives_parser = subparsers.add_parser("negatives", 
	help="Sample GC-matched negatives.")
negatives_parser.add_argument("-i", "--peaks", required=True, 
	help="Peak bed file.")
negatives_parser.add_argument("-f", "--fasta", help="Genome FASTA file.")
negatives_parser.add_argument("-b", "--bigwig", help="Optional signal bigwig.")
negatives_parser.add_argument("-o", "--output", required=True, 
	help="Output bed file.")
negatives_parser.add_argument("-l", "--bin_width", type=float, default=0.02, 
	help="GC bin width to match.")
negatives_parser.add_argument("-n", "--max_n_perc", type=float, default=0.1, 
	help="Maximum percentage of Ns allowed in each locus.")
negatives_parser.add_argument("-a", "--beta", type=float, default=0.5, 
	help="Multiplier on the minimum counts in peaks.")
negatives_parser.add_argument("-w", "--in_window", type=int, default=2114, 
	help="Width for calculating GC content.")
negatives_parser.add_argument("-x", "--out_window", type=int, default=1000, 
	help="Non-overlapping stride to use for loci.")
negatives_parser.add_argument("-v", "--verbose", default=False, 
	action='store_true')

#

json_parser = subparsers.add_parser("pipeline-json",
	help="Make a pipeline JSON file given the provided information.")
json_parser.add_argument("-s", "--sequences", type=str,
	help="The FASTA file of sequences.")
json_parser.add_argument("-i", "--inputs", type=str, action='append', 
	help="A BAM or bigwig file. Repeatable.")
json_parser.add_argument("-c", "--controls", type=str, action='append',
	help="A BAM or bigwig file. Repeatable.")
json_parser.add_argument("-l", "--loci", type=str, action='append',
	help="A BED-formatted file of loci to use. Repeatable.")
json_parser.add_argument("-n", "--name", type=str,
	help="Name to use as a suffix in intermediary files.")
json_parser.add_argument("-u", "--unstranded", action='store_true', 
	default=False, help="Whether the input is stranded")
json_parser.add_argument("-f", "--fragments", action='store_true',
	default=False, help='Whether the input are fragments or reads.')
json_parser.add_argument("-m", "--motifs", type=str,
	default="A motif database for marginalization and TF-MoDISco.")
json_parser.add_argument("-o", "--output", type=str,
	help="The filename for the pipeline JSON.")

#

fit_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
fit_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

#

predict_parser = subparsers.add_parser("predict", 
	help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

#

attribute_parser = subparsers.add_parser("attribute", 
	help="Calculate attributions using a trained BPNet model.")
attribute_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

#

seqlet_parser = subparsers.add_parser("seqlets", 
	help="Identify seqlets from attributions.")
seqlet_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for identifying seqlets.")

#

marginalize_parser = subparsers.add_parser("marginalize", 
	help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for running marginalizations.")

#

pipeline_parser = subparsers.add_parser("pipeline", 
	help="Run each step on the given files.")
pipeline_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters used for each step.")


###
# Default Parameters
###

training_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 
	'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 
	'chr19', 'chr20', 'chr21', 'chr22', 'chrX']

validation_chroms = ['chr8', 'chr10']


default_fit_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 100,
	'dtype': 'float32',
	'device': 'cuda',
	'early_stopping': None,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,

	'training_chroms': training_chroms,
	'validation_chroms': validation_chroms,
	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,
	'random_state': None
}


default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'device': 'cuda',
	'dtype': 'float32',
	'sequences': None,
	'loci': None,
	'controls': None,
	'model': None,
	'profile_filename': 'predictions.profile.npz',
	'counts_filename': 'predictions.counts.npz',
	'idx_filename': 'predictions.idx.npy'
}


default_attribute_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'sequences': None,
	'loci': None,
	'model': None,
	'output': 'counts',
	'ohe_filename': 'attributions.ohe.npz',
	'attr_filename': 'attributions.attr.npz',
	'idx_filename': 'attributions.idx.npy',
	'n_shuffles': 20,
	'random_state': 0,
	'device': 'cuda',
	'warning_threshold': 1e-3
}


default_seqlet_parameters = {
	'threshold': 0.01,
	'min_seqlet_len': 4,
	'max_seqlet_len': 25,
	'additional_flanks': 3,
	'in_window': 2114,
	'chroms': training_chroms,
	'verbose': False,
	'loci': None,
	'ohe_filename': None,
	'attr_filename': None,
	'idx_filename': None,
	'output_filename': 'seqlets.bed',
}


default_annotation_parameters = {
	'motifs': None,
	'sequences': None,
	'seqlet_filename': None,
	'n_score_bins': 100,
	'n_median_bins': 1000,
	'n_target_bins': 100,
	'n_cache': 250,
	'reverse_complement': True,
	'n_jobs': -1,
	'output_filename': 'seqlets_annotated.bed'
}


default_marginalize_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': training_chroms,
	'sequences': None,
	'motifs': None,
	'loci': None,
	'attributions': False,
	'n_loci': 100,
	'shuffle': False,
	'model': None,
	'output_filename':'marginalize/',
	'random_state':0,
	'minimal': True,
	'device': 'cuda'
}


default_pipeline_parameters = {
	# Model architecture parameters
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'in_window': 2114,
	'out_window': 1000,
	'name': None,
	'model': None,
	'dtype': 'float32',
	'device': 'cuda',
	'early_stopping': None,

	# Data parameters
	'batch_size': 64,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 20,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 100,
	'verbose': True,
	'min_counts': 0,
	'max_counts': 99999999,
	'random_state': None,

	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,
	'find_negatives': False,
	'unstranded': False,
	'fragments': False,


	# Fit parameters
	'fit_parameters': {
		'batch_size': 64,
		'training_chroms': training_chroms,
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,
		'verbose': None,
		'random_state': None,
	},


	# Predict parameters
	'predict_parameters': {
		'batch_size': 64,
		'chroms': validation_chroms,
		'profile_filename': None,
		'counts_filename': None,
		'idx_filename': None,
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,
		'dtype': None,
		'device': None,
		'verbose': None
	},


	# Attribution parameters
	'attribute_parameters': {
		'batch_size': 64,
		'chroms': validation_chroms,
		'output': 'counts',
		'loci': None,
		'device': None,
		'ohe_filename': None,
		'attr_filename': None,
		'idx_filename': None,
		'n_shuffles': 20,
		'warning_threshold': 1e-3,
		'random_state': None,
		'verbose': None
	},


	# Seqlet Parameters
	'seqlet_parameters': {
		'threshold': 0.01,
		'min_seqlet_len': 4,
		'max_seqlet_len': 25,
		'additional_flanks': 3,
		'in_window': None,
		'chroms': None,
		'verbose': None,
		'loci': None,
		'ohe_filename': None,
		'attr_filename': None,
		'idx_filename': None,
		'output_filename': None
	},


	# Seqlet Annotation Parameters
	'annotation_parameters': {
		'motifs': None,
		'sequences': None,
		'seqlet_filename': None,
		'n_score_bins': 100,
		'n_median_bins': 1000,
		'n_target_bins': 100,
		'n_cache': 250,
		'reverse_complement': True,
		'n_jobs': -1,
		'output_filename': None
	},


	# Modisco parameters
	'modisco_motifs_parameters': {
		'n_seqlets': 100000,
		'output_filename': None,
		'verbose': None
	},


	# Modisco report parameters
	'modisco_report_parameters': {
		'motifs': None,
		'output_folder': None,
		'verbose': None
	},


	# Marginalization parameters
	'marginalize_parameters': {
		'loci': None,
		'n_loci': 100,
		'attributions': False,
		'batch_size': 64,
		'shuffle': False,
		'random_state': None,
		'output_folder': None,
		'motifs': None,
		'minimal': True,
		'device': None,
		'verbose': None
	} 
}


###
# COMMANDS
###


def _extract_set(parameters, defaults, name):
	subparameters = {
		key: parameters.get(key, None) for key in defaults if key in parameters
	}
	
	for parameter, value in parameters[name].items():
		if value is not None:
			subparameters[parameter] = value

	return subparameters

def _check_set(parameters, parameter, value):
	if parameters.get(parameter, None) == None:
		parameters[parameter] = value


def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	if isinstance(parameters, str):
		with open(parameters, "r") as infile:
			parameters = json.load(infile)

	unset_parameters = ("controls", "warning_threshold", "early_stopping")
	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter not in unset_parameters:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()


##########
# NEGATIVES
##########


if args.cmd == 'negatives':
	# Extract regions that match the GC content of the peaks
	matched_loci = extract_matching_loci(
		loci=args.peaks, 
		fasta=args.fasta,
		gc_bin_width=args.bin_width,
		max_n_perc=args.max_n_perc,
		bigwig=args.bigwig,
		signal_beta=args.beta,
		in_window=args.in_window,
		out_window=args.out_window,
		chroms=None,
		verbose=args.verbose
	)

	matched_loci.to_csv(args.output, header=False, sep='\t', index=False)


##########
# PIPELINE-JSON
##########


if args.cmd == 'pipeline-json':
	parameters = default_pipeline_parameters.copy()
	
	parameters['sequences'] = args.sequences
	parameters['loci'] = args.loci
	parameters['signals'] = args.inputs
	parameters['controls'] = args.controls
	parameters['name'] = args.name
	parameters['unstranded'] = args.unstranded
	parameters['motifs'] = args.motifs
	parameters['find_negatives'] = True
	parameters['fragments'] = args.fragments
	
	with open(args.output, 'w') as outfile:
		outfile.write(json.dumps(parameters, indent=4))


##########
# FIT
##########


if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_data = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		in_signals=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		ignore=list('QWERYUIOPSDFHJKLZXVBNM'),
		verbose=parameters['verbose']
	)

	if parameters['controls'] is not None:
		valid_sequences, valid_signals, valid_controls = valid_data
		n_control_tracks = len(parameters['controls'])
	else:
		valid_sequences, valid_signals = valid_data
		valid_controls = None
		n_control_tracks = 0

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		n_outputs=len(parameters['signals']),
		n_control_tracks=n_control_tracks,
		profile_output_bias=parameters['profile_output_bias'],
		count_output_bias=parameters['count_output_bias'],
		alpha=parameters['alpha'],
		trimming=trimming,
		name=parameters['name'],
		verbose=parameters['verbose']).to(parameters['device'])

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	if parameters['verbose']:
		print("Training Set Size: ", training_data.dataset.sequences.shape[0])
		print("Validation Set Size: ", valid_sequences.shape[0])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'],
		early_stopping=parameters['early_stopping'],
		dtype=parameters['dtype'],
		device=parameters['device'])


##########
# PREDICT
##########


elif args.cmd == 'predict':
	parameters = merge_parameters(args.parameters, default_predict_parameters)

	###

	model = torch.load(parameters['model'], weights_only=False).to(
		parameters['device'])

	examples = extract_loci(
		sequences=parameters['sequences'],
		in_signals=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		ignore=list('QWERYUIOPSDFHJKLZXVBNM'),
		return_filtered=True,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X, idxs = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl, idxs = examples

	if X_ctl is not None:
		X_ctl = (X_ctl,)

	y_profiles, y_counts = predict(model, X, args=X_ctl, 
		batch_size=parameters['batch_size'], device=parameters['device'], 
		dtype=parameters['dtype'], verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['counts_filename'], y_counts)
	numpy.save(parameters['idx_filename'], idxs)


##########
# ATTRIBUTE
##########


elif args.cmd == 'attribute':
	parameters = merge_parameters(args.parameters, default_attribute_parameters)

	###

	model = torch.load(parameters['model'], weights_only=False).to(
		parameters['device'])

	dtype = torch.float32
	if parameters['output'] == 'profile' or isinstance(model, ChromBPNet):
		dtype = torch.float64

	X, idxs = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		ignore=list('QWERYUIOPSDFHJKLZXVBNM'),
		return_filtered=True,
		verbose=parameters['verbose']
	)

	n_idxs = X.sum(dim=(1, 2)) == X.shape[-1]
	X = X[n_idxs]
	
	idxs = idxs[n_idxs]

	model = ControlWrapper(model)
	if parameters['output'] == 'counts':
		wrapper = CountWrapper(model)
	elif parameters['output'] == 'profile':
		wrapper = ProfileWrapper(model)
	else:
		raise ValueError("output must be either `counts` or `profile`.")

	X_attr = deep_lift_shap(wrapper.type(dtype), X.type(dtype),
		hypothetical=True,
		n_shuffles=parameters['n_shuffles'],
		batch_size=parameters['batch_size'],
		device=parameters['device'],
		random_state=parameters['random_state'],
		verbose=parameters['verbose'],
		warning_threshold=parameters['warning_threshold'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['attr_filename'], X_attr)
	numpy.save(parameters['idx_filename'], idxs)


##########
# SEQLETS
##########


elif args.cmd == 'seqlets':
	parameters = merge_parameters(args.parameters, default_seqlet_parameters)

	###

	idxs = numpy.load(parameters['idx_filename'])

	loci = _interleave_loci(parameters['loci'], parameters['chroms'])
	loci = loci.iloc[idxs]

	X = numpy.load(parameters['ohe_filename'])['arr_0']
	X = torch.from_numpy(X)
	
	X_attr = numpy.load(parameters['attr_filename'])['arr_0']
	X_attr = torch.from_numpy(X_attr)
	X_attr = (X_attr * X).sum(dim=1)

	seqlets = recursive_seqlets(
		X_attr, 
		threshold=parameters['threshold'],
		min_seqlet_len=parameters['min_seqlet_len'],
		max_seqlet_len=parameters['max_seqlet_len'],
		additional_flanks=parameters['additional_flanks']
	).sort_values("attribution", ascending=False)

	seqlets = example_to_fasta_coords(seqlets, loci, parameters['in_window'])
	seqlets.to_csv(parameters['output_filename'], sep='\t', index=False, 
		header=False)


##########
# MARGINALIZE
##########


elif args.cmd == 'marginalize':
	parameters = merge_parameters(args.parameters, 
		default_marginalize_parameters)

	###

	model = torch.load(parameters['model'], weights_only=False).to(
		parameters['device'])
	model = ControlWrapper(model)

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 
		n_loci=parameters['n_loci'],
		verbose=parameters['verbose']
	).float()

	if parameters['shuffle'] == True:
		idxs = numpy.arange(X.shape[0])
		numpy.random.shuffle(idxs)
		X = X[idxs]

	if parameters['n_loci'] is not None:
		X = X[:parameters['n_loci']]

	marginalization_report(model, parameters['motifs'], X, 
		parameters['output_filename'], 
		attributions=parameters['attributions'],
		batch_size=parameters['batch_size'], 
		minimal=parameters['minimal'],
		device=parameters['device'],
		verbose=parameters['verbose'])


##########
# PIPLEINE
##########


elif args.cmd == 'pipeline':
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)
	pname = parameters['name']


	###
	# Step 0.1: Convert from SAM/BAMs to bigwigs if provided
	###

	ftypes = '.sam', '.bam', '.tsv', '.tsv.gz'

	if parameters['signals'][0].endswith(ftypes):
		if parameters['verbose']:
			print("Step 0.1: Convert data to bigWigs")
		
		args = [
			"bam2bw", 
			"-s", parameters['sequences'], 
			"-n", pname,
		]

		if parameters["unstranded"]:
			args += ["-u"]

		if parameters['fragments']:
			args += ["-f"]

		if parameters["verbose"]:
			args += ["-v"]

		args += parameters['signals']
		subprocess.run(args, check=True)

		if parameters["unstranded"]:
			parameters['signals'] = [pname + ".bw"]
		else:
			parameters['signals'] = [pname + ".+.bw", pname + ".-.bw"]

	if parameters['controls'] is not None:
		if parameters['controls'][0].endswith(ftypes):
			args = [
				"bam2bw", 
				"-s", parameters['sequences'], 
				"-n", pname + ".control",
			]

			if parameters["unstranded"]:
				args += ["-u"]

			if parameters["fragments"]:
				args += ["-f"]

			if parameters["verbose"]:
				args += ["-v"]

			args += parameters['controls']
			subprocess.run(args, check=True)

			if parameters["unstranded"]:
				parameters['controls'] = [pname + ".control.bw"]
			else:
				parameters['controls'] = [pname + ".control.+.bw", pname + ".control.-.bw"]


	###
	# Step 0.2: Identify GC-matched negative regions
	###

	if parameters['find_negatives'] == True:
		if parameters['verbose']:
			print("\nStep 0.2: Find GC-matched negative regions.")

		args = [
			"bpnet", "negatives", 
			"-i", parameters["loci"][0],
			"-f", parameters["sequences"],
			"-o", pname + ".negatives.bed"
		]

		if parameters['verbose']:
			args += ['-v']

		parameters["negatives"] = [pname + ".negatives.bed"]

		subprocess.run(args, check=True)


	###
	# Step 1: Fit a BPNet model to the provided data
	###

	if parameters['verbose']:
		print("\nStep 1: Fitting a BPNet model")

	fit_parameters = _extract_set(parameters, default_fit_parameters, 
		'fit_parameters')

	if parameters.get('model', None) == None:
		name = pname + '.bpnet.fit.json'
		parameters['model'] = pname + '.torch'
		
		if parameters['negatives'] is not None:
			fit_parameters['loci'] += parameters['negatives']

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		subprocess.run(["bpnet", "fit", "-p", name], check=True)


	###
	# Step 2: Make predictions for the entire validation set
	###

	if parameters['verbose']:
		print("\nStep 2: Making predictions")

	predict_parameters = _extract_set(parameters, 
		default_predict_parameters, 'predict_parameters')
	_check_set(predict_parameters, 'profile_filename', pname+'.predictions.profiles.npz')
	_check_set(predict_parameters, 'counts_filename',  pname+'.predictions.counts.npz')
	_check_set(predict_parameters, 'idx_filename',     pname+'.predictions.idxs.npy')

	name = '{}.bpnet.predict.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(predict_parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "predict", "-p", name], check=True)


	###
	# Step 3: Calculate attributions
	###

	if parameters['verbose']:
		print("\nStep 3: Calculating attributions")

	attribute_parameters = _extract_set(parameters, 
		default_attribute_parameters, 'attribute_parameters')
	_check_set(attribute_parameters, 'ohe_filename',  pname+'.attributions.ohe.npz')
	_check_set(attribute_parameters, 'attr_filename', pname+'.attributions.attr.npz')
	_check_set(attribute_parameters, 'idx_filename',  pname+'.attributions.idxs.npy')

	name = '{}.bpnet.attribute.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(attribute_parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "attribute", "-p", name], check=True)


	###
	# Step 4.1: Identify seqlets from attributions
	###

	if parameters['verbose']:
		print("\nStep 4.1: Seqlet identification")

	seqlet_parameters = _extract_set(parameters,
		default_seqlet_parameters, 'seqlet_parameters')
	_check_set(seqlet_parameters, "ohe_filename",  pname+'.attributions.ohe.npz')
	_check_set(seqlet_parameters, "attr_filename", pname+'.attributions.attr.npz')
	_check_set(seqlet_parameters, "idx_filename",  pname+'.attributions.idxs.npy')
	_check_set(seqlet_parameters, "output_filename", pname+".seqlets.bed")
	_check_set(seqlet_parameters, "chroms", attribute_parameters['chroms'])

	name = '{}.bpnet.seqlets.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(seqlet_parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "seqlets", "-p", name], check=True)


	###
	# Step 4.2: Annotate seqlets using motif database
	###

	if parameters['verbose']:
		print("\nStep 4.2: Seqlet annotation")

	annotation_parameters = _extract_set(parameters, 
		default_annotation_parameters, "annotation_parameters")
	_check_set(annotation_parameters, "seqlet_filename", pname+".seqlets.bed")
	_check_set(annotation_parameters, "output_filename", pname+".seqlets_annotated.bed")
	_check_set(annotation_parameters, "motifs", parameters["motifs"])

	annotation_parameters = merge_parameters(annotation_parameters, 
		default_annotation_parameters)

	cmd = ["ttl"]
	cmd += ["-f", annotation_parameters["sequences"]]
	cmd += ["-b", annotation_parameters["seqlet_filename"]]
	cmd += ["-s", str(annotation_parameters["n_score_bins"])]
	cmd += ["-m", str(annotation_parameters["n_median_bins"])]
	cmd += ["-a", str(annotation_parameters["n_target_bins"])]
	cmd += ["-c", str(annotation_parameters["n_cache"])]
	cmd += ["-j", str(annotation_parameters["n_jobs"])]

	if not annotation_parameters["reverse_complement"]:
		cmd += ["-r"]

	if annotation_parameters['motifs'] is not None:
		cmd += ["-t", annotation_parameters["motifs"]]

	with open(annotation_parameters['output_filename'], "w") as f:
		subprocess.run(cmd, check=True, stdout=f)

	annotated_seqlets = pandas.read_csv(annotation_parameters['output_filename'],
		sep="\t", header=None, usecols=(3,), names=['motifs'])

	seqlet_count = annotated_seqlets.value_counts()
	seqlet_count.to_csv(pname+".motif_seqlet_count.tsv", sep="\t")


	###
	# Step 5.1: Run TF-MoDISco
	###

	if parameters['verbose']:
		print("\nStep 5.1: TF-MoDISco motifs")

	modisco_parameters = parameters['modisco_motifs_parameters']

	_check_set(modisco_parameters, "output_filename", 
		pname+'_modisco_results.h5')
	_check_set(modisco_parameters, "verbose",
		parameters['verbose'])

	modisco_parameters = merge_parameters(modisco_parameters,
		default_pipeline_parameters['modisco_motifs_parameters'])

	cmd = "modisco motifs -s {} -a {} -n {} -o {}".format(
		attribute_parameters['ohe_filename'], 
		attribute_parameters['attr_filename'],
		modisco_parameters['n_seqlets'],
		modisco_parameters['output_filename'])

	if 'verbose' in modisco_parameters and modisco_parameters['verbose']:
		cmd += ' -v'
	elif parameters['verbose']:
		cmd += ' -v'

	subprocess.run(cmd.split(), check=True)


	###
	# Step 5.2: Generate the tf-modisco report
	###
	
	report_parameters = parameters['modisco_report_parameters']
	_check_set(report_parameters, "verbose", parameters["verbose"])
	_check_set(report_parameters, "output_folder", pname+"_modisco/")
	_check_set(report_parameters, "motifs", parameters['motifs'])

	if report_parameters['verbose']:
		print("\nStep 5.2: TF-MoDISco reports")
	
	subprocess.run(["modisco", "report", 
		"-i", modisco_parameters['output_filename'], 
		"-o", report_parameters['output_folder'],
		"-s", './',
		"-m", report_parameters['motifs']
		], check=True)


	###
	# Step 6: Marginalization experiments
	###
	
	if parameters['verbose']:
		print("\nStep 6: Run marginalizations")

	marginalize_parameters = _extract_set(parameters, 
		default_marginalize_parameters, "marginalize_parameters")

	_check_set(marginalize_parameters, "loci", parameters["negatives"])
	_check_set(marginalize_parameters, 'output_filename', pname+"_marginalize/")
	_check_set(marginalize_parameters, 'motifs', parameters['motifs'])
	_check_set(marginalize_parameters, 'negatives', parameters['negatives'])

	name = '{}.bpnet.marginalize.json'.format(parameters['name'])

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(marginalize_parameters, sort_keys=True, 
			indent=4))

	subprocess.run(["bpnet", "marginalize", "-p", name], check=True)
