#!python
#
#   Copyright 2016-2019 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:51 $
#       $Id: showstxcorr,v 1.11 2016/06/14 12:04:51 frederic Exp $
#
from __future__ import division, print_function

import getopt
import sys

import joblib
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.cluster import DBSCAN, KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA, FastICA, IncrementalPCA
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.feature_selection import (
    RFE,
    SelectFdr,
    SelectKBest,
    SelectPercentile,
    f_classif,
)
from sklearn.preprocessing import StandardScaler

import capcalc.filter as ccalc_filt
import capcalc.fit as ccalc_fit
import capcalc.io as ccalc_io
import capcalc.miscmath as ccalc_math
import capcalc.stats as ccalc_stats

try:
    import hdbscan as hdbs

    hdbpresent = True
    print("hdbscan is present")
except:
    hdbpresent = False

import capcalc.utils as capcalc_utils


def usage():
    print("")
    print("capfromtcs - calculate and cluster coactivation patterns for a set of timecourses")
    print("")
    print("usage: capfromtcs -i timecoursefile -o outputfile --samplefreq=FREQ --sampletime=TSTEP")
    print("                  [--nodetrend] [-s STARTTIME] [-D DURATION]")
    print("                  [-F LOWERFREQ,UPPERFREQ[,LOWERSTOP,UPPERSTOP]] [-V] [-L] [-R] [-C]")
    print(
        "                  [-m] [-n NUMCLUSTER] [-b BATCHSIZE] [-S SEGMENTSIZE] [-E SEGMENTTYPE] [-I INITIALIZATIONS]"
    )
    print(
        "                  [--noscale] [--nonorm] [--pctnorm] [--varnorm] [--stdnorm] [--ppnorm] [--quality]"
    )
    print("                  [--pca] [--ica] [-p NUMCOMPONENTS] --modelroot=MODELROOT")
    print("")
    print("required arguments:")
    print("    -i, --infile=TIMECOURSEFILE  - text file multiple timeseries")
    print("    -o, --outfile=OUTNAME        - the root name of the output files")
    print("")
    print("    --samplefreq=FREQ            - sample frequency of all timecourses is FREQ ")
    print("           or")
    print("    --sampletime=TSTEP           - time step of all timecourses is TSTEP ")
    print(
        "                                   NB: --samplefreq and --sampletime are two ways to specify"
    )
    print("                                   the same thing.")
    print("")
    print("optional arguments:")
    print("")
    print("  Data selection/partition:")
    print(
        "    -s STARTTIME                 - time of first datapoint to use in seconds in the first file"
    )
    print("    -D DURATION                  - amount of data to use in seconds")
    print("    -S SEGMENTSIZE,[SEGSIZE2,...SEGSIZEN]")
    print(
        "                                 - treat the timecourses as segments of length SEGMENTSIZE for preprocessing."
    )
    print("    -E SEGTYPE,SEGTYPE2[,...SEGTYPEN]")
    print(
        "                                 - group subsegments for summary statistics.  All subsegments in the same group must be the same length"
    )
    print(
        "                                   If there are multiple, comma separated numbers, treat these as subsegment lengths."
    )
    print("                                   Default segmentsize is the entire length")
    print("  --skippts=NUMPTS               - drop first NUMPTS points from each segment")
    print("  Clustering:")
    print(
        "    -m                           - run MiniBatch Kmeans rather than conventional - use with very large datasets"
    )
    print(
        "    -n NUMCLUSTER                - set the number of clusters to NUMCLUSTER (default is 8)"
    )
    print(
        "    -b BATCHSIZE                 - use a batchsize of BATCHSIZE if doing MiniBatch - ignored if not.  Default is 1000"
    )
    print("    --dbscan                     - perform dbscan clustering")
    print("    --hdbscan                    - perform hdbscan clustering")
    print(
        "    -I INITIALIZATIONS           - Restart KMeans INITIALIZATIONS times to find best fit (default is 1000)"
    )
    print("")
    print("  Preprocessing:")
    print(
        "    -F LOWERFREQ,UPPERFREQ       - filter data and regressors from LOWERFREQ to UPPERFREQ."
    )
    print(
        "                                   LOWERSTOP and UPPERSTOP can be specified, or will be calculated automatically"
    )
    print("    -V                           - filter data and regressors to VLF band")
    print("    -L                           - filter data and regressors to LFO band")
    print("    -R                           - filter data and regressors to respiratory band")
    print("    -C                           - filter data and regressors to cardiac band")
    print("    --nodetrend                  - do not detrend the data before correlation")
    print("    --noscale                    - don't perform vector magnitude scaling")
    print("    --nonorm                     - don't normalize timecourses")
    print("    --pctnorm                    - scale each timecourse to its percentage of the mean")
    print(
        "    --varnorm                    - scale each timecourse to have a variance of 1.0 (default)"
    )
    print(
        "    --stdnorm                    - scale each timecourse to have a standard deviation of 1.0"
    )
    print(
        "    --ppnorm                     - scale each timecourse to have a peak to peak range of 1.0"
    )
    print(
        "    --pca                        - perform PCA dimensionality reduction prior to analysis"
    )
    print(
        "    --ica                        - perform ICA dimensionality reduction prior to analysis"
    )
    print(
        "    -p NUMCOMPONENTS             - set the number of p/ica components to NUMCOMPONENTS (default is 8).  Set to -1 to estimate"
    )
    print("    --noscale                    - do not apply standard scaler before cluster fitting")
    print("    --preproconly                - do preprocessing then quit")
    print(
        "    --minout=MINOUT              - transitions out of a state shorter than MINOUT will be patched.  Default is 1"
    )
    print(
        "    --minhold=MINHOLD)           - time in a state shorter than MINHOLD will be assigned to the previous state.  Default is 1"
    )
    print("")
    print("  Other:")
    print(
        "    --GBR                        - apply gradient boosting regressor testing on clusters"
    )
    print("    -d                           - display some quality metrics")
    print("    --quality                    - perform a silhouette test to evaluate fit quality")
    print("    -v                           - turn on verbose mode")
    print(
        "    --modelroot=MODELROOT        - reread trained models from a previous run - MODELROOT should\n"
        "                                   be the outputfile from the previous run (i.e. what followed -o)"
    )
    print(
        "    --initialcenters=CENTERSFILE - reread cluster centers from a previous run - CENTERSFILE should\n"
        "                                   be the XXX_clustercenters.txt file from a previous run (i.e. what\n"
        "                                   followed -o).  This overrides the number of clusters given, and must\n"
        "                                   match the number of PCA or ICA components given (if used) or the number of features."
    )
    print("")
    return ()


