"""Matthew Shumway
05/03/2024

Trimmed down and optimized NBRW class meant to calculate Knb_v(G)
and Kv(G) fast. Made in an effort to collect data to conject
that Knb_v(G) < Kv(G) for all G. All this does is initialize several less attributes."""

import numpy as np

class Faster_NBRW:
    """A class to store several relevant attributes of a graph G in Non-Backtracking Random Walks"""
    
    def __init__(self, G):
        """Accepts a a sage graph"""
        
        self.G = G  # store graph
        self.m = len(G.edges())  # number of edges
        self.n = len(G.vertices())  # number of vertices
        
        self.S = self.S_mat()
        self.T = self.T_mat()
        self.tau = self.tau_mat()
        self.B = self.B_mat()
        self.Dinv = self.D_inv()
        self.Pnb = self.nb_trans_mat()
        
        self.M = self.M_mat()
        self.Mv_nb = self.nb_vertex_mfpt()
        self.Knb_v = self.nb_vertex_kemeny_mfpt()
        self.P = self.trans_mat_sage()
        self.K_vertex = self.kemeny_vertex()

        
    
    
    def S_mat(self):
        """Returns matrix S such that ST = C (TS=A)
        S maps from edge space to vertex space
        S is a 2m x n matrix"""
        
        edges = list(self.G.edges())
        
        S = matrix.zero(2*self.m, self.n)
        for x in range(self.n):
            # iterate through nodes
            for j in range(self.m):
                # iterate through edges
                if edges[j][1] == x:
                    S[j, x] = 1
                if edges[j][0] == x:
                    S[j+self.m, x] = 1
        return S
    
    def T_mat(self):
        """Returns T matrix such that ST = C (TS=A)
        Maps from vertex to edge
        n x 2m"""
        
        edges = list(self.G.edges())
        
        T = matrix.zero(self.n, 2*self.m)
        for x in range(self.n):
            # iterate through nodes
            for j in range(self.m):
                # iterate through edges
                if edges[j][0] == x:
                    T[x,j] = 1
                if edges[j][1] == x:
                    T[x, j+self.m] = 1
        
        return T
    
    def tau_mat(self):
        """Return matrix tau such that ST - tau = B
        2m x 2m
        Switches directed edges
        """
        
        edges = list(self.G.edges())
        
        zero = matrix.zero(self.m)  # m x m zeros
        I = identity_matrix(self.m)  # m x m identity
        
        t = block_matrix([[zero, I], [I, zero]])
        
        return t  # like an identity with opposite diagonal
    
    def B_mat(self):
        """Return matrix B = ST - tau
        non-backtracking edge adjacency matrix"""
        
        return self.S * self.T - self.tau
    
    def nb_trans_mat(self):
        """return NBRW edge space transition probability matrix
        P_{nb} = (D-I)^{-1}B"""
        b = self.B
        row_sums = sum(b)
        row_sums = [i ^ -1 for i in row_sums]
        return b * diagonal_matrix(row_sums)
    
    def D_inv(self):
        """D^{-1}"""
        return np.diag([1 / d for d in self.G.degree()])
    
    def trans(self):
        """vertex space transition matrix of a simple random walk"""
        D = self.D
        return D @ self.G.adjacency_matrix()
    
    def nb_hitting_times(self, j):
        """
        return list (np.array()) m_{i, j} where m is
        NB hitting time from i to j for all i, with j fixed.

        THIS IS ACTUALLY I THINK nb hitting time m_{e, j}. 2m vector of hitting times from
        the edge e to the vertex j. (SEE Theorem 4.3 I believe.)

        i.e. the columns of M_{ev}
        """
        # Ask Adam what's going on here!
        m = len(self.G.edges())
        n = len(self.G)
        P = self.Pnb
        Gedges = [(g[0], g[1]) for g in self.G.edges()]
        Gedges.extend([(g[1], g[0]) for g in self.G.edges()])
        entries_to_delete = []

        # Find the rows/columns to delete
        for i in range(len(Gedges)):
            if Gedges[i][0] == j:
                entries_to_delete.append(i)
        P_new = np.delete(P, entries_to_delete, axis=0)
        P_new = np.delete(P_new, entries_to_delete, axis=1)

        # Keep a list of what was NOT deleted, to construct the full vector later
        remaining_entries = [i for i in range(2 * m)]
        for k in entries_to_delete:
            remaining_entries.remove(k)

        # Solve (I - P)x = 1
        ones = np.array([1] * len(P_new))
        tau_vec = np.linalg.solve(np.identity(len(P_new)) - P_new, ones)

        # recreate full tau_vec, putting the 0's back in where necessary
        full_tau = np.array([0.0] * (2 * m))
        for i in range(len(remaining_entries)):
            full_tau[remaining_entries[i]] = tau_vec[i]

        return full_tau
    
    def M_ev_NBRW(self):
        """
        return the matrix $M_{ev}$, that is, the hitting times (mean first passage times)
        of the nonbacktracking random walk
        """
        m = len(self.G.edges())
        n = len(self.G)
        Mev = np.zeros((2 * m, n))
        for i in range(n):
            Mvec = self.nb_hitting_times(i)
            for j in range(2 * m):
                Mev[j][i] = Mvec[j]

        return Mev
    
    def P_hat(self):
        """
        (D-I)^{-1}A.
        I should probably not take an inverse since its so easy but whatever. change as needed
        """
        return np.diag([1 / (d - 1) for d in G.degree()]) @ G.adjacency_matrix()


    def P_mat(self):
        """
        D^{-1}A
        """
        return np.diag([1 / d for d in G.degree()]) @ G.adjacency_matrix()

    
    def Znb_mat(self):
        """
        A guess at what Znb should be. Seems to work pretty close to what we want numerically. Matches exactly with
        edge transitive graphs. Usually close with all other graphs we've tried.
        """
        n = len(self.G)
        m = len(self.G.edges())
        Dinv = np.diag([1 / d for d in self.G.degree()])
        T = self.T
        S = self.S
        Wnb = np.ones((2 * m, 2 * m)) / (2 * m)
        Wv = np.outer(np.ones(n), [d / (2 * m) for d in self.G.degree()])

        return np.eye(n) + Dinv @ T @ (np.linalg.inv(np.eye(2 * m) - self.nb_trans_mat() + Wnb)) @ S - Wv

    
    def Wnb_mat(self):
        return np.ones((2 * self.m, 2 * self.m)) / (2 * self.m)
    
    
    def Znb_e_mat(self):
        I = identity_matrix(2*self.m)
        return np.linalg.inv(I - self.Pnb + self.Wnb)
    
    
    def M_mat(self):
        """
        Matrix M as in (4.4) in the Hitting Times Paper
        (Basically the start point operator divided by the row sums)
        """
        M = self.T.numpy()
        # Divide by row sums
        return M / M.sum(axis=1, keepdims=1)
    
    
    def nb_vertex_mfpt(self):
        """
        Get the matrix (np.array probably) where M_{i,j} is the mfpt i -> j.
        That is, $M_v^{nb}$

        For now, assumes vertices are {0, 1, ..., n-1}
        """
        n = self.n
        MFPT = np.zeros((n, n))

        for i in range(n):
            MFPT_col_i = self.M @ self.nb_hitting_times(i)
            for j in range(n):
                MFPT[j][i] = MFPT_col_i[j]

        return MFPT
    
    def Wv_mat(self):
        return np.outer(np.ones(self.n), [d / (2 * self.m) for d in self.G.degree()])
    
    def mrt_NBRW_vertex(self):
        """
        Diagonal matrix of mean return times of a NBRW.
        See Equations below (4.8) in hitting times paper.

        Note: This should be the same Mean Return Times as SRW I believe which might just be 1/\pi_j
              if I remember correctly as I'm typing this note
        """
        Dinv = np.diag([1 / d for d in self.G.degree()])
        T = self.T
        Pnb = self.Pnb
        Mev = self.Mev
        return np.eye(len(self.G)) + np.diag(np.diag(Dinv @ T @ Pnb @ Mev))
    
    def trans_mat_sage(self):
        """vertex space transition matrix of a simple random walk"""
        return self.Dinv @ self.G.adjacency_matrix()
    
    
    def Pv_NBRW(self, k):
        """
        Return the vertex space transition probability matrix for the nonbacktracking random walk for the kth step
        """
        if k == 0:
            return np.eye(len(self.G))

        if k == 1:
            return trans_mat_sage()

        # If k > 1
        return self.Dinv @ self.T @ (np.linalg.matrix_power(self.Pnb, k - 1)) @ self.S

    def nb_vertex_kemeny_mfpt(self):
        """
        Use these hitting times to get Kemeny's constant
        """
        volG = 2 * len(self.G.edges())
        steady = np.array([d / volG for d in self.G.degree()])
        return steady @ self.Mv_nb @ steady
    
    def kemeny_vertex(self):
        """SRW Vertex Kemeny"""
        eigvals = np.array(sorted(np.linalg.eigvals(self.P))[:-1])  # eigenvalues of P transition matrix without \lambda_max = 1
        return np.sum(1/(1-eigvals))
        
    def kemeney_edge(self):
        """SRW Edge Space Kemeny Theorem 2.9"""
        return self.K_vertex + 2*self.m - self.n
    
    def kemeney_nb_edge(self):
        """nb kemeny's for a d-regular graph in the edge space"""
        eigvals = np.array(sorted(np.linalg.eigvals(self.Pnb))[:-1])  # eigenvalues of P transition matrix without \lambda_max = 1
        return np.sum(1/(1-eigvals)).real
    
    def De_mat(self):
        De = np.zeros((2*self.m, 2*self.m))
        deg = self.G.degree()
        for idx, tup in enumerate(self.G.edges()):
            i, j, _ = tup
            De[idx, idx] = deg[j]
            De[self.m + idx, self.m + idx] = deg[i]
        return De
    
    def nb_vertex_kemeny_trace(self):
        return self.Knb_e - np.trace(self.De_inv @ self.Znb_e) + self.n * (1 - 1/(2*self.m)) - 2*self.m + self.n + np.trace(self.De_inv @ self.tau @ self.Znb_e)

    def nb_fund_mat(self, i):
        """compute (I-Q)^-1 where Q = P[i,i], the transition probability matrix
        with deleted cols/rows i"""
        G = self.G
        P = self.nb_trans_mat()
        P = np.delete(P, i, 0)
        P = np.delete(P, i, 1)
        return np.linalg.inv(np.identity(len(P)) - P)


    def nb_mfpt_mat(self):
        """Mean First Passing Time (Hitting Time) matrix in the edge space for
        a Non-Backtracking Random Walk."""
        G = self.G
        P = np.array(self.nb_trans_mat())
        volG = len(P)

        M = np.zeros((volG, volG))
        for i in range(volG):
            #get N for all directed edges
            Ni = self.nb_fund_mat(i)
            for j in range(len(Ni)):
            #get all row sums, careful to put in correct place
                if j < i:
                    #if j index before i, no problem
                    M[j,i] = np.sum(Ni[j,:])
                if j >= i:
                    #takes care of reindexing
                    M[j+1,i] = np.sum(Ni[j,:])

        return M
