#!python
#
#   Copyright 2016-2019 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:51 $
#       $Id: showstxcorr,v 1.11 2016/06/14 12:04:51 frederic Exp $
#
import argparse
import sys

import joblib
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA, FastICA, IncrementalPCA
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.feature_selection import SelectPercentile, f_classif
from sklearn.preprocessing import StandardScaler

import capcalc.filter as ccalc_filt
import capcalc.fit as ccalc_fit
import capcalc.io as ccalc_io
import capcalc.miscmath as ccalc_math
import capcalc.parser_funcs as pf
import capcalc.stats as ccalc_stats
import capcalc.utils as ccalc_utils
from capcalc.niftidecomp import niftidecomp_workflow

DEFAULT_BATCH_SIZE = 100
DEFAULT_MINOUTLENGTH = 1
DEFAULT_MINHOLDLENGTH = 1
DEFAULT_NCLUSTERS = 8
DEFAULT_NPCA = 8
DEFAULT_MAXITER = 250
DEFAULT_NINIT = 100


def niftitotimecourse(
    infilename,
    datamaskname,
    outputroot,
    n_pca,
    theprefilter,
    normmethod="None",
    trainedmodelroot=None,
    skippts=0,
    starttime=0.0,
    duration=100000000.0,
    sigma=0.0,
    maskthresh=0.25,
    segmentnorm=True,
    debug=False,
):
    # read in a list of NIFTI files
    if datamaskname is not None:
        (
            datamask_img,
            datamask_data,
            datamask_hdr,
            datamaskdims,
            datamasksizes,
        ) = ccalc_io.readfromnifti(datamaskname)
    else:
        print("using nifti files requires the use of a mask file")
        sys.exit()
    xsize, ysize, numslices, timepoints = ccalc_io.parseniftidims(datamaskdims)
    numspatialpoints = int(xsize) * int(ysize) * int(numslices)
    print(f"number of spatial points from mask = {numspatialpoints}")
    themask = datamask_data.reshape((numspatialpoints))
    thevalidpoints = np.where(themask > maskthresh)

    # Load the data.  If input is nifti, the first thing to do is reduce dimensionality A LOT
    (
        outputcomponents,
        inputdata,
        outinvtrans,
        exp_var_pct,
        datafile_hdr,
        datafiledims,
        datafilesizes,
        prenormfacs,
        premeans,
    ) = niftidecomp_workflow(
        infilename,
        outputroot,
        datamaskname=datamaskname,
        decomptype="pca",
        pcacomponents=n_pca,
        icacomponents=None,
        trainedmodelroot=trainedmodelroot,
        normmethod=normmethod,
        demean=True,
        theprefilter=theprefilter,
        sigma=sigma,
        maskthresh=maskthresh,
        segmentnorm=segmentnorm,
    )
    Fs = 1.0 / datafilesizes[4]

    if debug:
        print(f"{outputcomponents.shape=}")
        print(f"{inputdata.shape=}")
        print(f"{outinvtrans.shape=}")
        print(f"{exp_var_pct.shape=}")
        print(f"{datafiledims.shape=}")
        print(f"{datafilesizes.shape=}")

    # save the eigenvalues
    if debug:
        print("variance explained by components:", exp_var_pct)
    ccalc_io.writenpvecs(
        exp_var_pct,
        f"{outputroot}_explained_variance_pct.txt",
    )

    if debug:
        print("writing component images")
    theheader = datafile_hdr.copy()
    theheader["dim"][4] = outputcomponents.shape[3]
    ccalc_io.savetonifti(
        outputcomponents,
        theheader,
        f"{outputroot}_components",
    )
    print("writing out the coefficients")
    ccalc_io.writenpvecs(inputdata, f"{outputroot}_coefficients.txt")
    if debug:
        print("input data shape is ", inputdata.shape)
    ccalc_io.writenpvecs(prenormfacs, f"{outputroot}_prenormfacs.txt")
    ccalc_io.writenpvecs(premeans, f"{outputroot}_premeans.txt")
    numpoints = inputdata.shape[1]
    startpoint = max([int(starttime * Fs), 0]) + skippts
    endpoint = min([startpoint + int(duration * Fs), numpoints])
    trimmeddata = inputdata[:, startpoint:endpoint]
    return trimmeddata, outputcomponents, Fs, prenormfacs, premeans, theheader


