#!/usr/bin/env python
# ChromBPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
import sys
import numpy
import torch
import argparse
import subprocess

from bpnetlite.bpnet import BPNet
from bpnetlite.chrombpnet import ChromBPNet

from bpnetlite.io import PeakGenerator
from bpnetlite.marginalize import marginalization_report

from tangermeme.io import extract_loci

import json

desc = """ChromBPNet is a neural network that builds off the original BPNet
	architecture by explicitly learning bias in the signal tracks themselves.
	Specifically, for ATAC-seq and DNAse-seq experiments, the cutting enzymes
	have a soft sequence bias (though this is much stronger for Tn5, the
	enzyme for ATAC-seq). Accordingly, ChromBPNet is a pair of neural networks
	where one models the bias explicitly and one models the accessibility
	explicitly. This tool provides functionality for training the combination
	of the bias model and accessibility model and making predictions using it.
	After training, the accessibility model can be used using the `bpnet`
	tool."""

_help = """Must be either 'fit', 'predict', 'attribute', 'marginalize', 
	or 'pipeline'."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd')


#

json_parser = subparsers.add_parser("pipeline-json",
	help="Make a pipeline JSON file given the provided information.")
json_parser.add_argument("-s", "--sequences", type=str,
	help="The FASTA file of sequences.")
json_parser.add_argument("-i", "--inputs", type=str, action='append', 
	help="A BAM or bigwig file. Repeatable.")
json_parser.add_argument("-p", "--peaks", type=str, action='append',
	help="A BED-formatted file of peaks to use. Repeatable.")
json_parser.add_argument("-neg", "--negatives", type=str, action='append',
	help="A BED-formatted file of negative loci to use. Repeatable.")
json_parser.add_argument("-n", "--name", type=str,
	help="Name to use as a suffix in intermediary files.")
json_parser.add_argument("-ps", "--pos_shift", type=int,
	default=0, help="How many bp to shift the + strand reads.")
json_parser.add_argument("-ns", "--neg_shift", type=int,
	default=0, help="how many bp to shift the - strand reads.")
json_parser.add_argument("-m", "--motifs", type=str, default=None,
	help="A motif database for marginalization and TF-MoDISco.")
json_parser.add_argument("-o", "--output", type=str,
	help="The filename for the pipeline JSON.")

#

fit_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
fit_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

#

evaluate_parser = subparsers.add_parser("evaluate", 
	help="Evaluate a trained BPNet model.")
evaluate_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

#

attribute_parser = subparsers.add_parser("attribute", 
	help="Calculate attributions using a trained BPNet model.")
attribute_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

#

seqlet_parser = subparsers.add_parser("seqlets", 
	help="Identify seqlets from attributions.")
seqlet_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for identifying seqlets.")

#

marginalize_parser = subparsers.add_parser("marginalize", 
	help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for running marginalizations.")

#

pipeline_parser = subparsers.add_parser("pipeline", 
	help="Run each step on the given files.")
pipeline_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters used for each step.")


###
# Default Parameters
###


training_chroms = ["chr2", "chr4", "chr5", "chr7", "chr9", "chr10", "chr11",
    "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", "chr18", "chr19",
    "chr21", "chr22", "chrX", "chrY"]

validation_chroms = ['chr8', 'chr20']


default_fit_parameters = {
	'n_filters': 256,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 500,
	'reverse_complement': True,
	'reverse_complement_average': False,
	'summits': True,
	'max_epochs': 50,
	'lr': 0.001,
	'negative_ratio': 0.1,
	'count_loss_weight': None,
	'dtype': 'float32',
	'device': 'cuda',
	'scheduler': True,
	'early_stopping': 5,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,

	'training_chroms': training_chroms,
	'validation_chroms': validation_chroms,
	'sequences': None,
	'loci': None,
	'exclusion_lists': None,
	'negatives': None,
	'signals': None,
	'controls': None,
	'random_state': None,
	'performance_filename': 'performance.tsv',

	'skip': False,
}


default_pipeline_parameters = {
	# Shared parameters
	'in_window': 2114,
	'out_window': 1000,
	'name': None,
	'model': None,
	'dtype': 'float32',
	'device': 'cuda',

	# Data parameters
	'batch_size': 64,
	'verbose': True,
	'min_counts': 0,
	'max_counts': 99999999,
	'random_state': None,

	'exclusion_lists': None,
	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,

	'skip': False,
	'dry_run': False,

	# Data processing parameters
	'preprocessing_parameters': {
		'unstranded': False,
		'fragments': False,
		'paired_end': False,
		'pos_shift': 0,
		'neg_shift': 0,
		'callpeaks_format': None,
		'callpeaks_gsize': 'hs',
		'callpeaks_q': 0.01,
		'verbose': True
	},

	# Bias fit parameters
	'bias_fit_parameters': {
		'n_filters': 128,
		'n_layers': 4,
		'profile_output_bias': True,
		'count_output_bias': True,
		'batch_size': 64,
		'lr': 0.001,
		'scheduler': False,
		'early_stopping': 5,
		'max_jitter': 0,
		'reverse_complement': True,
		'reverse_complement_average': False,
		'max_epochs': 50,
		'training_chroms': training_chroms,
		'validation_chroms': validation_chroms,
		'bias_threshold_factor': 0.5,
		'sequences': None,
		'loci': None,
		'signals': None,
		'verbose': None,
		'random_state': None,
		'summits': False,
		'performance_filename': None,
	},
	
	# ChromBPNet fit parameters
	'chrombpnet_fit_parameters': {
		'n_filters': 256,
		'n_layers': 8,
		'profile_output_bias': True,
		'count_output_bias': True,
		'batch_size': 64,
		'lr': 0.001,
		'scheduler': False,
		'negative_ratio': 0.1,
		'count_loss_weight': None,
		'early_stopping': 5,
		'max_jitter': 500,
		'reverse_complement': True,
		'reverse_complement_average': False,
		'max_epochs': 50,
		'training_chroms': training_chroms,
		'validation_chroms': validation_chroms,
		'sequences': None,
		'loci': None,
		'signals': None,
		'verbose': None,
		'random_state': None,
		'summits': False,
		'performance_filename': None,
	},

	# Attribution parameters
	'attribute_parameters': {
		'batch_size': 64,
		'chroms': validation_chroms,
		'output': 'counts',
		'loci': None,
		'device': None,
		'ohe_filename': None,
		'attr_filename': None,
		'idx_filename': None,
		'n_shuffles': 20,
		'warning_threshold': 1e-3,
		'random_state': None,
		'verbose': None
	},


	# Seqlet Parameters
	'seqlet_parameters': {
		'threshold': 0.01,
		'min_seqlet_len': 4,
		'max_seqlet_len': 25,
		'additional_flanks': 3,
		'in_window': None,
		'chroms': None,
		'verbose': None,
		'loci': None,
		'ohe_filename': None,
		'attr_filename': None,
		'idx_filename': None,
		'output_filename': None
	},


	# Seqlet Annotation Parameters
	'annotation_parameters': {
		'motifs': None,
		'sequences': None,
		'seqlet_filename': None,
		'n_score_bins': 100,
		'n_median_bins': 1000,
		'n_target_bins': 100,
		'n_cache': 250,
		'reverse_complement': True,
		'n_jobs': -1,
		'output_filename': None
	},


	# Modisco parameters
	'modisco_motifs_parameters': {
		'n_seqlets': 100000,
		'output_filename': None,
		'verbose': None
	},


	# Modisco report parameters
	'modisco_report_parameters': {
		'motifs': None,
		'output_folder': None,
		'verbose': None
	},


	# Marginalization parameters
	'marginalize_parameters': {
		'loci': None,
		'n_loci': 100,
		'attributions': False,
		'batch_size': 64,
		'shuffle': False,
		'random_state': None,
		'output_folder': None,
		'motifs': None,
		'minimal': True,
		'device': None,
		'verbose': None
	} 
}


##########
# COMMANDS
##########


def _extract_set(parameters, defaults, name):
	subparameters = {
		key: parameters.get(key, None) for key in defaults if key in parameters
	}
	
	for parameter, value in parameters[name].items():
		if value is not None:
			subparameters[parameter] = value

	return subparameters

def _check_set(parameters, parameter, value):
	if parameters.get(parameter, None) == None:
		parameters[parameter] = value


def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	if isinstance(parameters, str):
		with open(parameters, "r") as infile:
			parameters = json.load(infile)

	unset_parameters = ("controls", "warning_threshold", "early_stopping", 
		"count_loss_weight", "exclusion_lists")
	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter not in unset_parameters:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()



##########
# PIPELINE-JSON
##########


if args.cmd == 'pipeline-json':
	pp = 'preprocessing_parameters'
	parameters = default_pipeline_parameters.copy()
	
	parameters['sequences'] = args.sequences
	parameters['loci'] = args.peaks
	parameters['negatives'] = args.negatives
	parameters['signals'] = args.inputs
	parameters['controls'] = None
	parameters['name'] = args.name
	parameters['motifs'] = args.motifs

	parameters[pp]['unstranded'] = True
	parameters[pp]['fragments'] = True
	parameters[pp]['pos_shift'] = args.pos_shift
	parameters[pp]['neg_shift'] = args.neg_shift
	parameters[pp]['paired_end'] = True
	
	with open(args.output, 'w') as outfile:
		outfile.write(json.dumps(parameters, indent=4))
	
	sys.exit()


import os
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'

import torch
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')


##########
# FIT
##########

if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	if parameters['bias_model'] is None:
		counts = training_data.dataset.signals.sum(dim=(1, 2))
		min_counts = torch.quantile(counts, 0.01).item()

		name = 'bias_fit_parameters'
		bias_fit_parameters = {key: parameters[key] for key in 
			default_fit_parameters}
		for parameter, value in bias_fit_parameters[name].items():
			if value is not None:
				bias_fit_parameters[parameter] = value
			if parameter == 'loci' and value is None:
				bias_fit_parameters[parameter] = parameters['negatives']

		del bias_fit_parameters['negatives'], bias_fit_parameters['beta']
			
		name = '{}.chrombpnet.bias.fit.json'.format(parameters['name'])
		bias_fit_parameters['max_counts'] = min_counts * parameters['beta']
		bias_fit_parameters['name'] = parameters['name'] + '.bias'
		parameters['bias_model'] = bias_fit_parameters['name'] + '.torch'

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(bias_fit_parameters, sort_keys=True, 
				indent=4))

		subprocess.run(["bpnet", "fit", "-p", name], check=True)


	if parameters['negatives'] is not None:
		training_data = PeakGenerator(
			loci=[parameters['loci'], parameters['negatives']], 
			sequences=parameters['sequences'],
			signals=parameters['signals'],
			controls=None,
			chroms=parameters['training_chroms'],
			in_window=parameters['in_window'],
			out_window=parameters['out_window'],
			max_jitter=parameters['max_jitter'],
			reverse_complement=parameters['reverse_complement'],
			min_counts=parameters['min_counts'],
			max_counts=parameters['max_counts'],
			random_state=parameters['random_state'],
			batch_size=parameters['batch_size'],
			verbose=parameters['verbose']
		)

	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		in_signals=None,
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 
		verbose=parameters['verbose']
	)

	bias = torch.load(parameters['bias_model'], weights_only=False, map_location='cpu').cuda().eval()
	accessibility = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_control_tracks=0, n_outputs=1,
		alpha=parameters['alpha'],
		name=parameters['name'] + '.accessibility',
		trimming=trimming).cuda()

	model = ChromBPNet(bias=bias, accessibility=accessibility,
		name=parameters['name'])

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])


##########
# PIPLEINE
##########


elif args.cmd == 'pipeline':
	import pandas
	import subprocess
	
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)
	preprocess_parameters = parameters['preprocessing_parameters']
	
	pname = parameters['name']

	###
	# Step 0.1: Run MACS3 to call peaks if not provided
	###

	if parameters['loci'] is None:
		if preprocess_parameters['verbose']:
			print("\nStep 0.1: Call peaks using MACS3.")

		if preprocess_parameters['callpeaks_format'] is None:
			file_format = parameters['signals'][0].split(".")[-1].upper()

			if preprocess_parameters['paired_end']:
				file_format += 'PE'

			preprocess_parameters['callpeaks_format'] = file_format
			
		
		args = [
			"macs3", "callpeak",
			"-f", preprocess_parameters['callpeaks_format'],
			"-g", str(preprocess_parameters['callpeaks_gsize']),
			"-n", pname,
			"-q", str(preprocess_parameters['callpeaks_q']),
			"-t"
		]

		args.extend(parameters['signals'])

		if parameters['controls'] is not None:
			args += ['-c'] 
			args.extend(parameters['controls'])

		parameters['loci'] = [pname + '_peaks.narrowPeak']
		
		if not parameters['dry_run']:
			subprocess.run(args, check=True)

	
	###
	# Step 0.2: Convert from SAM/BAMs to bigwigs if provided
	###

	ftypes = '.sam', '.bam', '.bed', '.bed.gz', '.tsv', '.tsv.gz'

	if parameters['signals'][0].endswith(ftypes):
		if preprocess_parameters['verbose']:
			print("Step 0.2: Convert data to bigWigs")
		
		args = [
			"bam2bw", 
			"-s", parameters['sequences'], 
			"-n", pname,
			"-ps", str(preprocess_parameters['pos_shift']),
			"-ns", str(preprocess_parameters['neg_shift']),
			"-p", "-1",
		]

		if preprocess_parameters["unstranded"]:
			args += ["-u"]

		if preprocess_parameters['fragments']:
			args += ["-f"]

		if preprocess_parameters["verbose"]:
			args += ["-v"]

		args += parameters['signals']
		if not parameters['dry_run']:
			subprocess.run(args)

		if preprocess_parameters["unstranded"]:
			parameters['signals'] = [pname + ".bw"]
		else:
			parameters['signals'] = [pname + ".+.bw", pname + ".-.bw"]


	if parameters['controls'] is not None:
		if parameters['controls'][0].endswith(ftypes):
			args = [
				"bam2bw", 
				"-s", parameters['sequences'], 
				"-n", pname + ".control",
				"-ps", str(preprocess_parameters['pos_shift']),
				"-ns", str(preprocess_parameters['neg_shift']),
				"-p", "-1"
			]

			if preprocess_parameters["unstranded"]:
				args += ["-u"]

			if preprocess_parameters["fragments"]:
				args += ["-f"]

			if preprocess_parameters["verbose"]:
				args += ["-v"]

			args += parameters['controls']
			if not parameters['dry_run']:
				subprocess.run(args, check=True)

			if preprocess_parameters["unstranded"]:
				parameters['controls'] = [pname + ".control.bw"]
			else:
				parameters['controls'] = [pname + ".control.+.bw", pname + ".control.-.bw"]

		
	###
	# Step 0.3: Identify GC-matched negative regions
	###

	if parameters['negatives'] is None:
		if preprocess_parameters['verbose']:
			print("\nStep 0.3: Find GC-matched negative regions.")

		args = [
			"bpnet", "negatives", 
			"-i", parameters["loci"][0],
			"-f", parameters["sequences"],
			"-o", pname + ".negatives.bed"
		]

		if preprocess_parameters['verbose']:
			args += ['-v']

		parameters["negatives"] = [pname + ".negatives.bed"]

		if not parameters['dry_run']:
			subprocess.run(args, check=True)


	###
	# Step 1: Fit a BPNet model to the provided data
	###

	if parameters['verbose']:
		print("\nStep 1: Fitting a BPNet model")

	fit_parameters = _extract_set(parameters, default_fit_parameters, 
		'fit_parameters')

	if parameters.get('model', None) == None:
		name = pname + '.bpnet.fit.json'
		parameters['model'] = pname + '.torch'
		_check_set(fit_parameters, 'performance_filename', pname + '.performance.tsv')

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		if not parameters['dry_run']:
			subprocess.run(["bpnet", "fit", "-p", name], check=True)


	###
	# Step 2: Calculate attributions
	###

	if parameters['verbose']:
		print("\nStep 2: Calculating attributions")

	attribute_parameters = _extract_set(parameters, 
		default_attribute_parameters, 'attribute_parameters')
	_check_set(attribute_parameters, 'ohe_filename',  pname+'.attributions.ohe.npz')
	_check_set(attribute_parameters, 'attr_filename', pname+'.attributions.attr.npz')
	_check_set(attribute_parameters, 'idx_filename',  pname+'.attributions.idxs.npy')

	name = '{}.bpnet.attribute.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(attribute_parameters, sort_keys=True, indent=4))

	if not parameters['dry_run']:
		subprocess.run(["bpnet", "attribute", "-p", name], check=True)


	###
	# Step 3.1: Identify seqlets from attributions
	###

	if parameters['verbose']:
		print("\nStep 3.1: Seqlet identification")

	seqlet_parameters = _extract_set(parameters,
		default_seqlet_parameters, 'seqlet_parameters')
	_check_set(seqlet_parameters, "ohe_filename",  pname+'.attributions.ohe.npz')
	_check_set(seqlet_parameters, "attr_filename", pname+'.attributions.attr.npz')
	_check_set(seqlet_parameters, "idx_filename",  pname+'.attributions.idxs.npy')
	_check_set(seqlet_parameters, "output_filename", pname+".seqlets.bed")
	_check_set(seqlet_parameters, "chroms", attribute_parameters['chroms'])

	name = '{}.bpnet.seqlets.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(seqlet_parameters, sort_keys=True, indent=4))

	if not parameters['dry_run']:
		subprocess.run(["bpnet", "seqlets", "-p", name], check=True)


	###
	# Step 3.2: Annotate seqlets using motif database
	###

	if parameters['verbose']:
		print("\nStep 3.2: Seqlet annotation")

	annotation_parameters = _extract_set(parameters, 
		default_annotation_parameters, "annotation_parameters")
	_check_set(annotation_parameters, "seqlet_filename", pname+".seqlets.bed")
	_check_set(annotation_parameters, "output_filename", pname+".seqlets_annotated.bed")
	_check_set(annotation_parameters, "motifs", parameters["motifs"])

	annotation_parameters = merge_parameters(annotation_parameters, 
		default_annotation_parameters)

	cmd = ["ttl"]
	cmd += ["-f", annotation_parameters["sequences"]]
	cmd += ["-b", annotation_parameters["seqlet_filename"]]
	cmd += ["-s", str(annotation_parameters["n_score_bins"])]
	cmd += ["-m", str(annotation_parameters["n_median_bins"])]
	cmd += ["-a", str(annotation_parameters["n_target_bins"])]
	cmd += ["-c", str(annotation_parameters["n_cache"])]
	cmd += ["-j", str(annotation_parameters["n_jobs"])]

	if not annotation_parameters["reverse_complement"]:
		cmd += ["-r"]

	if annotation_parameters['motifs'] is not None:
		cmd += ["-t", annotation_parameters["motifs"]]

	if not parameters['dry_run']:
		with open(annotation_parameters['output_filename'], "w") as f:
			subprocess.run(cmd, check=True, stdout=f)

	annotated_seqlets = pandas.read_csv(annotation_parameters['output_filename'],
		sep="\t", header=None, usecols=(3,), names=['motifs'])

	seqlet_count = annotated_seqlets.value_counts()
	seqlet_count.to_csv(pname+".motif_seqlet_count.tsv", sep="\t")


	###
	# Step 4.1: Run TF-MoDISco
	###

	if parameters['verbose']:
		print("\nStep 4.1: TF-MoDISco motifs")

	modisco_parameters = parameters['modisco_motifs_parameters']

	_check_set(modisco_parameters, "output_filename", 
		pname+'_modisco_results.h5')
	_check_set(modisco_parameters, "verbose",
		parameters['verbose'])

	modisco_parameters = merge_parameters(modisco_parameters,
		default_pipeline_parameters['modisco_motifs_parameters'])

	cmd = "modisco motifs -s {} -a {} -n {} -o {}".format(
		attribute_parameters['ohe_filename'], 
		attribute_parameters['attr_filename'],
		modisco_parameters['n_seqlets'],
		modisco_parameters['output_filename'])

	if 'verbose' in modisco_parameters and modisco_parameters['verbose']:
		cmd += ' -v'
	elif parameters['verbose']:
		cmd += ' -v'

	if not parameters['dry_run']:
		subprocess.run(cmd.split(), check=True)


	###
	# Step 4.2: Generate the tf-modisco report
	###
	
	report_parameters = parameters['modisco_report_parameters']
	_check_set(report_parameters, "verbose", parameters["verbose"])
	_check_set(report_parameters, "output_folder", pname+"_modisco/")
	_check_set(report_parameters, "motifs", parameters['motifs'])
	
	if report_parameters['verbose']:
		print("\nStep 4.2: TF-MoDISco reports")

	if not parameters['dry_run']:
		if report_parameters['motifs'] is not None:
			subprocess.run(["modisco", "report", 
				"-i", modisco_parameters['output_filename'], 
				"-o", report_parameters['output_folder'],
				"-s", './',
				"-m", report_parameters['motifs']
				], check=True)
		else:
			subprocess.run(["modisco", "report", 
				"-i", modisco_parameters['output_filename'], 
				"-o", report_parameters['output_folder'],
				"-s", './'
				], check=True)
        

	###
	# Step 5: Marginalization experiments
	###

	if parameters['motifs'] is None:
		sys.exit()
		
	if parameters['verbose']:
		print("\nStep 5: Run marginalizations")
	
	marginalize_parameters = _extract_set(parameters, 
		default_marginalize_parameters, "marginalize_parameters")
	
	_check_set(marginalize_parameters, "loci", parameters["negatives"])
	_check_set(marginalize_parameters, 'output_filename', pname+"_marginalize/")
	_check_set(marginalize_parameters, 'motifs', parameters['motifs'])
	_check_set(marginalize_parameters, 'negatives', parameters['negatives'])
	
	name = '{}.bpnet.marginalize.json'.format(parameters['name'])
	
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(marginalize_parameters, sort_keys=True, 
			indent=4))

	if not parameters['dry_run']:
		subprocess.run(["bpnet", "marginalize", "-p", name], check=True)


##########
# PIPELINE
##########

elif args.cmd == 'pipeline':
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)
	model_name = parameters['name']

	# Step 1: Fit a BPNet model to the provided data
	if parameters['verbose']:
		print("Step 1: Fitting a ChromBPNet model")

	if parameters['model'] is None:
		name = '{}.chrombpnet.fit.json'.format(parameters['name'])
		parameters['model'] = parameters['name'] + '.torch'

		fit_parameters = {key: parameters[key] for key in 
			default_fit_parameters}
		for parameter, value in parameters['chrombpnet_fit_parameters'].items():
			if value is not None:
				fit_parameters[parameter] = value

		for parameter, value in parameters['bias_fit_parameters'].items():
			if value is not None:
				fit_parameters['bias_fit_parameters'][parameter] = value

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		subprocess.run(["chrombpnet", "fit", "-p", name], check=True)


	if parameters['bias_model'] is None:
		parameters['bias_model'] = model_name + '.bias.torch'

	if parameters['accessibility_model'] is None:
		parameters['accessibility_model'] = (model_name + 
			'.accessibility.torch')

	del parameters['bias_fit_parameters']
	del parameters['chrombpnet_fit_parameters']

	# Run pipeline with ChromBPNet model
	name = '{}.chrombpnet.pipeline.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "pipeline", "-p", name], check=True)


	# Run pipeline with accessibility model
	name = '{}.chrombpnet.accessibility.pipeline.json'.format(
		model_name)

	parameters['model'] = parameters['accessibility_model']
	parameters['name'] = model_name + '.accessibility'

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "pipeline", "-p", name], check=True)


	# Run pipeline with bias model
	name = '{}.chrombpnet.bias.pipeline.json'.format(model_name)

	parameters['model'] = parameters['bias_model']
	parameters['name'] = model_name + '.bias'

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	subprocess.run(["bpnet", "pipeline", "-p", name], check=True)