# get the command line parameters
summaryonly = True

# preprocessing options
preprocessingtype = None
detrendorder = 1
timenormmethod = "varnorm"

# clustering/partitioning options
minibatch = False
n_clusters = 8
n_pca = 8
max_iter = 250
n_init = 100
batch_size = 1000
clustertype = "kmeans"
clustertype = "kmeans"
connfilename = None
affinity = "euclidean"
linkage = "ward"
eps = 0.3
min_samples = 100
alpha = 1.0
standardscale = True
skippts = 0
minoutlength = 1
minholdlength = 1

duration = 100000000.0
starttime = 0.0
usebutterworthfilter = False
filtorder = 3
verbose = True

doGBR = False
display = False
preproconly = False

trainedmodelroot = None
initialcenters = None


# scan the command line arguments
try:
    opts, args = getopt.gnu_getopt(
        sys.argv[1:],
        "di:o:s:D:F:S:E:VLRCmn:p:b:I:v",
        [
            "infile=",
            "outfile=",
            "nodetrend",
            "dbscan",
            "hdbscan",
            "GBR",
            "pca",
            "ica",
            "noscale",
            "preproconly",
            "nonorm",
            "pctnorm",
            "varnorm",
            "stdnorm",
            "ppnorm",
            "skippts=",
            "minout=",
            "minhold=",
            "modelroot=",
            "initialcenters=",
            "quality",
            "samplefreq=",
            "sampletime=",
            "help",
        ],
    )
except getopt.GetoptError as err:
    # print help information and exit:
    print(str(err))  # will print something like "option -x not recognized"
    usage()
    sys.exit(2)

if len(args) > 1:
    print("capfromtcs takes no unflagged arguments")
    print(args)
    sys.exit(2)

# unset all required arguments
infilename = []
segsize = -1
subsegs = []
subseggroupIDs = None
sampletime = None
Fs = None
outfilename = None

theprefilter = ccalc_filt.NoncausalFilter(transferfunc="butterworth")
theprefilter.setbutterorder(filtorder)

# set the default characteristics
theprefilter.settype("None")