def readtcsfromtext(
    infilename,
    Fs,
    outputroot,
    n_pca,
    theprefilter,
    normmethod="None",
    trainedmodelroot=None,
    skippts=0,
    starttime=0.0,
    duration=100000000.0,
    sigma=0.0,
    maskthresh=0.25,
    debug=False,
):
    startpoint = max([int(starttime * Fs), 0]) + skippts
    if len(infilename) == 1:
        # each column is a timecourse, each row is a timepoint
        print("processing single input file")
        matrixoutput = True
        inputdata = ccalc_io.readvecs(infilename[0])[skippts:]
        if debug:
            print("input data shape is ", inputdata.shape)
        numpoints = inputdata.shape[1]
        endpoint = min([startpoint + int(duration * Fs), numpoints])
        trimmeddata = inputdata[:, startpoint:endpoint]
    elif len(infilename) == 2:
        print("processing two input files")
        inputdata1 = ccalc_io.readvec(infilename[0])
        numpoints = len(inputdata1)
        inputdata2 = ccalc_io.readvec(infilename[1])
        endpoint1 = min(
            [startpoint + int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))]
        )
        endpoint2 = min([int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))])
        trimmeddata = np.zeros((2, numpoints), dtype="float")
        trimmeddata[0, :] = inputdata1[startpoint:endpoint1]
        trimmeddata[1, :] = inputdata2[0:endpoint2]
    else:
        print(
            "showstxcorr requires 1 multicolumn timecourse file or two single column timecourse files as input"
        )
        sys.exit()
    return trimmeddata


def doclustering(
    X,
    outputroot,
    minibatch=True,
    batch_size=DEFAULT_BATCH_SIZE,
    n_clusters=DEFAULT_NCLUSTERS,
    max_iter=DEFAULT_MAXITER,
    n_init=DEFAULT_NINIT,
    trainedmodelroot=None,
    initialcenters=None,
):
    print("setting up kmeans")
    if trainedmodelroot is None:
        if initialcenters is None:
            theinit = "k-means++"
        else:
            theinit = initialcenters
            max_iter = 1

        print("training model")
        if minibatch:
            kmeans = MiniBatchKMeans(
                n_clusters=n_clusters, batch_size=batch_size, max_iter=max_iter, init=theinit
            ).fit(X)
        else:
            kmeans = KMeans(
                n_clusters=n_clusters, max_iter=max_iter, n_init=n_init, init=theinit
            ).fit(X)

        # save the model
        joblib.dump(kmeans, outputroot + "_kmeans.joblib")
    else:
        modelfilename = trainedmodelroot + "_kmeans.joblib"
        print("reading kmeans model from", modelfilename)
        try:
            kmeans = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    # save the clusters
    theclusters = np.transpose(kmeans.cluster_centers_)
    ccalc_io.writenpvecs(theclusters, outputroot + "_clustercenters.txt")

    # make normalized clusters
    thenormclusters = theclusters * 0.0
    themeans = np.mean(theclusters, axis=0)
    thestds = np.std(theclusters, axis=0)
    print("themeans:", themeans)
    print("thestds:", thestds)
    print("shape:", theclusters.shape)
    for i in range(theclusters.shape[1]):
        thenormclusters[:, i] = (theclusters[:, i] - themeans[i]) / thestds[i]
    ccalc_io.writenpvecs(thenormclusters, outputroot + "_norm_clustercenters.txt")

    print("kmeans done")
    return kmeans


