#!python

"""
This tool will convert one or more SAM/BAM file/s into a pair of stranded
BigWig files containing the 5' counts. The input files can be either stored
locally or remote URLs. When multiple files are provided, the reads are
concatenated. Only fully mapped reads are kept.
"""

import gzip
import numpy
import argparse
import collections

import pysam
import pyBigWig
import pyfaidx

from tqdm import tqdm

parser = argparse.ArgumentParser(
    prog='bam2bw',
    description='This tool will convert BAM files to bigwig files without an intermediate.')

parser.add_argument('filename', nargs='+',
	help="""The SAM/BAM or tsv/tsv.gz file to be processed.""")         
parser.add_argument('-s', '--sizes', required=True, 
	help="""A chrom_sizes or FASTA file.""")
parser.add_argument('-u', '--unstranded', action='store_true',
    help="Have only one, unstranded, output.")
parser.add_argument('-f', '--fragments', action='store_true', default=False,
	help="The data is fragments and so both ends should be recorded.")
parser.add_argument('-n', '--name', required=True)
parser.add_argument('-ps', '--pos_shift', default=0, type=int,
	help="""A shift to apply to positive strand reads.""")
parser.add_argument('-ns', '--neg_shift', default=0, type=int,
	help="""A shift to apply to negative strand reads.""") 
parser.add_argument('-z', '--zooms', default=0, type=int,
    help="""The number of zooms to store in the bigwig.""")
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()


###


pos_reads = {}
neg_reads = pos_reads if args.unstranded else {}


# Here, we are determining the chromosomes and their sizes. This is necessary
# for creating the bigWig header(s) and for figuring out which reads to filter
# out (those that do not map to the provided chromosomes).
#
# Because we allow you to provide either a two-column chrom_sizes file or a
# FASTA file, we need code to handle the two situations. We use pyfaidx to
# quickly process the FASTA file so that we do not have to scan through the
# entire thing just to get the sizes.


chrom_sizes = []

# If provided a FASTA file, read the lengths of the sequences
if args.sizes.endswith(".fa"):
	fa = pyfaidx.Fasta(args.sizes)
	for chrom, seq in fa.items():
		chrom_sizes.append((chrom, len(seq)))

# If provided a chrom_sizes file, just use the provided lengths
else:
	with open(args.sizes, "r") as size_file:
		for line in size_file:
			chrom, size = line.strip("\r\n").split()
			chrom_sizes.append((chrom, int(size)))

# Create dictionaries for each chrom, regardless
for chrom, _ in chrom_sizes:
	pos_reads[chrom] = collections.defaultdict(int)
	if not args.unstranded:
		neg_reads[chrom] = collections.defaultdict(int)


###


# Here, we open the bigWigs that we will be saving data into. If the data is
# stranded, we are saving two bigWigs. If the data is not stranded, we are only
# saving one bigWig.

if args.unstranded:
    bw_pos = pyBigWig.open(args.name + ".bw", "w")
    bw_pos.addHeader(chrom_sizes, maxZooms=args.zooms)

else:
    bw_pos = pyBigWig.open(args.name + ".+.bw", "w")
    bw_neg = pyBigWig.open(args.name + ".-.bw", "w")

    bw_pos.addHeader(chrom_sizes, maxZooms=args.zooms)
    bw_neg.addHeader(chrom_sizes, maxZooms=args.zooms)


###


missing_chroms = set()

# This is the main loop that goes through the reads and records them in
# one or two dictionaries (depending on if the data is stranded). The
# processing of BAM/SAM files relies on pysam whereas tsv/tsv.gz files use
# basic file iteration. The processing of reads in both cases is largely the
# same except that, usually, the -f flag will be passed in for .tsv/.tsv.gz
# files because those come from fragments from ATAC-seq-like experiments,
# whereas BAM/SAM files are usually just reads.