for o, a in opts:
    if o == "--infile" or o == "-i":
        infilename.append(a)
        if verbose:
            print("will use", infilename[-1], "as an input file")
    elif o == "--outfile" or o == "-o":
        outfilename = a
        if verbose:
            print("will use", outfilename, "as output file")
    elif o == "-S":
        for seg in a.split(","):
            subsegs.append(int(seg))
        segsize = np.sum(np.asarray(subsegs))
        print("SUBSEGS:", subsegs)
        if verbose:
            print("Setting segment size to ", segsize)
    elif o == "-E":
        subseggroupIDs = []
        for seg in a.split(","):
            subseggroupIDs.append(seg)
        print("SUBSEGGROUPIDS:", subseggroupIDs)
    elif o == "--samplefreq":
        Fs = float(a)
        sampletime = 1.0 / Fs
        linkchar = "="
        if verbose:
            print("Setting sample frequency to ", Fs)
    elif o == "--sampletime":
        sampletime = float(a)
        Fs = 1.0 / sampletime
        linkchar = "="
        if verbose:
            print("Setting sample time step to ", sampletime)
    elif o == "-display":
        display = True
        if verbose:
            print("will display quality metrics")
    elif o == "--preproconly":
        preproconly = True
        if verbose:
            print("only do preprocessing through PCA/ICA")
    elif o == "--noscale":
        standardscale = False
        if verbose:
            print("will not magnitude scale feature vectors")
    elif o == "--nonorm":
        timenormmethod = "none"
        if verbose:
            print("will do no normalization")
    elif o == "--pctnorm":
        timenormmethod = "pctnorm"
        if verbose:
            print("will do percent normalization")
    elif o == "--stdnorm":
        timenormmethod = "stdnorm"
        if verbose:
            print("will do std dev normalization")
    elif o == "--varnorm":
        timenormmethod = "varnorm"
        if verbose:
            print("will do variance normalization")
    elif o == "--ppnorm":
        timenormmethod = "ppnorm"
        if verbose:
            print("will do p-p normalization")
    elif o == "--modelroot":
        trainedmodelroot = a
        if verbose:
            print("will read trained models from " + trainedmodelroot + "_*.joblib")
    elif o == "--initialcenters":
        initialcenters = a
        if verbose:
            print("will read in the initial cluster centers from", initialcenters)
    elif o == "--minhold":
        minholdlength = int(a)
        if verbose:
            print("residency in a state shorter than", minholdlength, "will be patched")
    elif o == "--minout":
        minoutlength = int(a)
        if verbose:
            print(
                "transitions out of a state shorter than",
                minoutlength,
                "will be patched",
            )
    elif o == "--skippts":
        skippts = int(a)
        if verbose:
            print("will drop first", skippts, "points from each segnent")
    elif o == "--quality":
        summaryonly = False
        if verbose:
            print("will do silhouette test")
    elif o == "-v":
        verbose = True
        if verbose:
            print("verbose mode enabled")
    elif o == "--ica":
        preprocessingtype = "ica"
        if verbose:
            print("will perform ica dimensionality reduction step")
    elif o == "--GBR":
        doGBR = True
        if verbose:
            print("will do GBR on clusters")
    elif o == "--pca":
        preprocessingtype = "pca"
        if verbose:
            print("will perform pca dimensionality reduction step")
    elif o == "--hdbscan":
        clustertype = "hdbscan"
        if not hdbpresent:
            print("hdbs is not installed, cannot perform hdbscan clustering.  Exiting")
            sys.exit()
        if verbose:
            print("switching to hdbscan clustering")
    elif o == "--dbscan":
        clustertype = "dbscan"
        if verbose:
            print("switching to dbscan clustering")
    elif o == "--nodetrend":
        detrendorder = 0
        if verbose:
            print("disabling detrending")
    elif o == "-D":
        duration = float(a)
        if verbose:
            print("duration set to", duration)
    elif o == "-s":
        starttime = float(a)
        if verbose:
            print("starttime set to", starttime)
    elif o == "-V":
        theprefilter.settype("vlf")
        if verbose:
            print("prefiltering to vlf band")
    elif o == "-L":
        theprefilter.settype("lfo")
        if verbose:
            print("prefiltering to lfo band")
    elif o == "-R":
        theprefilter.settype("resp")
        if verbose:
            print("prefiltering to respiratory band")
    elif o == "-C":
        theprefilter.settype("card")
        if verbose:
            print("prefiltering to cardiac band")
    elif o == "-F":
        arbvec = a.split(",")
        if len(arbvec) != 2 and len(arbvec) != 4:
            usage()
            sys.exit()
        if len(arbvec) == 2:
            arb_lower = float(arbvec[0])
            arb_upper = float(arbvec[1])
            arb_lowerstop = 0.9 * float(arbvec[0])
            arb_upperstop = 1.1 * float(arbvec[1])
        if len(arbvec) == 4:
            arb_lower = float(arbvec[0])
            arb_upper = float(arbvec[1])
            arb_lowerstop = float(arbvec[2])
            arb_upperstop = float(arbvec[3])
        theprefilter.settype("arb")
        theprefilter.setfreqs(arb_lowerstop, arb_lower, arb_upper, arb_upperstop)
        if verbose:
            print(
                "prefiltering to ",
                arb_lower,
                arb_upper,
                "(stops at ",
                arb_lowerstop,
                arb_upperstop,
                ")",
            )
    elif o == "-m":
        minibatch = True
        print("will perform MiniBatchKMeans")
    elif o == "-b":
        batch_size = int(a)
        print("will use", batch_size, "as batch_size")
    elif o == "-I":
        n_init = int(a)
        print("will do", n_init, "initializations")
    elif o == "-n":
        n_clusters = int(a)
        print("will use", n_clusters, "clusters")
    elif o == "-p":
        n_pca = float(a)
        if n_pca <= 0.0:
            print("will estimate the number of pca components for dimensionality reduction")
        elif n_pca < 1.0:
            print(
                "will use enough pca components to explain at least",
                100.0 * n_pca,
                "% of the variance",
            )
        else:
            n_pca = int(a)
            print("will use", n_pca, "pca components for dimensionality reduction")
    else:
        assert False, "unhandled option"