def _get_parser():
    # get the command line parameters
    parser = argparse.ArgumentParser(
        prog="capfromany",
        description="Calculate and cluster coactivation patterns for a set of timecourses",
        usage="%(prog)s  datafile outputroot",
    )

    parser.add_argument(
        "--infile",
        help="Input file - at least one is required",
        action="append",
        required=True,
        dest="infilename",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="FILENAME",
    )
    parser.add_argument(
        "--outputroot",
        help="The root of the output file names.",
        action="store",
        required=True,
        dest="outputroot",
        type=str,
        metavar="FILEROOT",
    )
    parser.add_argument(
        "--maskname",
        help="Mask file - required if input is a nifti file",
        action="store",
        dest="datamaskname",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="FILENAME",
    )

    # version options
    pf.addversionopts(parser)

    # preprocessing
    preproc_opts = parser.add_argument_group("Preprocessing options")
    preproc_opts.add_argument(
        "--duration",
        type=float,
        metavar="TIME",
        help="Amount of data to use, in seconds.",
        default=100000000.0,
    )
    preproc_opts.add_argument(
        "--starttime",
        type=float,
        metavar="TIME",
        help="Time of first datapoint to use in seconds in the first file",
        default=0.0,
    )
    parser.add_argument(
        "--sigma",
        dest="sigma",
        type=lambda x: pf.is_float(parser, x),
        action="store",
        metavar="SIGMA",
        help=("Spatially smooth the input data with a SIGMA mm kernel."),
        default=0.0,
    )
    preproc_opts.add_argument(
        "--n_pca",
        type=int,
        metavar="N_PCA",
        help="Number of PCA components to retain",
        default=8,
    )
    preproc_opts.add_argument(
        "--nostandardscaler",
        dest="standardscale",
        action="store_false",
        help=("Do not use StandardScaler on input timecourses."),
        default=True,
    )
    preproc_opts.add_argument(
        "--nosegmentnorm",
        dest="segmentnorm",
        action="store_false",
        help=(
            "Do not normalize file timecourses individually.  This is probably a bad idea.  Don't use this option."
        ),
        default=True,
    )
    preproc_opts.add_argument(
        "--skip",
        type=int,
        metavar="PTS",
        help="Number of points to skip at the beginning of each segment.  Default is 0.",
        default=0,
    )
    preproc_opts.add_argument(
        "--preproconly",
        action="store_true",
        help=("Stop after preprocessing."),
        default=False,
    )

    # filtering
    pf.addfilteropts(parser)

    # normalization
    pf.addnormalizationopts(parser, phases=["pre", "post"], defaultmethods=["percent", "z"])

    # clustering
    cluster_opts = parser.add_argument_group("Clustering options")
    cluster_opts.add_argument(
        "--n_clusters",
        type=int,
        metavar="N_CLUSTERS",
        help=f"Number of clusters to find.  Default is {DEFAULT_NCLUSTERS}.",
        default=DEFAULT_NCLUSTERS,
    )
    cluster_opts.add_argument(
        "--max_iter",
        type=int,
        metavar="MAX_ITER",
        help=f"Number of iterations to perform for cluster fitting.  Default is {DEFAULT_MAXITER}.",
        default=DEFAULT_MAXITER,
    )
    cluster_opts.add_argument(
        "--n_init",
        type=int,
        metavar="N_INIT",
        help=f"Number of starting states to try for k-means clustering. Default is {DEFAULT_NINIT}.",
        default=DEFAULT_NINIT,
    )
    cluster_opts.add_argument(
        "--batch_size",
        type=int,
        metavar="SIZE",
        help=f"Batch size for batched k-means clustering.  Default is {DEFAULT_BATCH_SIZE}",
        default=DEFAULT_BATCH_SIZE,
    )
    cluster_opts.add_argument(
        "--minibatch",
        action="store_true",
        help=("Run MiniBatch Kmeans rather than conventional - use with very large datasets."),
        default=False,
    )
    cluster_opts.add_argument(
        "--trainedmodelroot",
        type=str,
        metavar="DIRECTORY",
        help="Root name where previously trained model joblibs are stored (dirname and name stub).",
        default=None,
    )
    cluster_opts.add_argument(
        "--initialcenters",
        type=str,
        metavar="FILE",
        help="Name of a file containing previously calculated cluster centers.",
        default=None,
    )

    # postprocessing
    postproc_opts = parser.add_argument_group("Postprocessing options")
    postproc_opts.add_argument(
        "--segsize",
        type=str,
        metavar="SEGSIZE[,SEGSIZE...]",
        help="Segment sizes in TRs.  One or more integers, separated by commas.",
        default=None,
    )
    postproc_opts.add_argument(
        "--segIDs",
        type=str,
        metavar="SEGID[,SEGID,...]",
        help="Segment identifiers.  One or more strings, separated by commas.  Must match number of segment sizes.",
        default=None,
    )
    postproc_opts.add_argument(
        "--quality",
        dest="summaryonly",
        action="store_false",
        help=("Run silhouette calculations on fits."),
        default=True,
    )
    postproc_opts.add_argument(
        "--minout",
        dest="minoutlength",
        type=int,
        metavar="PTS",
        help=f"Transitions out of a state shorter than PTS will be patched.  Default is {DEFAULT_MINOUTLENGTH}",
        default=DEFAULT_MINOUTLENGTH,
    )
    postproc_opts.add_argument(
        "--minhold",
        dest="minholdlength",
        type=int,
        metavar="PTS",
        help=f"Residency in a state shorter than PTS will be patched.  Default is {DEFAULT_MINHOLDLENGTH}",
        default=DEFAULT_MINHOLDLENGTH,
    )
    preproc_opts.add_argument(
        "--detrendorder",
        type=int,
        metavar="ORDER",
        help="Order of detrending to apply to timecourses.  Default is 1.",
        default=1,
    )
    preproc_opts.add_argument(
        "--display",
        action="store_true",
        help=("Plot out quality metrics."),
        default=True,
    )
    preproc_opts.add_argument(
        "--doGBR",
        action="store_true",
        help=("Perform gradient boosting regression after fitting to find feature importances."),
        default=True,
    )

    misc_opts = parser.add_argument_group("Miscellaneous options")
    misc_opts.add_argument(
        "--debug",
        dest="debug",
        action="store_true",
        help=("Output debugging information."),
        default=False,
    )

    return parser