for filename in args.filename:
	
	if filename.endswith(".bam"):
		bam = pysam.AlignmentFile(filename, "rb")
		for read in tqdm(bam.fetch(until_eof=True), disable=not args.verbose):
			if read.is_unmapped:
				continue

			# Check whether the chrom is in the allowable chroms. Otherwise, discard.
			
			chrom = read.reference_name
			if chrom not in pos_reads:
				if chrom not in missing_chroms:
					missing_chroms.add(chrom)
					if args.verbose:
						print("{} encountered in input but not in FASTA/chrom sizes.".format(
							chrom))
			
			start = read.reference_start + args.pos_shift
			end = read.reference_end + args.neg_shift
			
			# Here, we need to deal with two related issues.
			#
			#    (1) Does the read map to the fwd or rev strand?
			#    (2) Are we mapping the start or the strand and the end (fragments)?
			#
			# Accordingly, we first check to see the strand the read is on and take the
			# start of the read (start for fwd, end-1 for bwd reads). Then, we need to
			# check whether we want both starts and ends and record both if so. This
			# strategy works even if the underlying data is not stranded because
			# pos_reads and neg_reads are the same dictionary in that case.
			
			if read.is_forward:
				pos_reads[chrom][start] += 1
				if args.fragments:
					pos_reads[chrom][end-1] += 1
				
			else:
				neg_reads[chrom][end-1] += 1
				if args.fragments:
					neg_reads[chrom][start] += 1
		
		bam.close()
	
	elif filename.endswith(".tsv") or filename.endswith(".tsv.gz"):
		# Open the file using the correct opener -- the standard one if the file is
		# not compressed, otherwise the gzip opener if gzipped.
		
		if filename.endswith(".tsv"):
			f = open(filename, "r")
		elif filename.endswith(".tsv.gz"):
			f = gzip.open(filename, "rt")
		
		# Here, we process the entries in a similar manner to using pysam except that
		# we assume the coordinates are all fwd strand. We do not explicitly assume
		# that we want both the start and the end of the entry, which is controlled
		# using the -f flag, but we do not handle strandedness here.
		
		for line in tqdm(f, disable=not args.verbose):
			# Check whether the chrom is in the allowable chroms. Otherwise, discard.
			
			chrom, start, end = line.split()[:3]
			if chrom not in pos_reads:
				if chrom not in missing_chroms:
					missing_chroms.add(chrom)
					if args.verbose:
						print("{} encountered in input but not in FASTA/chrom sizes.".format(
							chrom))
								
			start = int(float(start)) + args.pos_shift
			end = int(float(end)) + args.neg_shift
			
			pos_reads[chrom][start] += 1
			if args.fragments:
				pos_reads[chrom][end-1] += 1
	
	else:
		raise ValueError("Input filename must end in .bam, .sam, .tsv, or .tsv.gz")


###

# Now that we have our dictionary(ies) of reads, we need to create bigWig
# objects and store them. Because the entries need to be sorted along the
# length of each chromosome we have to convert the dictionaries to numpy arrays
# and then sort them, but this usually is not that big of a hassle. It is much
# faster to sort the arrays at the end like this that it is to try to keep
# everything in order as you see the reads.
#
# We use pyBigWig for our tool to create bigWig files. We choose to save
# entries using only the coordinate and the value, so two numbers per non-zero
# position, rather than as spans of size 1 which would be three numbers per
# non-zero position. This reduces file size and also I/O time.

for chrom, _ in chrom_sizes:
	reads = pos_reads[chrom]
	if len(reads) > 0:
		pos_starts = numpy.array(list(reads.keys()), dtype='int64')
		pos_values = numpy.array(list(reads.values()), dtype='float64')
		
		idxs = numpy.argsort(pos_starts)
		bw_pos.addEntries(chrom, pos_starts[idxs], values=pos_values[idxs], span=1)

	###
	
	reads = neg_reads[chrom]
	if len(reads) > 0 and not args.unstranded:
		neg_starts = numpy.array(list(reads.keys()), dtype='int64')
		neg_values = numpy.array(list(reads.values()), dtype='float64')
		
		idxs = numpy.argsort(neg_starts)
		bw_neg.addEntries(chrom, neg_starts[idxs], values=neg_values[idxs], span=1)
	
bw_pos.close()
if not args.unstranded:
	bw_neg.close()