# check that required arguments are set
if outfilename is None:
    print("outfile must be set")
    usage()
    sys.exit()

# check to make sure groups and subsegments are in agreement
groups = {}
if subsegs is not None:
    print("subsegs:", subsegs)
    if subseggroupIDs is None:
        subseggroupIDs = []
        for i in range(len(subsegs)):
            subseggroupIDs.append("group_" + str(i))
    print("subseggroupIDs:", subseggroupIDs)
    if len(subsegs) != len(subseggroupIDs):
        print("number of subsegment group IDs must match number of subsegs")
        sys.exit()
    for segnum in range(len(subseggroupIDs)):
        try:
            groups[subseggroupIDs[segnum]]["segnum"].append(int(segnum))
        except:
            groups[subseggroupIDs[segnum]] = {}
            groups[subseggroupIDs[segnum]]["seglen"] = []
            groups[subseggroupIDs[segnum]]["segstart"] = []
            groups[subseggroupIDs[segnum]]["segnum"] = [int(segnum)]
        groups[subseggroupIDs[segnum]]["seglen"].append(subsegs[segnum])
        groups[subseggroupIDs[segnum]]["segstart"].append(int(np.sum(subsegs[0:segnum])))
    for key in groups:
        groupsegs = groups[key]["seglen"]
        if not all(x == groupsegs[0] for x in groupsegs):
            print("all subsegments in a group must have the same length")
            sys.exit()
    print(groups)

if sampletime is None:
    print("sampletime must be set")
    usage()
    sys.exit()

# read in cluster centers if specified
if initialcenters is not None:
    theclustercenters = np.transpose(ccalc_io.readvecs(initialcenters))
    ccentershape = theclustercenters.shape
    print("clustercenter shape:", ccentershape)
    n_clusters = ccentershape[0]
    targetfeatures = ccentershape[1]
    print("will use", n_clusters, "clusters and", targetfeatures, "features")

if timenormmethod == "none":
    print("will not normalize timecourses")
elif timenormmethod == "pctnorm":
    print("will normalize timecourses to percentage of mean")
elif timenormmethod == "stdnorm":
    print("will normalize timecourses to standard deviation of 1.0")
elif timenormmethod == "varnorm":
    print("will normalize timecourses to variance of 1.0")
elif timenormmethod == "ppnorm":
    print("will normalize timecourses to p-p deviation of 1.0")

# save the command line
ccalc_io.writevec([" ".join(sys.argv)], outfilename + "_commandline.txt")

# read in the files and get everything trimmed to the right length
startpoint = max([int(starttime * Fs), 0]) + skippts
if len(infilename) == 1:
    # each column is a timecourse, each row is a timepoint
    print("processing single input file")
    matrixoutput = True
    inputdata = ccalc_io.readvecs(infilename[0])[skippts:]
    if verbose:
        print("input data shape is ", inputdata.shape)
    numpoints = inputdata.shape[1]
    endpoint = min([startpoint + int(duration * Fs), numpoints])
    trimmeddata = inputdata[:, startpoint:endpoint]
elif len(infilename) == 2:
    print("processing two input files")
    inputdata1 = ccalc_io.readvec(infilename[0])
    numpoints = len(inputdata1)
    inputdata2 = ccalc_io.readvec(infilename[1])
    endpoint1 = min([startpoint + int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))])
    endpoint2 = min([int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))])
    trimmeddata = np.zeros((2, numpoints), dtype="float")
    trimmeddata[0, :] = inputdata1[startpoint:endpoint1]
    trimmeddata[1, :] = inputdata2[0:endpoint2]
else:
    print(
        "showstxcorr requires 1 multicolumn timecourse file or two single column timecourse files as input"
    )
    usage()
    sys.exit()

# band limit the regressors if that is needed
if theprefilter.gettype() != "None":
    if verbose:
        print("filtering to ", theprefilter.gettype(), " band")
else:
    if verbose:
        print("no prefiltering applied")

origdims = inputdata.shape
thedims = trimmeddata.shape
print("original file dimensions:", origdims)
print("trimmed file dimensions:", thedims)
n_features = thedims[0]
n_samples = thedims[1]
if segsize < 0:
    segsize = n_samples
    subsegs.append(segsize)
