#!/usr/bin/env python

"""
This tool will convert one or more SAM/BAM file/s into a pair of stranded
BigWig files containing the 5' counts. The input files can be either stored
locally or remote URLs. When multiple files are provided, the reads are
concatenated. Only fully mapped reads are kept.
"""

import numpy
import argparse
import collections

import pysam
import pyBigWig

from tqdm import tqdm

parser = argparse.ArgumentParser(
    prog='bam2bw',
    description='This tool will convert BAM files to bigwig files without an intermediate.')

parser.add_argument('filename', nargs='+',
	help="""The SAM/BAM file to be processed.""")         
parser.add_argument('-s', '--sizes', required=True, 
	help="""A chromosome sizes file.""")
parser.add_argument('-n', '--name', required=True)
parser.add_argument('-ps', '--pos_shift', default=0,
	help="""A shift to apply to positive strand reads.""")
parser.add_argument('-ns', '--neg_shift', default=0,
	help="""A shift to apply to negative strand reads.""") 
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()


###


bw_pos = pyBigWig.open(args.name + ".+.bw", "w")
bw_neg = pyBigWig.open(args.name + ".-.bw", "w")

pos_reads = {}
neg_reads = {}

chrom_sizes = []
with open(args.sizes, "r") as size_file:
	for line in size_file:
		chrom, size = line.strip("\r\n").split()
		chrom_sizes.append((chrom, int(size)))
		
		pos_reads[chrom] = collections.defaultdict(int)
		neg_reads[chrom] = collections.defaultdict(int)

bw_pos.addHeader(chrom_sizes, maxZooms=0)
bw_neg.addHeader(chrom_sizes, maxZooms=0)


###


for filename in args.filename:
	bam = pysam.AlignmentFile(filename, "rb")
	for read in tqdm(bam.fetch(until_eof=True), disable=not args.verbose):
		if read.is_unmapped:
			continue
		
		chrom = read.reference_name
		start = read.reference_start + args.pos_shift
		end = read.reference_end + args.neg_shift
		
		if read.is_forward:
			pos_reads[chrom][start] += 1
		else:
			neg_reads[chrom][end-1] += 1
	
	bam.close()


###


for chrom, _ in chrom_sizes:
	reads = pos_reads[chrom]
	if len(reads) > 0:
		pos_starts = numpy.array(list(reads.keys()), dtype='int64')
		pos_values = numpy.array(list(reads.values()), dtype='float64')
		
		idxs = numpy.argsort(pos_starts)
		bw_pos.addEntries(chrom, pos_starts[idxs], values=pos_values[idxs], span=1)

	###
	
	reads = neg_reads[chrom]
	if len(reads) > 0:
		neg_starts = numpy.array(list(reads.keys()), dtype='int64')
		neg_values = numpy.array(list(reads.values()), dtype='float64')
		
		idxs = numpy.argsort(neg_starts)
		bw_neg.addEntries(chrom, neg_starts[idxs], values=neg_values[idxs], span=1)
	
bw_pos.close()
bw_neg.close()