def capfromany_main():
    # get the command line parameters
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    args, theprefilter = pf.postprocessfilteropts(args, debug=args.debug)

    if args.debug:
        print(args)

    # check that required arguments are set
    if args.outputroot is None:
        print("outfile must be set")
        sys.exit()

    # check to make sure groups and subsegments are in agreement
    if args.debug:
        print("checking groups and subsegments")
    groups = {}
    subsegs = []
    if args.segsize is not None:
        print("subsegs:", args.segsize)
        for seg in args.segsize.split(","):
            subsegs.append(int(seg))
        segsize = np.sum(np.asarray(subsegs))
        print("SUBSEGS:", subsegs)
    else:
        segsize = -1

    subseggroupIDs = []
    if args.segIDs is None:
        if len(subsegs) > 0:
            for i in range(len(subsegs)):
                subseggroupIDs.append("group_" + str(i))
    else:
        for segID in args.segIDs.split(","):
            subseggroupIDs.append(segID)
    print("subseggroupIDs:", subseggroupIDs)
    if len(subsegs) != len(subseggroupIDs):
        print("number of subsegment group IDs must match number of subsegs")
        sys.exit()
    for segnum in range(len(subseggroupIDs)):
        try:
            groups[subseggroupIDs[segnum]]["segnum"].append(int(segnum))
        except:
            groups[subseggroupIDs[segnum]] = {}
            groups[subseggroupIDs[segnum]]["seglen"] = []
            groups[subseggroupIDs[segnum]]["segstart"] = []
            groups[subseggroupIDs[segnum]]["segnum"] = [int(segnum)]
        groups[subseggroupIDs[segnum]]["seglen"].append(subsegs[segnum])
        groups[subseggroupIDs[segnum]]["segstart"].append(int(np.sum(subsegs[0:segnum])))
    for key in groups:
        groupsegs = groups[key]["seglen"]
        if not all(x == groupsegs[0] for x in groupsegs):
            print("all subsegments in a group must have the same length")
            sys.exit()
    print(groups)

    # read in cluster centers if specified
    if args.initialcenters is not None:
        theclustercenters = np.transpose(ccalc_io.readvecs(args.initialcenters))
        ccentershape = theclustercenters.shape
        print("clustercenter shape:", ccentershape)
        n_clusters = ccentershape[0]
        targetfeatures = ccentershape[1]
        print("will use", n_clusters, "clusters and", targetfeatures, "features")

    # save the command line
    ccalc_io.writevec([" ".join(sys.argv)], args.outputroot + "_commandline.txt")

    # describe the spatial prefiltering
    if args.sigma == 0.0:
        print("no spatial prefiltering will be performed")
    else:
        print(f"data will be spatially prefiltered with a {args.sigma}mm gaussian kernel")

    # describe spectral prefiltering
    if theprefilter.gettype() != "None":
        print("filtering to ", theprefilter.gettype(), " band")
    else:
        print("no prefiltering will be applied")

    # describe the normalization
    if args.prenormmethod == "None":
        print("will not prenormalize timecourses")
    elif args.prenormmethod == "percent":
        print("will prenormalize timecourses to percentage of mean")
    elif args.prenormmethod == "stddev":
        print("will prenormalize timecourses to standard deviation of 1.0")
    elif args.prenormmethod == "z":
        print("will prenormalize timecourses to variance of 1.0")
    elif args.prenormmethod == "p2p":
        print("will prenormalize timecourses to p-p deviation of 1.0")
    elif args.prenormmethod == "mad":
        print("will prenormalize timecourses to median average deviate of 1.0")
    else:
        print("illegal prenormalization type")
        sys.exit()

    # read in the files and get everything trimmed to the right length
    if ccalc_io.checkifnifti(args.infilename[0]):
        inputisnifti = True
        (
            trimmeddata,
            outputcomponents,
            Fs,
            prenormfacs,
            premeans,
            theniftiheader,
        ) = niftitotimecourse(
            args.infilename,
            args.datamaskname,
            args.outputroot,
            args.n_pca,
            theprefilter,
            trainedmodelroot=args.trainedmodelroot,
            normmethod=args.prenormmethod,
            starttime=args.starttime,
            duration=args.duration,
            sigma=args.sigma,
            skippts=args.skip,
            segmentnorm=args.segmentnorm,
        )
        sampletime = 1.0 / Fs
    else:
        pass

    # save the means and normalization factors

    # we now have all the imaging data read in preprocessed, and reduced to pca coefficient timecourses
    thedims = trimmeddata.shape
    # print("original file dimensions:", origdims)
    print("trimmed data dimensions:", thedims)
    n_features = thedims[0]
    n_samples = thedims[1]
    if segsize < 0:
        segsize = n_samples
        subsegs.append(segsize)
    reformdata = np.reshape(trimmeddata, (n_features, n_samples))
    if n_samples % segsize > 0:
        print(
            "segment size (",
            segsize,
            ") is not an even divisor of the total length (",
            n_samples,
            ")- exiting",
        )
        sys.exit()
    else:
        numsegs = int(n_samples // segsize)

    print(
        f"input dataset has {n_features} features and {n_samples} samples in {numsegs} segments of size {segsize}"
    )
    if len(subsegs) > 1:
        print(f"    each segment is broken into {len(subsegs)} subsegments of length {subsegs}")

    for feature in range(n_features):
        if args.debug:
            print("second stage feature preprocessing ", feature)
        for segment in range(numsegs):
            subsegstart = segment * segsize
            for subseglen in subsegs:
                if args.detrendorder > 0:
                    segdata = ccalc_fit.detrend(
                        reformdata[feature, subsegstart : subsegstart + subseglen]
                    )
                else:
                    segdata = reformdata[feature, subsegstart : subsegstart + subseglen]

                if args.postnormmethod == "None":
                    segnorm = segdata - np.mean(segdata)
                elif args.postnormmethod == "percent":
                    segnorm = ccalc_math.pcnormalize(segdata)
                elif args.postnormmethod == "z":
                    segnorm = ccalc_math.varnormalize(segdata)
                elif args.postnormmethod == "stddev":
                    segnorm = ccalc_math.stdnormalize(segdata)
                elif args.postnormmethod == "p2p":
                    segnorm = ccalc_math.ppnormalize(segdata)
                elif args.postnormmethod == "mad":
                    segnorm = ccalc_math.madnormalize(segdata)
                else:
                    segnorm = segdata

                reformdata[feature, subsegstart : subsegstart + subseglen] = theprefilter.apply(
                    Fs, segnorm
                )
                subsegstart += subseglen
    X = np.nan_to_num(np.transpose(reformdata))
    if args.debug:
        print(f"X matrix has dimensions {X.shape}")

    if args.standardscale:
        X = StandardScaler().fit_transform(X)

    """if preprocessingtype == "pca":
        print("running PCA")
        print("shape going in:", X.shape)
        if trainedmodelroot is None:
            print("running PCA")
            if n_pca <= 0:
                thepca = PCA(n_components="mle", svd_solver="full").fit(X)
            else:
                thepca = PCA(n_components=n_pca).fit(X)
    
            # save the model
            joblib.dump(thepca, outputroot + "_pca.joblib")
        else:
            modelfilename = trainedmodelroot + "_pca.joblib"
            print("reading PCA from", modelfilename)
            try:
                thepca = joblib.load(modelfilename)
            except Exception as ex:
                template = (
                    "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
                )
                message = template.format(type(ex).__name__, modelfilename, ex.args)
                print(message)
                sys.exit()
    
        thetransform = thepca.transform(X)
        X = thepca.inverse_transform(thetransform)
        print("shape coming out:", X.shape)
        ccalc_io.writenpvecs(np.transpose(thetransform), outputroot + "_pcadenoiseddata.txt")
        for i in range(thepca.n_components_):
            print(
                "component",
                i,
                "explained variance:",
                thepca.explained_variance_[i],
                "explained variance %:",
                100.0 * thepca.explained_variance_ratio_[i],
            )
        ccalc_io.writenpvecs(thepca.components_, outputroot + "_pcacomponents.txt")
        ccalc_io.writenpvecs(
            np.transpose(thepca.components_), outputroot + "_pcacomponents_transpose.txt"
        )
    elif preprocessingtype == "ica":
        print("running FastICA")
        if n_pca <= 1.0:
            n_pca = int(0)
        if trainedmodelroot is None:
            theica = FastICA(n_components=n_pca, algorithm="deflation").fit(X)
    
            # save the model
            joblib.dump(theica, outputroot + "_ica.joblib")
        else:
            modelfilename = trainedmodelroot + "_ica.joblib"
            print("reading ICA from", modelfilename)
            try:
                theica = joblib.load(modelfilename)
            except Exception as ex:
                template = (
                    "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
                )
                message = template.format(type(ex).__name__, modelfilename, ex.args)
                print(message)
                sys.exit()
    
        thetransform = theica.transform(X)
        X = theica.inverse_transform(thetransform)
        ccalc_io.writenpvecs(theica.components_, outputroot + "_icacomponents.txt")
        ccalc_io.writenpvecs(
            np.transpose(theica.components_), outputroot + "_icacomponents_transpose.txt"
        )"""

    ccalc_io.writenpvecs(reformdata, args.outputroot + "_preprocessed.txt")
    if args.preproconly:
        print("preprocessing done - quitting")
        sys.exit()

    kmeans = doclustering(
        X,
        args.outputroot,
        minibatch=args.minibatch,
        batch_size=args.batch_size,
        n_clusters=args.n_clusters,
        max_iter=args.max_iter,
        n_init=args.n_init,
        trainedmodelroot=args.trainedmodelroot,
        initialcenters=args.initialcenters,
    )

    # project the clusters to image space and write them out
    theclusters = kmeans.cluster_centers_
    numclusters, numfeatures = theclusters.shape
    projclusterheader = theniftiheader.copy()
    projclusterheader["dim"][4] = numclusters
    xdim, ydim, zdim, dummy = ccalc_io.parseniftidims(projclusterheader["dim"])
    projclusters = np.zeros((xdim, ydim, zdim, theclusters.shape[0]), dtype=float)
    for clusternum in range(numclusters):
        for pcacomp in range(numfeatures):
            projclusters[:, :, :, clusternum] += (
                theclusters[clusternum, pcacomp] * outputcomponents[:, :, :, pcacomp]
            )
    ccalc_io.savetonifti(projclusters, projclusterheader, args.outputroot + "_clustercenters")

    # generate the state labels
    thestatelabels = kmeans.predict(X)
    print("thestatelabels shape", thestatelabels.shape)

    # save the states
    ccalc_io.writenpvecs(thestatelabels, args.outputroot + "_statelabels.txt")

    # find most important features
    print("finding most important features")
    print(
        "calling SelectPercentiles with X and y of dimensions",
        X.shape,
        thestatelabels.shape,
    )
    selector = SelectPercentile(f_classif, percentile=10)
    selector.fit(X, thestatelabels)
    print(selector.get_params())
    X_indices = np.arange(X.shape[-1])
    scores = -np.nan_to_num(np.log10(np.nan_to_num(selector.pvalues_)))
    scores /= scores.max()
    sortedscores = np.sort(np.nan_to_num(selector.scores_))[::-1]
    print(sortedscores)
    if args.display:
        plt.bar(
            X_indices - 0.45,
            scores,
            width=0.2,
            label=r"Univariate score ($-Log(p_{value})$)",
            color="darkorange",
        )
        print(selector.get_support(indices=True))
        fig = plt.subplots(1, 1)
        plt.plot(sortedscores)
        plt.show()

    # now do some stats!
    thesilavgs, thesilclusterstats = ccalc_utils.silhouette_test(
        X, kmeans, args.n_clusters, numsegs, segsize, args.summaryonly
    )
    ccalc_io.writenpvecs(thesilavgs, args.outputroot + "_silhouettesegmentstats.txt")

    silinfo = []
    for state in range(args.n_clusters):
        silinfo.append([])
    print("shape going in:", thestatelabels.shape)
    statelabelsbysegment = np.reshape(thestatelabels, (-1, segsize))
    print("shape coming out:", statelabelsbysegment.shape)

    # do the subsegment summaries
    for key in groups:
        groups[key]["meaninstate"] = np.zeros(
            (args.n_clusters, groups[key]["seglen"][0]), dtype="float"
        )
        groups[key]["stdinstate"] = np.zeros(
            (args.n_clusters, groups[key]["seglen"][0]), dtype="float"
        )
        for state in range(args.n_clusters):
            tcbyseg = []
            for seginstance in range(len(groups[key]["segnum"])):
                startpos = groups[key]["segstart"][seginstance]
                endpos = startpos + groups[key]["seglen"][seginstance]
                tcbyseg.append(np.where(statelabelsbysegment[:, startpos:endpos] == state, 1, 0))
            groups[key]["meaninstate"][state, :] = np.mean(np.concatenate(tcbyseg, axis=0), axis=0)
            groups[key]["stdinstate"][state, :] = np.std(np.concatenate(tcbyseg, axis=0), axis=0)
        ccalc_io.writenpvecs(
            groups[key]["meaninstate"],
            args.outputroot + "_" + str(key) + "_meaninstate.txt",
        )
        ccalc_io.writenpvecs(
            groups[key]["stdinstate"], args.outputroot + "_" + str(key) + "_stdinstate.txt"
        )
    allstatestats = []
    allrawtransmats = []
    alllenlists = []
    for i in range(args.n_clusters):
        alllenlists.append([])
    for segment in range(numsegs):
        thesestatelabels = thestatelabels[segment * segsize : (segment + 1) * segsize]

        outputaffine = np.eye(4)
        rawtransmat, thestats, lenlist = ccalc_utils.statestats(
            thesestatelabels,
            args.n_clusters,
            0,
            minout=args.minoutlength,
            minhold=args.minholdlength,
        )
        allrawtransmats.append(rawtransmat * 1.0)
        allstatestats.append(thestats)
        for i in range(args.n_clusters):
            alllenlists[i] += lenlist[i]
        normtransmat, offdiagtransmat = ccalc_utils.calcmats(rawtransmat, args.n_clusters)
        init_img = nib.Nifti1Image(normtransmat, outputaffine)
        init_hdr = init_img.header
        init_sizes = init_hdr["pixdim"]
        ccalc_io.savetonifti(
            np.transpose(rawtransmat),
            init_hdr,
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_rawtransmat",
        )
        ccalc_io.savetonifti(
            np.transpose(normtransmat),
            init_hdr,
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_normtransmat",
        )
        ccalc_io.savetonifti(
            np.transpose(offdiagtransmat),
            init_hdr,
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat",
        )

        # write as text as well
        rows = []
        cols = []
        for i in range(args.n_clusters):
            rows.append("from state " + str(i + 1))
            cols.append("to state " + str(i + 1))
        df = pd.DataFrame(data=rawtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_rawtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=normtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_normtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=offdiagtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat.csv",
            index=False,
        )
        # rawtransmat files are an args.n_clusters by args.n_clusters matrix with the total number of transitions from each state to each other state.
        # normtransmat files are an args.n_clusters by args.n_clusters matrix with the total for of transitions from each state to each other state.

        cols = [
            "% TRs in state",
            "Number of runs in state",
            "Total TRs in state",
            "Min run (TRs)",
            "Max run (TRs)",
            "Mean run (TRs)",
            "Median run (TRs)",
            "StdDev run (TRs)",
        ]
        df = pd.DataFrame(data=thestats, columns=cols)
        df.to_csv(
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_statestats.csv",
            index=False,
        )
        # ccalc_io.writenpvecs(np.transpose(thestats), outputroot + '_seg_' + str(segment).zfill(4) + '_statestats.txt')
        thetimestats = 1.0 * thestats
        thetimestats[:, 2:] *= sampletime
        cols = [
            "% Seconds in state",
            "Number of runs in state",
            "Total seconds in state",
            "Min run (sec)",
            "Max run (sec)",
            "Mean run (sec)",
            "Median run (sec)",
            "StdDev run (sec)",
        ]
        df = pd.DataFrame(data=thetimestats, columns=cols)
        df.to_csv(
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_statetimestats.csv",
            index=False,
        )
        # ccalc_io.writenpvecs(np.transpose(thetimestats), outputroot + '_seg_' + str(segment).zfill(4) + '_statetimestats.txt')

        ccalc_io.writenpvecs(
            thesestatelabels,
            args.outputroot + "_seg_" + str(segment).zfill(4) + "_statelabels.txt",
        )
        print("Segment %d average silhouette Coefficient: %0.3f" % (segment, thesilavgs[segment]))
        for state in range(args.n_clusters):
            tc = np.where(thesestatelabels == state, 1, 0)
            ccalc_io.writenpvecs(
                tc,
                args.outputroot
                + "_seg_"
                + str(segment).zfill(4)
                + "_instate_"
                + str(state).zfill(2)
                + ".txt",
            )
        if not args.summaryonly:
            cols = ["Mean", "Median", "Min", "Max"]
            df = pd.DataFrame(data=np.transpose(thesilclusterstats[segment, :, :]), columns=cols)
            df.to_csv(
                args.outputroot + "_seg_" + str(segment).zfill(4) + "_silhouetteclusterstats.csv",
                index=False,
            )
            # ccalc_io.writenpvecs(thesilclusterstats[segment, :, :],
            #             args.outputroot + '_seg_' + str(segment).zfill(4) + '_silhouetteclusterstats.txt')

        for state in range(args.n_clusters):
            if thestats[state, 2] > 0:
                silinfo[state].append(thesilclusterstats[segment, 0, state])

        # now generate some summary information
        overallstatestats = []
        thetimestats = []
        alllens = 0
        for i in range(args.n_clusters):
            alllens += np.sum(np.asarray(alllenlists[i], dtype="float"))

        for i in range(args.n_clusters):
            lenarray = np.asarray(alllenlists[i], dtype="float")
            if len(lenarray) > 2:
                overallstatestats.append(
                    [
                        100.0 * np.sum(lenarray) / alllens,
                        len(lenarray),
                        np.sum(lenarray),
                        np.min(lenarray),
                        np.max(lenarray),
                        np.mean(lenarray),
                        np.median(lenarray),
                        np.std(lenarray),
                    ]
                )
                thetimestats.append(
                    [
                        100.0 * np.sum(lenarray) / alllens,
                        sampletime * len(lenarray),
                        sampletime * np.sum(lenarray),
                        sampletime * np.min(lenarray),
                        sampletime * np.max(lenarray),
                        sampletime * np.mean(lenarray),
                        sampletime * np.median(lenarray),
                        sampletime * np.std(lenarray),
                    ]
                )

        cols = [
            "% TRs in state",
            "Number of runs in state",
            "Total TRs in state",
            "Min run (TRs)",
            "Max run (TRs)",
            "Mean run (TRs)",
            "Median run (TRs)",
            "StdDev run (TRs)",
        ]
        df = pd.DataFrame(data=overallstatestats, columns=cols)
        df.to_csv(
            args.outputroot + "_overall_statestats.csv",
            index=False,
        )

        cols = [
            "% Seconds in state",
            "Number of runs in state",
            "Total seconds in state",
            "Min run (sec)",
            "Max run (sec)",
            "Mean run (sec)",
            "Median run (sec)",
            "StdDev run (sec)",
        ]
        df = pd.DataFrame(data=thetimestats, columns=cols)
        df.to_csv(
            args.outputroot + "_overall_statetimestats.csv",
            index=False,
        )

    overallrawtransmat = allrawtransmats[0] * 0.0
    if args.debug:
        print(f"{len(allrawtransmats)=}")
    for segment in range(numsegs):
        overallrawtransmat += allrawtransmats[segment]
    overallnormtransmat, overalloffdiagtransmat = ccalc_utils.calcmats(
        overallrawtransmat, args.n_clusters
    )
    init_img = nib.Nifti1Image(overallnormtransmat, outputaffine)
    init_hdr = init_img.header
    init_sizes = init_hdr["pixdim"]
    ccalc_io.savetonifti(
        np.transpose(overallrawtransmat),
        init_hdr,
        args.outputroot + "_overall_rawtransmat",
    )
    ccalc_io.savetonifti(
        np.transpose(overallnormtransmat),
        init_hdr,
        args.outputroot + "_overall_normtransmat",
    )
    ccalc_io.savetonifti(
        np.transpose(overalloffdiagtransmat),
        init_hdr,
        args.outputroot + "_overall_offdiagtransmat",
    )
    themaxlen = 0
    for i in range(args.n_clusters):
        themaxlen = int(np.max([themaxlen, np.max(alllenlists[i])]))
    for i in range(args.n_clusters):
        thishist = ccalc_stats.makeandsavehistogram(
            np.array(alllenlists[i]),
            themaxlen,
            0,
            args.outputroot + "_" + str(i).zfill(2) + "_lenhist",
            therange=[1, themaxlen],
        )
    silavgs = []
    if not args.summaryonly:
        for state in range(args.n_clusters):
            silavgs.append(np.mean(np.asarray(silinfo[state], dtype="float")))
        ccalc_io.writenpvecs(
            np.asarray(silavgs, dtype="float"),
            args.outputroot + "_overallsilhouettemean.txt",
        )
    pctarray = np.asarray(allstatestats[:], dtype="float")
    cols = [
        "% TRs in state",
        "Number of runs in state",
        "Total TRs in state",
        "Min run (TRs)",
        "Max run (TRs)",
        "Mean run (TRs)",
        "Median run (TRs)",
        "StdDev run (TRs)",
    ]
    df = pd.DataFrame(data=np.mean(pctarray, axis=0), columns=cols)
    df.to_csv(
        args.outputroot + "_seg_" + str(segment).zfill(4) + "_overallmeanstats.csv",
        index=False,
    )
    # ccalc_io.writenpvecs(np.transpose(np.mean(pctarray, axis=0)), args.outputroot + '_overallmeanstats.txt')

    if args.doGBR:
        clf = GradientBoostingRegressor().fit(X, thestatelabels)
        print("GBR fitting score is:", clf.score(X, thestatelabels))
        ccalc_io.writenpvecs(
            np.reshape(clf.feature_importances_, (n_features, 1)),
            args.outputroot + "_featureimportances.txt",
        )


if __name__ == "__main__":
    capfromany_main()