print(
    "input dataset has",
    n_features,
    "features and",
    n_samples,
    "samples in segments of size",
    segsize,
)
if len(subsegs) > 1:
    print("    segment is broken into", len(subsegs), "subsegments of length", subsegs)
reformdata = np.reshape(trimmeddata, (n_features, n_samples))
if n_samples % segsize > 0:
    print(
        "segment size (",
        segsize,
        ") is not an even divisor of the total length (",
        n_samples,
        ")- exiting",
    )
    sys.exit()
else:
    numsegs = int(n_samples // segsize)

for feature in range(n_features):
    if verbose:
        print("preprocessing feature", feature)
    for segment in range(numsegs):
        subsegstart = segment * segsize
        for subseglen in subsegs:
            if detrendorder > 0:
                segdata = ccalc_fit.detrend(
                    reformdata[feature, subsegstart : subsegstart + subseglen]
                )
            else:
                segdata = reformdata[feature, subsegstart : subsegstart + subseglen]

            if timenormmethod == "none":
                segnorm = segdata - np.mean(segdata)
            elif timenormmethod == "pctnorm":
                segnorm = ccalc_math.pcnormalize(segdata)
            elif timenormmethod == "varnorm":
                segnorm = ccalc_math.varnormalize(segdata)
            elif timenormmethod == "stdnorm":
                segnorm = ccalc_math.stdnormalize(segdata)
            elif timenormmethod == "ppnorm":
                segnorm = ccalc_math.ppnormalize(segdata)
            else:
                segnorm = segdata

            reformdata[feature, subsegstart : subsegstart + subseglen] = theprefilter.apply(
                Fs, segnorm
            )
            subsegstart += subseglen
X = np.nan_to_num(np.transpose(reformdata))

if standardscale:
    X = StandardScaler().fit_transform(X)

if preprocessingtype == "pca":
    print("running PCA")
    print("shape going in:", X.shape)
    if trainedmodelroot is None:
        print("running PCA")
        if n_pca <= 0:
            thepca = PCA(n_components="mle", svd_solver="full").fit(X)
        else:
            thepca = PCA(n_components=n_pca).fit(X)

        # save the model
        joblib.dump(thepca, outfilename + "_pca.joblib")
    else:
        modelfilename = trainedmodelroot + "_pca.joblib"
        print("reading PCA from", modelfilename)
        try:
            thepca = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    thetransform = thepca.transform(X)
    X = thepca.inverse_transform(thetransform)
    print("shape coming out:", X.shape)
    for i in range(thepca.n_components_):
        print(
            "component",
            i,
            "explained variance:",
            thepca.explained_variance_[i],
            "explained variance %:",
            100.0 * thepca.explained_variance_ratio_[i],
        )
    ccalc_io.writenpvecs(thepca.components_, outfilename + "_pcacomponents.txt")
    ccalc_io.writenpvecs(
        np.transpose(thepca.components_), outfilename + "_pcacomponents_transpose.txt"
    )
elif preprocessingtype == "ica":
    print("running FastICA")
    if n_pca <= 1.0:
        n_pca = int(0)
    if trainedmodelroot is None:
        theica = FastICA(n_components=n_pca, algorithm="deflation").fit(X)

        # save the model
        joblib.dump(theica, outfilename + "_ica.joblib")
    else:
        modelfilename = trainedmodelroot + "_ica.joblib"
        print("reading ICA from", modelfilename)
        try:
            theica = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    thetransform = theica.transform(X)
    X = theica.inverse_transform(thetransform)
    ccalc_io.writenpvecs(theica.components_, outfilename + "_icacomponents.txt")
    ccalc_io.writenpvecs(
        np.transpose(theica.components_), outfilename + "_icacomponents_transpose.txt"
    )

ccalc_io.writenpvecs(reformdata, outfilename + "_preprocessed.txt")
if preproconly:
    print("preprocessing done - quitting")
    sys.exit()

if clustertype == "kmeans":
    print("setting up kmeans")
    if trainedmodelroot is None:
        if initialcenters is None:
            theinit = "k-means++"
        else:
            theinit = theclustercenters
            max_iter = 1

        print("training model")
        if minibatch:
            kmeans = MiniBatchKMeans(
                n_clusters=n_clusters, batch_size=batch_size, max_iter=max_iter, init=theinit
            ).fit(X)
        else:
            kmeans = KMeans(
                n_clusters=n_clusters, max_iter=max_iter, n_init=n_init, init=theinit
            ).fit(X)

        # save the model
        joblib.dump(kmeans, outfilename + "_kmeans.joblib")
    else:
        modelfilename = trainedmodelroot + "_kmeans.joblib"
        print("reading kmeans model from", modelfilename)
        try:
            kmeans = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    theclusters = np.transpose(kmeans.cluster_centers_)
    thestatelabels = kmeans.predict(X)
    # thestatelabels = kmeans.labels_
    print("thestatelabels shape", thestatelabels.shape)
    print("kmeans done")
    ccalc_io.writenpvecs(theclusters, outfilename + "_clustercenters.txt")

    # make normalized clusters
    thenormclusters = theclusters * 0.0
    themeans = np.mean(theclusters, axis=0)
    thestds = np.std(theclusters, axis=0)
    print("themeans:", themeans)
    print("thestds:", thestds)
    print("shape:", theclusters.shape)
    for i in range(theclusters.shape[1]):
        thenormclusters[:, i] = (theclusters[:, i] - themeans[i]) / thestds[i]
    ccalc_io.writenpvecs(thenormclusters, outfilename + "_norm_clustercenters.txt")

    # save the states
    ccalc_io.writenpvecs(thestatelabels, outfilename + "_statelabels.txt")

    # find most important features
    print("finding most important features")
    # rfe = RFE(kmeans, 10)
    # rfe.fit(X, thestatelabels)
    # print(rfe.support_)
    # print(rfe.ranking_)

    print(
        "calling SelectPercentiles with X and y of dimensions",
        X.shape,
        thestatelabels.shape,
    )
    selector = SelectPercentile(f_classif, percentile=10)
    selector.fit(X, thestatelabels)
    print(selector.get_params())
    X_indices = np.arange(X.shape[-1])
    scores = -np.nan_to_num(np.log10(np.nan_to_num(selector.pvalues_)))
    scores /= scores.max()
    sortedscores = np.sort(np.nan_to_num(selector.scores_))[::-1]
    print(sortedscores)
    if display:
        plt.bar(
            X_indices - 0.45,
            scores,
            width=0.2,
            label=r"Univariate score ($-Log(p_{value})$)",
            color="darkorange",
        )
        print(selector.get_support(indices=True))
        fig = plt.subplots(1, 1)
        plt.plot(sortedscores)
        plt.show()

    # now do some stats!
    thesilavgs, thesilclusterstats = capcalc_utils.silhouette_test(
        X, kmeans, n_clusters, numsegs, segsize, summaryonly
    )
    ccalc_io.writenpvecs(thesilavgs, outfilename + "_silhouettesegmentstats.txt")

    silinfo = []
    for state in range(n_clusters):
        silinfo.append([])
    print("shape going in:", thestatelabels.shape)
    statelabelsbysegment = np.reshape(thestatelabels, (-1, segsize))
    print("shape coming out:", statelabelsbysegment.shape)
    meaninstate = np.zeros((n_clusters, segsize), dtype="float")
    stdinstate = np.zeros((n_clusters, segsize), dtype="float")

    # do the subsegment summaries
    for key in groups:
        groups[key]["meaninstate"] = np.zeros(
            (n_clusters, groups[key]["seglen"][0]), dtype="float"
        )
        groups[key]["stdinstate"] = np.zeros((n_clusters, groups[key]["seglen"][0]), dtype="float")
        for state in range(n_clusters):
            tcbyseg = []
            for seginstance in range(len(groups[key]["segnum"])):
                startpos = groups[key]["segstart"][seginstance]
                endpos = startpos + groups[key]["seglen"][seginstance]
                tcbyseg.append(np.where(statelabelsbysegment[:, startpos:endpos] == state, 1, 0))
            groups[key]["meaninstate"][state, :] = np.mean(np.concatenate(tcbyseg, axis=0), axis=0)
            groups[key]["stdinstate"][state, :] = np.std(np.concatenate(tcbyseg, axis=0), axis=0)
        ccalc_io.writenpvecs(
            groups[key]["meaninstate"],
            outfilename + "_" + str(key) + "_meaninstate.txt",
        )
        ccalc_io.writenpvecs(
            groups[key]["stdinstate"], outfilename + "_" + str(key) + "_stdinstate.txt"
        )
    allstatestats = []
    allrawtransmats = []
    alllenlists = []
    for i in range(n_clusters):
        alllenlists.append([])
    for segment in range(numsegs):
        thesestatelabels = thestatelabels[segment * segsize : (segment + 1) * segsize]

        outputaffine = np.eye(4)
        rawtransmat, thestats, lenlist = capcalc_utils.statestats(
            thesestatelabels, n_clusters, 0, minout=minoutlength, minhold=minholdlength
        )
        allrawtransmats.append(rawtransmat * 1.0)
        allstatestats.append(thestats)
        for i in range(n_clusters):
            alllenlists[i] += lenlist[i]
        normtransmat, offdiagtransmat = capcalc_utils.calcmats(rawtransmat, n_clusters)
        init_img = nib.Nifti1Image(normtransmat, outputaffine)
        init_hdr = init_img.header
        init_sizes = init_hdr["pixdim"]
        ccalc_io.savetonifti(
            np.transpose(rawtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_rawtransmat",
        )
        ccalc_io.savetonifti(
            np.transpose(normtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_normtransmat",
        )
        ccalc_io.savetonifti(
            np.transpose(offdiagtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat",
        )

        # write as text as well
        rows = []
        cols = []
        for i in range(n_clusters):
            rows.append("from state " + str(i + 1))
            cols.append("to state " + str(i + 1))
        df = pd.DataFrame(data=rawtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_rawtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=normtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_normtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=offdiagtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat.csv",
            index=False,
        )
        # rawtransmat files are an n_clusters by n_clusters matrix with the total number of transitions from each state to each other state.
        # normtransmat files are an n_clusters by n_clusters matrix with the total for of transitions from each state to each other state.

        cols = [
            "% TRs in state",
            "Number of runs in state",
            "Total TRs in state",
            "Min run (TRs)",
            "Max run (TRs)",
            "Mean run (TRs)",
            "Median run (TRs)",
            "StdDev run (TRs)",
        ]
        df = pd.DataFrame(data=thestats, columns=cols)
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_statestats.csv",
            index=False,
        )
        # ccalc_io.writenpvecs(np.transpose(thestats), outfilename + '_seg_' + str(segment).zfill(4) + '_statestats.txt')
        thetimestats = 1.0 * thestats
        thetimestats[:, 2:] *= sampletime
        cols = [
            "% Seconds in state",
            "Number of runs in state",
            "Total seconds in state",
            "Min run (sec)",
            "Max run (sec)",
            "Mean run (sec)",
            "Median run (sec)",
            "StdDev run (sec)",
        ]
        df = pd.DataFrame(data=thetimestats, columns=cols)
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_statetimestats.csv",
            index=False,
        )
        # ccalc_io.writenpvecs(np.transpose(thetimestats), outfilename + '_seg_' + str(segment).zfill(4) + '_statetimestats.txt')

        ccalc_io.writenpvecs(
            thesestatelabels,
            outfilename + "_seg_" + str(segment).zfill(4) + "_statelabels.txt",
        )
        print("Segment %d average silhouette Coefficient: %0.3f" % (segment, thesilavgs[segment]))
        for state in range(n_clusters):
            tc = np.where(thesestatelabels == state, 1, 0)
            ccalc_io.writenpvecs(
                tc,
                outfilename
                + "_seg_"
                + str(segment).zfill(4)
                + "_instate_"
                + str(state).zfill(2)
                + ".txt",
            )
        if not summaryonly:
            cols = ["Mean", "Median", "Min", "Max"]
            df = pd.DataFrame(data=np.transpose(thesilclusterstats[segment, :, :]), columns=cols)
            df.to_csv(
                outfilename + "_seg_" + str(segment).zfill(4) + "_silhouetteclusterstats.csv",
                index=False,
            )
            # ccalc_io.writenpvecs(thesilclusterstats[segment, :, :],
            #             outfilename + '_seg_' + str(segment).zfill(4) + '_silhouetteclusterstats.txt')

        for state in range(n_clusters):
            if thestats[state, 2] > 0:
                silinfo[state].append(thesilclusterstats[segment, 0, state])

    # now generate some summary information
    overallstatestats = []
    thetimestats = []
    alllens = 0
    for i in range(n_clusters):
        alllens += np.sum(np.asarray(alllenlists[i], dtype="float"))

    for i in range(n_clusters):
        lenarray = np.asarray(alllenlists[i], dtype="float")
        if len(lenarray) > 2:
            overallstatestats.append(
                [
                    100.0 * np.sum(lenarray) / alllens,
                    len(lenarray),
                    np.sum(lenarray),
                    np.min(lenarray),
                    np.max(lenarray),
                    np.mean(lenarray),
                    np.median(lenarray),
                    np.std(lenarray),
                ]
            )
            thetimestats.append(
                [
                    100.0 * np.sum(lenarray) / alllens,
                    sampletime * len(lenarray),
                    sampletime * np.sum(lenarray),
                    sampletime * np.min(lenarray),
                    sampletime * np.max(lenarray),
                    sampletime * np.mean(lenarray),
                    sampletime * np.median(lenarray),
                    sampletime * np.std(lenarray),
                ]
            )

    cols = [
        "% TRs in state",
        "Number of runs in state",
        "Total TRs in state",
        "Min run (TRs)",
        "Max run (TRs)",
        "Mean run (TRs)",
        "Median run (TRs)",
        "StdDev run (TRs)",
    ]
    df = pd.DataFrame(data=overallstatestats, columns=cols)
    df.to_csv(
        outfilename + "_overall_statestats.csv",
        index=False,
    )

    cols = [
        "% Seconds in state",
        "Number of runs in state",
        "Total seconds in state",
        "Min run (sec)",
        "Max run (sec)",
        "Mean run (sec)",
        "Median run (sec)",
        "StdDev run (sec)",
    ]
    df = pd.DataFrame(data=thetimestats, columns=cols)
    df.to_csv(
        outfilename + "_overall_statetimestats.csv",
        index=False,
    )
    overallrawtransmat = allrawtransmats[0] * 0.0
    for segment in range(numsegs):
        overallrawtransmat += allrawtransmats[segment]
    overallnormtransmat, overalloffdiagtransmat = capcalc_utils.calcmats(
        overallrawtransmat, n_clusters
    )
    init_img = nib.Nifti1Image(overallnormtransmat, outputaffine)
    init_hdr = init_img.header
    init_sizes = init_hdr["pixdim"]
    ccalc_io.savetonifti(
        np.transpose(overallrawtransmat),
        init_hdr,
        outfilename + "_overall_rawtransmat",
    )
    ccalc_io.savetonifti(
        np.transpose(overallnormtransmat),
        init_hdr,
        outfilename + "_overall_normtransmat",
    )
    ccalc_io.savetonifti(
        np.transpose(overalloffdiagtransmat),
        init_hdr,
        outfilename + "_overall_offdiagtransmat",
    )
    themaxlen = 0
    for i in range(n_clusters):
        themaxlen = int(np.max([themaxlen, np.max(alllenlists[i])]))
    for i in range(n_clusters):
        thishist = ccalc_stats.makeandsavehistogram(
            np.array(alllenlists[i]),
            themaxlen,
            0,
            outfilename + "_" + str(i).zfill(2) + "_lenhist",
            therange=[1, themaxlen],
        )
    silavgs = []
    if not summaryonly:
        for state in range(n_clusters):
            silavgs.append(np.mean(np.asarray(silinfo[state], dtype="float")))
        ccalc_io.writenpvecs(
            np.asarray(silavgs, dtype="float"),
            outfilename + "_overallsilhouettemean.txt",
        )
    pctarray = np.asarray(allstatestats[:], dtype="float")
    cols = [
        "% TRs in state",
        "Number of runs in state",
        "Total TRs in state",
        "Min run (TRs)",
        "Max run (TRs)",
        "Mean run (TRs)",
        "Median run (TRs)",
        "StdDev run (TRs)",
    ]
    df = pd.DataFrame(data=np.mean(pctarray, axis=0), columns=cols)
    df.to_csv(
        outfilename + "_seg_" + str(segment).zfill(4) + "_overallmeanstats.csv",
        index=False,
    )
    # ccalc_io.writenpvecs(np.transpose(np.mean(pctarray, axis=0)), outfilename + '_overallmeanstats.txt')

    if doGBR:
        clf = GradientBoostingRegressor().fit(X, thestatelabels)
        print("GBR fitting score is:", clf.score(X, thestatelabels))
        ccalc_io.writenpvecs(
            np.reshape(clf.feature_importances_, (n_features, 1)),
            outfilename + "_featureimportances.txt",
        )

elif clustertype == "dbscan":
    if trainedmodelroot is None:
        db = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1).fit(X)

        # save the model
        joblib.dump(db, outfilename + "_dbscan.joblib")
    else:
        modelfilename = trainedmodelroot + "_dbscan.joblib"
        print("reading dbscan model from", modelfilename)
        try:
            db = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

        db.predict(X)

    print("dbscan done")

    # core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
    # core_samples_mask[db.core_sample_indices_] = True

    thestatelabels = db.labels_
    print(thestatelabels)
    print("thestatelabels shape", thestatelabels.shape)
    ccalc_io.writenpvecs(thestatelabels, outfilename + "_statelabels.txt")

    print("core_sample_indices:", db.core_sample_indices_)
    core_centers = np.transpose(X[db.core_sample_indices_, :])
    ccalc_io.writenpvecs(core_centers, outfilename + "_core_centers.txt")

    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(thestatelabels)) - (1 if -1 in thestatelabels else 0)
    print("Estimated number of clusters: %d" % n_clusters_)

    methodname = "dbscan_" + str(n_clusters).zfill(2)

elif clustertype == "hdbscan":
    if trainedmodelroot is None:
        hdb = hdbs.HDBSCAN(
            min_samples=min_samples,
            alpha=alpha,
            memory="/Users/frederic/Documents/MR_data/connectome/movies",
        ).fit(X)

        # save the model
        joblib.dump(hdb, outfilename + "_hdbscan.joblib")
    else:
        modelfilename = trainedmodelroot + "_hdbscan.joblib"
        print("reading hdbscan model from", modelfilename)
        try:
            hdb = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

        hdb.predict(X)

    thestatelabels = hdb.labels_
    print(thestatelabels)

    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(thestatelabels)) - (1 if -1 in thestatelabels else 0)

    print("Estimated number of clusters: %d" % n_clusters_)
    methodname = "hdbscan_" + str(n_clusters).zfill(2)

else:
    print("unknown clustering type")
    sys.exit()
