#!/usr/bin/env python

import numpy as np
import pdb
import pickle
from argparse import ArgumentParser
from provenance_tools.write_provenance_data import write_provenance_data
from bayes_traj.fit_stats import ave_pp, odds_correct_classification

desc = """"""

parser = ArgumentParser(description=desc)
parser.add_argument('--model', help='Bayesian trajectory model to summarize',
    type=str, required=True)

op = parser.parse_args()

with open(op.model, 'rb') as f:
    mm = pickle.load(f)['MultDPRegression']

bic = mm.bic()
traj_ids = np.where(mm.sig_trajs_)[0]

print(" ")
print("---------- Summary ----------")
print("Number of trajectories: {}".format(np.sum(mm.sig_trajs_)))
print("Trajectories: {}".format(', '.join(list(traj_ids.astype('str')))))
print("Predictors: {}".format(', '.join(mm.predictor_names_)))
print("Targets: {}".format(', '.join(mm.target_names_)))

print(" ")
print("---------- Information Criteria ----------")

if len(bic) == 2:
    print("BIC1: {}".format(int(bic[0])))
    print("BIC2: {}".format(int(bic[1])))
else:
    print("BIC: {}".format(int(bic)))

print("WAIC2: {}".format(int(mm.compute_waic2())))

print(" ")
print("---------- Posterior Stats ----------")

print("Average posterior probability of assignment:")
ave_pps = ave_pp(mm)
for t in traj_ids:
    print("\t Trajectory {}: {:.3f}".format(t, ave_pps[t]))

print("Odds of correct classification:")    
occs = odds_correct_classification(mm)
for t in traj_ids:
    print("\t Trajectory {}: {:.2f}".format(t, occs[t]))
