#! /usr/bin/env python

#  block2: Efficient MPO implementation of quantum chemistry DMRG
#  Copyright (C) 2023 Huanchen Zhai <hczhai@caltech.edu>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program. If not, see <https://www.gnu.org/licenses/>.
#
#

"""
block2 wrapper.

Author:
    Huanchen Zhai
    Zhi-Hao Cui
"""

import os
import sys

if os.name == "nt" and sys.version_info[1] >= 8:
    for path in os.environ["PATH"].split(";"):
        if path != "" and os.path.exists(os.path.abspath(path)):
            os.add_dll_directory(os.path.abspath(path))

from block2 import SZ, SU2, SZK, SU2K, SGF, DoubleFPCodec as FPCodec
from block2 import Global, OpNamesSet, NoiseTypes, DecompositionTypes, Threading, ThreadingTypes
from block2 import init_memory, release_memory, set_mkl_num_threads, read_occ, TruncationTypes
from block2 import VectorUInt8, VectorUBond, VectorVectorUBond, VectorDouble, PointGroup, ParallelSimpleTypes
from block2 import Random, FCIDUMP, QCTypes, SeqTypes, TETypes, OpNames, VectorInt, VectorUInt16, VectorUInt32
from block2 import MatrixFunctions, KuhnMunkres, Matrix, DyallFCIDUMP, FinkFCIDUMP, ConvergenceTypes
from block2 import HubbardKSpaceFCIDUMP, HubbardFCIDUMP, HeisenbergFCIDUMP, ExpectationAlgorithmTypes
from block2 import SpinOrbitalFCIDUMP, MRCISFCIDUMP, VectorVectorInt, GeneralFCIDUMP, ElemOpTypes
from block2 import MPOAlgorithmTypes, EquationTypes, VectorActTypes, ActiveTypes
from block2 import IntVectorAllocator, DoubleVectorAllocator
import numpy as np
import time
import os
import sys

VectorFP = VectorDouble

try:
    from pyblock2.driver.parser import parse, orbital_reorder, read_integral, format_schedule
except ImportError:
    from parser import parse, orbital_reorder, read_integral, format_schedule

DEBUG = True

if len(sys.argv) > 1:
    fin = sys.argv[1]
    if len(sys.argv) == 2 and fin == '-v':
        print('Block 2.0')
        quit()
    elif len(sys.argv) > 2 and sys.argv[2] in ["pre", "para-pre"]:
        pre_run = True
        para_pre_run = sys.argv[2] == "para-pre"
    else:
        pre_run = para_pre_run = False
    if len(sys.argv) > 2 and sys.argv[2] in ["run", "para-run"]:
        no_pre_run = True
        para_no_pre_run = sys.argv[2] == "para-run"
    else:
        no_pre_run = para_no_pre_run = False
else:
    raise ValueError("""
        Usage: any of:
            (A) python block2main dmrg.conf
            (B) reduced memory mode (save/load serial mpo):
                Step 1: python block2main dmrg.conf pre
                Step 2: python block2main dmrg.conf run
            (C) extra reduced memory mode (save/load parallel mpo):
                Step 1: python block2main dmrg.conf para-pre
                Step 2: python block2main dmrg.conf para-run
            (D) python block2main FCIDUMP
    """)

dic = parse(fin)
if "single_prec" in dic and "use_complex" not in dic:
    assert "k_symmetry" not in dic
    from block2 import FloatFPCodec as FPCodec, init_memory_float
    from block2 import VectorFloat
    VectorFP = VectorFloat
    init_memory = init_memory_float
    from block2.sp import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP, GeneralFCIDUMP
    if "nonspinadapted" in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.sz import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.su2 import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.sgf import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite, DRTBigSite, DRT
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite, DRTBigSite, DRT
    except ImportError:
        pass
elif "single_prec" in dic and "use_complex" in dic:
    assert "k_symmetry" not in dic
    from block2 import FloatFPCodec as FPCodec, init_memory_float
    from block2 import VectorFloat
    VectorFP = VectorFloat
    init_memory = init_memory_float
    from block2.sp.cpx import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP, GeneralFCIDUMP
    if "nonspinadapted" in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.cpx.sz import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.cpx.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.cpx.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.cpx.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.cpx.su2 import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.cpx.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.cpx.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.cpx.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.sp.cpx.sgf import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.sp.cpx.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.sp.cpx.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sp.cpx.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite, DRTBigSite, DRT
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite, DRTBigSite, DRT
    except ImportError:
        pass
elif "use_complex" not in dic:
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sz import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo
        from block2.sz import ParallelRuleSimple, ParallelFCIDUMP, CondensedMPO, GeneralMPO
        from block2.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.su2 import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo
        from block2.su2 import ParallelRuleSimple, ParallelFCIDUMP, CondensedMPO, GeneralMPO
        from block2.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO, IdentityAddedMPO
        from block2.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2 import VectorSZK as VectorSL
        from block2.szk import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.szk import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo
        from block2.szk import ParallelRuleSimple, ParallelFCIDUMP, CondensedMPO, GeneralMPO
        from block2.szk import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.szk import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO, IdentityAddedMPO
        from block2.szk import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.szk import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.szk import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.szk import trans_state_info_to_su2k as trans_si
        from block2.su2k import MPSInfo as TrMPSInfo
        from block2.su2k import trans_mps_info_to_szk as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZK
        TrSX = SU2K
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2 import VectorSU2K as VectorSL
        from block2.su2k import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.su2k import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo
        from block2.su2k import ParallelRuleSimple, ParallelFCIDUMP, CondensedMPO, GeneralMPO
        from block2.su2k import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.su2k import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO, IdentityAddedMPO
        from block2.su2k import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.su2k import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.su2k import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.su2k import trans_state_info_to_szk as trans_si, trans_unfused_mps_to_szk as trans_mps
        from block2.szk import MPSInfo as TrMPSInfo
        from block2.szk import trans_mps_info_to_su2k as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2K
        TrSX = SZK
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sgf import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo
        from block2.sgf import ParallelRuleSimple, ParallelFCIDUMP, CondensedMPO, GeneralMPO
        from block2.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO, IdentityAddedMPO
        from block2.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite, DRTBigSite, DRT
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite, DRTBigSite, DRT
    except ImportError:
        pass
else:
    from block2.cpx import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP, GeneralFCIDUMP
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.cpx.sz import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.cpx.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.cpx.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.cpx.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.cpx.su2 import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.cpx.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.cpx.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.cpx.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2 import VectorSZK as VectorSL
        from block2.szk import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.szk import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.cpx.szk import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.cpx.szk import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.cpx.szk import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.cpx.szk import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.szk import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.szk import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.szk import trans_state_info_to_su2k as trans_si
        from block2.su2k import MPSInfo as TrMPSInfo
        from block2.su2k import trans_mps_info_to_szk as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZK
        TrSX = SU2K
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2 import VectorSU2K as VectorSL
        from block2.su2k import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.su2k import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.cpx.su2k import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.cpx.su2k import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.cpx.su2k import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.cpx.su2k import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.su2k import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.su2k import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2k import trans_state_info_to_szk as trans_si, trans_unfused_mps_to_szk as trans_mps
        from block2.szk import MPSInfo as TrMPSInfo
        from block2.szk import trans_mps_info_to_su2k as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2K
        TrSX = SZK
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CondensedMPO, CSROperatorFunctions
        from block2.cpx.sgf import HamiltonianQC, GeneralHamiltonian, GeneralNPDMMPO, MPS, ParallelRuleSimple, ParallelFCIDUMP
        from block2.cpx.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule, GeneralMPO
        from block2.cpx.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO, IdentityAddedMPO
        from block2.cpx.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite, DRTBigSite, DRT
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite, DRTBigSite, DRT
    except ImportError:
        pass

try:
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.sz import MPICommunicator
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.su2 import MPICommunicator
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2.szk import MPICommunicator
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2.su2k import MPICommunicator
    elif "use_general_spin" in dic:
        from block2.sgf import MPICommunicator
    if "restart_mps_nevpt" in dic:
        MPI = None
        _print = print
    else:
        MPI = MPICommunicator()
        # from mpi4py import MPI as PYMPI
        # comm = PYMPI.COMM_WORLD

        def _print(*args, **kwargs):
            if MPI.rank == 0 and outputlevel > -1:
                kwargs["flush"] = True
                print(*args, **kwargs)
except ImportError:
    MPI = None
    _print = print


tx = time.perf_counter()

# input parameters
Random.rand_seed(1234)
outputlevel = int(dic.get("outputlevel", 2))
if DEBUG:
    _print("\n" + "*" * 34 + " INPUT START " + "*" * 34)
    for key, val in dic.items():
        if key == "schedule":
            pval = format_schedule(val)
            for ipv, pv in enumerate(pval):
                _print("%-25s %40s" % (key if ipv == 0 else "", pv))
        else:
            _print("%-25s %40s" % (key, val))
    _print("*" * 34 + " INPUT END   " + "*" * 34 + "\n")
    if "use_general_spin" in dic:
        _print("GENERAL SPIN - ", end='')
    else:
        _print("SPIN ADAPTED - " if "nonspinadapted" not in dic else "NON SPIN ADAPTED - ", end='')
    _print("REAL DOMAIN - " if "use_complex" not in dic else "COMPLEX DOMAIN - ", end='')
    _print("SINGLE PREC" if "single_prec" in dic else "DOUBLE PREC")

prefix = dic.get("prefix", "./nodex/")
restart_dir = dic.get("restart_dir", None)
mps_dir = dic.get("mps_dir", None)
stackblock_compat = dic.get("hf_occ", None) == "integral"
openmolcas_compat = "openmolcas" in dic
restart_dir_per_sweep = dic.get("restart_dir_per_sweep", None)
if stackblock_compat and not openmolcas_compat:
    n_threads = int(dic.get("num_thrds", 1))
else:
    n_threads = int(dic.get("num_thrds", Global.threading.n_threads_global))
mkl_threads = int(dic.get("mkl_thrds", 1))
bond_dims, dav_thrds, noises, site_dependent_bdims = dic["schedule"]
if max([len(x) for x in site_dependent_bdims]) == 0:
    site_dependent_bdims = []
site_dependent_bdims = VectorVectorUBond(
    [VectorUBond(x) for x in site_dependent_bdims])
store_wfn_spectra = "store_wfn_spectra" in dic
sweep_tol = float(dic.get("sweep_tol", 1e-6))
cached_contraction = int(dic.get("cached_contraction", 1)) == 1
singlet_embedding = "singlet_embedding" in dic
integral_rescale = dic.get("integral_rescale", "auto")
siv = dic.get("symmetrize_ints", 1E-10)
init_mps_center = int(dic.get("init_mps_center", 0))
symmetrize_ints_tol = 1E-10 if siv == "" else float(siv)
nevpt_symmetrize_ints = float(dic.get("nevpt_symmetrize_ints", 1E-12))
dynamic_corr_method = None
condense_mpo = int(dic.get("condense_mpo", "1"))
sub_spaces =['ijrs', 'ij', 'rs', 'ijr', 'rsi', 'ir', 'i', 'r']
for dyn_key in ["dmrgfci", "mrci", "mrcis", "mrcisd", "mrcisdt",
                "casci", "nevpt2", "nevpt2s", "nevpt2sd",
                "mrrept2", "mrrept2s", "mrrept2sd",
                *["nevpt2-" + x for x in sub_spaces],
                *["mrrept2-" + x for x in sub_spaces]]:
    if dyn_key in dic:
        dynamic_corr_method = [dyn_key, [int(x) for x in dic[dyn_key].split()]]
        if dynamic_corr_method[0] == 'mrci':
            dynamic_corr_method[0] = 'mrcisd'
        elif dynamic_corr_method[0] == 'nevpt2':
            dynamic_corr_method[0] = 'nevpt2sd'
        elif dynamic_corr_method[0] == 'mrrept2':
            dynamic_corr_method[0] = 'mrrept2sd'
        break
ghamil = None
big_site_method = dic.get("big_site", None)
n_cas = 0
qc_mpo_trans_center = -1
if dic.get("qc_mpo_type", "auto") != "auto":
    qctstr = dic.get("qc_mpo_type", "auto")
    if qctstr == "conventional":
        qc_type = QCTypes.Conventional
    elif qctstr == "nc":
        qc_type = QCTypes.NC
    elif qctstr == "cn":
        qc_type = QCTypes.CN
    else:
        raise RuntimeError("invalid qc_mpo_type: %s" % qctstr)
elif "simple_parallel" in dic:
    qc_type = QCTypes.Conventional
    if qc_mpo_trans_center == -1:
        qc_mpo_trans_center = -2
elif dynamic_corr_method is None or dynamic_corr_method[0] == "dmrgfci":
    qc_type = QCTypes.Conventional
else:
    qc_type = QCTypes.NC

_print('qc mpo type = ', qc_type)

if dic.get("trunc_type", "physical") == "physical":
    trunc_type = TruncationTypes.Physical
elif dic.get("trunc_type", "physical").startswith("keep "):
    trunc_type = TruncationTypes.KeepOne * \
        int(dic["trunc_type"][len("keep "):].strip())
else:
    trunc_type = TruncationTypes.Reduced
if "real_density_matrix" in dic:
    trunc_type = trunc_type | TruncationTypes.RealDensityMatrix
if dic.get("decomp_type", "density_matrix") == "density_matrix":
    decomp_type = DecompositionTypes.DensityMatrix
else:
    decomp_type = DecompositionTypes.SVD
if dic.get("te_type", "rk4") == "rk4":
    te_type = TETypes.RK4
else:
    te_type = TETypes.TangentSpace
if dic.get("expt_algo_type", "auto") == "auto":
    algo_type = ExpectationAlgorithmTypes.Automatic
elif dic["expt_algo_type"] == "fast":
    algo_type = ExpectationAlgorithmTypes.Fast
elif dic["expt_algo_type"] == "normal":
    algo_type = ExpectationAlgorithmTypes.Normal
elif dic["expt_algo_type"] == "symbolfree":
    algo_type = ExpectationAlgorithmTypes.SymbolFree | ExpectationAlgorithmTypes.Compressed
elif dic["expt_algo_type"] == "lowmem":
    algo_type = ExpectationAlgorithmTypes.SymbolFree | ExpectationAlgorithmTypes.LowMem | ExpectationAlgorithmTypes.Compressed
elif dic["expt_algo_type"] == "symbolfree-npy":
    algo_type = ExpectationAlgorithmTypes.SymbolFree
elif dic["expt_algo_type"] == "lowmem-npy":
    algo_type = ExpectationAlgorithmTypes.SymbolFree | ExpectationAlgorithmTypes.LowMem
else:
    raise RuntimeError("Unknown expectation algo type: %s" % dic["expt_algo_type"])

has_tran = "restart_tran_onepdm" in dic or "tran_onepdm" in dic \
    or "restart_tran_twopdm" in dic or "tran_twopdm" in dic \
    or "restart_tran_threepdm" in dic or "tran_threepdm" in dic \
    or "restart_tran_fourpdm" in dic or "tran_fourpdm" in dic \
    or "restart_tran_oh" in dic or "tran_oh" in dic or "compression" in dic
has_2pdm = "restart_tran_twopdm" in dic or "tran_twopdm" in dic \
    or "restart_twopdm" in dic or "twopdm" in dic
has_1npc = "restart_correlation" in dic or "correlation" in dic \
    or "restart_diag_twopdm" in dic or "diag_twopdm" in dic
anti_herm = "orbital_rotation" in dic
one_body_only = "orbital_rotation" in dic or "one_body_parallel_rule" in dic
complex_mps = "complex_mps" in dic
full_integral = "full_integral" in dic
XExpect = ComplexExpect if complex_mps else Expect

simpl_rule = RuleQC()
if has_tran or ("use_complex" in dic) or ("use_hybrid_complex" in dic):
    simpl_rule = NoTransposeRule(simpl_rule)
if anti_herm:
    simpl_rule = AntiHermitianRuleQC(simpl_rule)

if "use_hybrid_complex" in dic:
    assert "use_complex" not in dic
    assert "single_prec" not in dic # todo
    _print("USE HYBRID COMPLEX MPO")
    cached_contraction = False
    from block2.cpx import FCIDUMP as FCIDUMPCPX
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.cpx.sz import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.sz import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.sz import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.sz import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.sz import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
        from block2.cpx.sz import ParallelRuleSimple as ParallelRuleSimpleCPX, ParallelFCIDUMP as ParallelFCIDUMPCPX
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.cpx.su2 import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.su2 import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.su2 import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.su2 import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.su2 import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
        from block2.cpx.su2 import ParallelRuleSimple as ParallelRuleSimpleCPX, ParallelFCIDUMP as ParallelFCIDUMPCPX
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2.cpx.szk import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.szk import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.szk import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.szk import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.szk import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
        from block2.cpx.szk import ParallelRuleSimple as ParallelRuleSimpleCPX, ParallelFCIDUMP as ParallelFCIDUMPCPX
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2.cpx.su2k import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.su2k import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.su2k import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.su2k import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.su2k import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
        from block2.cpx.su2k import ParallelRuleSimple as ParallelRuleSimpleCPX, ParallelFCIDUMP as ParallelFCIDUMPCPX
    elif "use_general_spin" in dic:
        from block2.cpx.sgf import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.sgf import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.sgf import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.sgf import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.sgf import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
        from block2.cpx.sgf import ParallelRuleSimple as ParallelRuleSimpleCPX, ParallelFCIDUMP as ParallelFCIDUMPCPX
    simpl_rule_cpx_h1e = RuleQCCPX()
    if has_tran:
        simpl_rule_cpx_h1e = NoTransposeRuleCPX(simpl_rule_cpx_h1e)
    if anti_herm:
        simpl_rule_cpx_h1e = AntiHermitianRuleQCCPX(simpl_rule_cpx_h1e)

if MPI is None or MPI.rank == 0:
    if not os.path.isdir(prefix):
        os.makedirs(prefix)
    if restart_dir is not None and not os.path.isdir(restart_dir):
        os.makedirs(restart_dir)
    if mps_dir is not None and not os.path.isdir(mps_dir):
        os.makedirs(mps_dir)
    scratch = prefix if mps_dir is None else mps_dir
    if stackblock_compat or openmolcas_compat:
        if mps_dir is None and "prefix" not in dic:
            scratch = "./node0"
        else:
            scratch = scratch + "/node0"
    if not os.path.isdir(scratch):
        os.makedirs(scratch)
    os.environ['TMPDIR'] = scratch
else:
    scratch = prefix if mps_dir is None else mps_dir
    if stackblock_compat or openmolcas_compat:
        if mps_dir is None and "prefix" not in dic:
            scratch = "./node0"
        else:
            scratch = scratch + "/node0"
    os.environ['TMPDIR'] = scratch
if MPI is not None:
    MPI.barrier()

if MPI is not None and MPI.rank != 0:
    if not os.path.isdir(prefix):
        os.makedirs(prefix)
    if not os.path.isdir(scratch):
        os.makedirs(scratch)

if MPI is not None:
    MPI.barrier()

# global settings
memory = int(int(dic.get("mem", "2").split()[0].lower().replace('g', '')) * 1e9)
mem_ratio = float(dic.get("mem_ratio", 0.4))
min_mpo_mem = dic.get("min_mpo_mem", "auto")
fp_cps_cutoff = float(dic.get("fp_cps_cutoff", 1E-16))
if "intmem" in dic:
    intmemory = int(int(dic["intmem"].split()[0]) * 1e9)
    init_memory(isize=int(intmemory), dsize=int(memory),
                save_dir=prefix, dmain_ratio=mem_ratio)
else:
    init_memory(isize=int(memory * 0.1),
                dsize=int(memory * 0.9), save_dir=prefix,
                dmain_ratio=mem_ratio)
# ZHC NOTE nglobal_threads, nop_threads, MKL_NUM_THREADS
Global.threading = Threading(
    ThreadingTypes.OperatorBatchedGEMM | ThreadingTypes.Global,
    n_threads * mkl_threads, n_threads, mkl_threads)
Global.threading.seq_type = SeqTypes.Tasked if big_site_method is None else SeqTypes.Nothing
gframe = Global.frame if "single_prec" not in dic else Global.frame_float
gframe.fp_codec = FPCodec(fp_cps_cutoff, 1024)
gframe.load_buffering = False
gframe.save_buffering = False
gframe.use_main_stack = False
gframe.minimal_disk_usage = True
if mps_dir is not None:
    gframe.mps_dir = mps_dir
if restart_dir is not None:
    gframe.restart_dir = restart_dir
if restart_dir_per_sweep is not None:
    gframe.restart_dir_per_sweep = restart_dir_per_sweep
_print(gframe)
_print(Global.threading)

if MPI is not None:
    if "simple_parallel" in dic:
        spara_name = dic["simple_parallel"].upper() or "IJ"
        spara_type = getattr(ParallelSimpleTypes, spara_name)
        prule = ParallelRuleSimple(spara_type, MPI)
    else:
        prule = ParallelRuleQC(MPI)
    prule_one_body = ParallelRuleOneBodyQC(MPI)
    prule_pdm1 = ParallelRulePDM1QC(MPI)
    prule_pdm2 = ParallelRulePDM2QC(MPI)
    prule_ident = ParallelRuleIdentity(MPI)

if "use_hybrid_complex" in dic:
    if MPI is not None:
        if "simple_parallel" in dic:
            prule_cpx_h1e = ParallelRuleSimpleCPX(spara_type, MPI)
        else:
            prule_cpx_h1e = ParallelRuleQCCPX(MPI)
        prule_one_body_cpx_h1e = ParallelRuleOneBodyQCCPX(MPI)

def read_fcidump_header(fn):

    with open(fn, 'r') as f:
        ff = f.read().lower()
        if '/' in ff:
            pars, ints = ff.split('/')
        elif '&end' in ff:
            pars, ints = ff.split('&end')

    cont = ','.join(pars.split()[1:])
    cont = cont.split(',')
    cont_dict = {}
    p_key = None
    for c in cont:
        if '=' in c or p_key is None:
            p_key, b = c.split('=')
            cont_dict[p_key.strip().lower()] = b.strip()
        elif len(c.strip()) != 0:
            if len(cont_dict[p_key.strip().lower()]) != 0:
                cont_dict[p_key.strip().lower()] += ',' + c.strip()
            else:
                cont_dict[p_key.strip().lower()] = c.strip()

    for k, v in cont_dict.items():
        if ',' in v:
            cont_dict[k] = v.split(',')
    
    return cont_dict, ints

def read_fock_fcidump(fn):

    cont_dict, ints = read_fcidump_header(fn)

    n_sites = int(cont_dict.get('nact'))

    data = []
    dtype = float
    for l in ints.split('\n'):
        ll = l.strip()
        if len(ll) == 0 or ll.strip()[0] == '!':
            continue
        ll = ll.split()
        assert len(ll) == 3 or len(ll) == 4
        if len(ll) == 3:
            d = float(ll[0])
            i, j = [int(x) for x in ll[1:]]
        else:
            dtype = np.complex128
            d = float(ll[0]) + float(ll[1]) * 1j
            i, j = [int(x) for x in ll[2:]]
        data.append((i, j, d))
    fock = np.zeros((n_sites, ) * 2, dtype=dtype)
    for i, j, d in data:
        fock[i - 1, j - 1] = d
        fock[j - 1, i - 1] = d

    return n_sites, fock

def read_nevpt2_compress_fcidump(fn, ncas, ncore, nvirt):

    cont_dict, ints = read_fcidump_header(fn)

    n_sites = int(cont_dict.get('norb'))
    twos = int(cont_dict.get('ms2', 0))
    ipg = int(cont_dict.get('isym', 0))
    n_elec = int(cont_dict.get('nelec', 0))
    orb_sym = [int(i) for i in cont_dict.get('orbsym')]

    data = []
    dtype = float
    for l in ints.split('\n'):
        ll = l.strip()
        if len(ll) == 0 or ll.strip()[0] == '!':
            continue
        ll = ll.split()
        assert len(ll) == 5 or len(ll) == 6
        if len(ll) == 5:
            d = float(ll[0])
            i, j, k, l = [int(x) for x in ll[1:]]
        else:
            dtype = np.complex128
            d = float(ll[0]) + float(ll[1]) * 1j
            i, j, k, l = [int(x) for x in ll[2:]]
        data.append((i, j, k, l, d))
    ip = 0
    h1e = np.zeros((ncas, ) * 2, dtype=dtype)
    g2e = np.zeros((ncas, ) * 4, dtype=dtype)
    h1e_sr = np.zeros((nvirt, ncas), dtype=dtype)
    h1e_si = np.zeros((ncas, ncore), dtype=dtype)
    g2e_sr = np.zeros((nvirt, ncas, ncas, ncas), dtype=dtype)
    g2e_si = np.zeros((ncas, ncore, ncas, ncas), dtype=dtype)
    orbe = np.zeros((ncore + nvirt, ) * 1, dtype=dtype)
    const_e = 0.0
    for i, j, k, l, d in data:
        if i + j + k + l == 0:
            if ip == 0:
                const_e = d
            else:
                assert d == 0
            ip += 1
        elif ip == 0:
            if k + l == 0 and i - 1 < ncas:
                h1e[i - 1, j - 1] = d
                h1e[j - 1, i - 1] = d
            elif k + l == 0 and i == j:
                orbe[i - 1 - ncas] = d
            else:
                assert k != 0 and l != 0
                g2e[i - 1, j - 1, k - 1, l - 1] = d
                g2e[j - 1, i - 1, k - 1, l - 1] = d
                g2e[j - 1, i - 1, l - 1, k - 1] = d
                g2e[i - 1, j - 1, l - 1, k - 1] = d
                g2e[k - 1, l - 1, i - 1, j - 1] = d
                g2e[k - 1, l - 1, j - 1, i - 1] = d
                g2e[l - 1, k - 1, j - 1, i - 1] = d
                g2e[l - 1, k - 1, i - 1, j - 1] = d
        elif ip == 1:
            g2e_sr[i - 1 - ncas - ncore, j - 1, k - 1, l - 1] = d
        elif ip == 2:
            g2e_si[i - 1, j - 1 - ncas, k - 1, l - 1] = d
        elif ip == 3:
            assert k == 0 and l == 0
            h1e_sr[i - 1 - ncas - ncore, j - 1] = d
        elif ip == 4:
            assert k == 0 and l == 0
            h1e_si[i - 1, j - 1 - ncas] = d
        elif ip == 5:
            assert i + j + k + l == 0
    
    return n_sites, twos, ipg, n_elec, orb_sym, h1e, g2e, const_e, orbe, h1e_si, h1e_sr, g2e_si, g2e_sr

# prepare hamiltonian
if pre_run or not no_pre_run:
    nelec = [int(x) for x in dic["nelec"].split()]
    spin = [int(x) for x in dic.get("spin", "0").split()]
    isym = [int(x) for x in dic.get("irrep", "1").split()]
    iksym = [int(x) for x in dic.get("k_irrep", "0").split()]
    if "orbital_rotation" in dic:
        orb_sym = np.load(scratch + "/nat_orb_sym.npy")
        if "k_symmetry" in dic:
            orb_sym = VectorUInt32(orb_sym)
        else:
            orb_sym = VectorUInt8(orb_sym)
        kappa = np.load(scratch + "/nat_kappa.npy")
        kappa = kappa.ravel()
        n_sites = len(orb_sym)
        fcidump = FCIDUMP()
        fcidump.initialize_h1e(n_sites, nelec[0], spin[0], isym[0], 0.0, kappa)
        assert "nofiedler" in dic or "noreorder" in dic
        if "target_t" not in dic:
            dic["target_t"] = "1"
    elif "model" in dic:  # model hamiltonians
        fmods = dic["model"].split()
        if fmods[0] in ["hubbard", "hubbard_periodic", "hubbard_kspace", "hubbard_rspace"]:
            assert len(fmods) in [4, 5]
            n_sites, const_t, const_u = int(
                fmods[1]), float(fmods[2]), float(fmods[3])
            if len(fmods) == 5 and fmods[4] == "per-site":
                const_t /= n_sites
                const_u /= n_sites
            _print("1D %s model : L = %d T = %.5f U = %.5f" %
                   (fmods[0], n_sites, const_t, const_u))
            if fmods[0] == "hubbard_kspace":
                fcidump = HubbardKSpaceFCIDUMP(n_sites, const_t, const_u)
            else:
                fcidump = HubbardFCIDUMP(n_sites, const_t, const_u,
                                         fmods[0] in ["hubbard_periodic", "hubbard_rspace"])
            orb_sym = None
        else:
            raise RuntimeError("Model %d not supported!" % fmods[0])
    else:
        orb_sym = None
        fints = dic["orbitals"]
        if open(fints, 'rb').read(4) != b'\x89HDF':
            # separate fcidump into real and complex parts
            if "use_hybrid_complex" in dic:
                fd_cpx = FCIDUMPCPX()
                fd_cpx.read(fints)
                fh1e = np.array(fd_cpx.h1e_matrix())
                rg2e = np.array(fd_cpx.g2e_1fold())
                assert np.abs(np.linalg.norm(np.imag(rg2e))) < 1E-20
                assert np.abs(np.imag(fd_cpx.const_e)) < 1E-20
                rg2e = np.real(rg2e).copy() # make contig
                rh1e = np.real(fh1e).copy()
                ch1e = fh1e.copy()
                rh1e[np.abs(np.imag(fh1e)) >= 1E-20] = 0.0
                ch1e[np.abs(np.imag(fh1e)) < 1E-20] = 0.0
                assert not fd_cpx.uhf
                fcidump = FCIDUMP()
                if fd_cpx.uhf:
                    fcidump.initialize_sz(
                        fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, np.real(fd_cpx.const_e), rh1e, rg2e)
                else:
                    fcidump.initialize_su2(
                        fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, np.real(fd_cpx.const_e), rh1e, rg2e)
                fcidump.orb_sym = fd_cpx.orb_sym
                fd_cpx_h1e = FCIDUMPCPX()
                fd_cpx_h1e.initialize_h1e(fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, 1j * np.imag(fd_cpx.const_e), ch1e)
                fd_cpx_h1e.orb_sym = fd_cpx.orb_sym
            elif "restart_mps_nevpt" in dic:
                nevpt_ncas, nevpt_ncore, nevpt_nvirt = [int(x) for x in dic["restart_mps_nevpt"].split()]
                fcidump = FCIDUMP()
                n_sites, twos, ipg, n_elec, orb_sym, nevpt_h1e, nevpt_g2e, nevpt_const_e, orbe, h1e_si, h1e_sr, g2e_si, g2e_sr \
                    = read_nevpt2_compress_fcidump(fints, nevpt_ncas, nevpt_ncore, nevpt_nvirt)
                fcidump.initialize_su2(
                    nevpt_ncas, n_elec, twos, ipg, np.real(nevpt_const_e), nevpt_h1e.ravel(), nevpt_g2e.flatten())
                orb_sym = VectorUInt8(orb_sym[:nevpt_ncas])
                nevpt_orb_sym = orb_sym
                orb_sym = None
                fcidump.orb_sym = nevpt_orb_sym
                import shutil
                import os
                if MPI is None or MPI.rank == 0:
                    for k in os.listdir(scratch + "/../../node0"):
                        shutil.copy(scratch + "/../../node0/" + k, scratch + "/" + k)
                    for k in os.listdir(scratch + "/../.."):
                        if not os.path.isdir(scratch + "/../../" + k):
                            shutil.copy(scratch + "/../../" + k, prefix + "/" + k)
                if MPI is not None:
                    MPI.barrier()
            else:
                fcidump = FCIDUMP()
                fcidump.read(fints)
            integral_tol = float(dic.get("integral_tol", 0.0))
            if integral_tol != 0.0:
                int_tc_error = fcidump.truncate_small(integral_tol)
                _print("integral truncation error = ", int_tc_error)
            fcidump.params["nelec"] = str(nelec[0])
            fcidump.params["ms2"] = str(spin[0])
            fcidump.params["isym"] = str(isym[0])
        else:
            integral_tol = float(dic.get("integral_tol", 1E-12))
            fcidump = read_integral(fints, nelec[0], spin[0], isym=isym[0],
                                    tol=integral_tol, is_sp="single_prec" in dic)
        if integral_rescale == "auto" and "single_prec" in dic:
            _print("original integral const = %20.10f" % fcidump.e())
            fcidump.rescale(0)
            _print("rescaled integral const = %20.10f" % fcidump.e())
        elif integral_rescale != "none" and integral_rescale != "auto":
            _print("original integral const = %20.10f" % fcidump.e())
            fcidump.rescale(float(integral_rescale))
            _print("rescaled integral const = %20.10f" % fcidump.e())
    n_orbs = fcidump.n_sites
    if fcidump.uhf and not "nonspinadapted" in dic and not "use_general_spin" in dic:
        _print("WARN: A non-spin-adapted FCIDUMP is given but there is no keyword 'nonspinadapted'!")
    if n_orbs == 2 and qc_type == QCTypes.Conventional:
        qc_type = QCTypes.NC
        _print("WARN: changed qc_mpo_type to NC because there are only 2 sites!")
    if "trans_integral_to_spin_orbital" in dic:
        n_orbs = n_orbs * 2
    if dynamic_corr_method is not None:
        if len(dynamic_corr_method[1]) == 2:
            assert len(nelec) == 1
            n_cas, n_elec_cas = dynamic_corr_method[1]
            assert (nelec[0] - n_elec_cas) % 2 == 0
            n_inactive = (nelec[0] - n_elec_cas) // 2
            n_external = n_orbs - n_inactive - n_cas
        else:
            n_inactive, n_cas, n_external = dynamic_corr_method[1]
            assert n_orbs == n_inactive + n_cas + n_external
        _print("dynamic correlation space : inactive = %d, cas = %d, external = %d"
               % (n_inactive, n_cas, n_external))
    if "fullrestart" in dic and os.path.isfile(scratch + '/orbital_reorder.npy'):
        orb_idx = np.load(scratch + '/orbital_reorder.npy')
        _print("loading reorder for restarting = ", orb_idx)
        fcidump.reorder(VectorUInt16(orb_idx))
    elif "nofiedler" in dic or "noreorder" in dic:
        if dynamic_corr_method is not None and big_site_method == "bigdrt":
            orb_idx = np.arange(0, fcidump.n_sites, dtype=int)
            orb_idx = np.concatenate((orb_idx[(orb_idx >= n_inactive)
                & (orb_idx < n_cas + n_inactive)], orb_idx[orb_idx < n_inactive],
                orb_idx[orb_idx >= n_cas + n_inactive]), axis=0)
            _print("reorder indices adjusted for dynamic correlation = ", orb_idx)
            fcidump.reorder(VectorUInt16(orb_idx))
            np.save(scratch + '/orbital_reorder.npy', orb_idx)
        else:
            orb_idx = None
            np.save(scratch + '/orbital_reorder.npy',
                    np.arange(0, fcidump.n_sites, dtype=int))
    else:
        if "gaopt" in dic:
            orb_idx = orbital_reorder(fcidump, method='gaopt ' + dic["gaopt"])
            _print("using gaopt reorder = ", orb_idx)
        elif "reorder" in dic:
            orb_idx = orbital_reorder(
                fcidump, method='manual ' + dic["reorder"])
            _print("using manual reorder = ", orb_idx)
        elif "irrep_reorder" in dic:
            orb_idx = orbital_reorder(
                fcidump, method='irrep ' + dic.get("sym", "d2h"))
            _print("using irrep reorder = ", orb_idx)
            _print("reordered irrep = ", fcidump.orb_sym)
        else:
            orb_idx = orbital_reorder(fcidump, method='fiedler')
            _print("using fiedler reorder = ", orb_idx)
        if dynamic_corr_method is not None:
            if big_site_method != "bigdrt":
                orb_idx = np.concatenate((orb_idx[orb_idx < n_inactive],
                    orb_idx[(orb_idx >= n_inactive) & (orb_idx < n_cas + n_inactive)],
                    orb_idx[orb_idx >= n_cas + n_inactive]), axis=0)
            else:
                orb_idx = np.concatenate((orb_idx[(orb_idx >= n_inactive)
                    & (orb_idx < n_cas + n_inactive)], orb_idx[orb_idx < n_inactive],
                    orb_idx[orb_idx >= n_cas + n_inactive]), axis=0)
            _print("reorder indices adjusted for dynamic correlation = ", orb_idx)
        fcidump.reorder(VectorUInt16(orb_idx))
        np.save(scratch + '/orbital_reorder.npy', orb_idx)
    if "use_hybrid_complex" in dic and orb_idx is not None:
        fd_cpx_h1e.reorder(VectorUInt16(orb_idx))
    if "full_integral" not in dic and dynamic_corr_method is not None and \
        dynamic_corr_method[0] in ["nevpt2s", "mrcis", "mrrept2s",
            "nevpt2-i", "nevpt2-r", "mrrept2-i", "mrrept2-r"]:
        _print("use mrcis ficdump")
        fcidump = MRCISFCIDUMP(fcidump, n_inactive, n_external)
    if "trans_integral_to_spin_orbital" in dic:
        fcidump = SpinOrbitalFCIDUMP(fcidump)
    if "heisenberg" in dic:
        fcidump = HeisenbergFCIDUMP(fcidump)

    swap_pg = getattr(PointGroup, "swap_" + dic.get("sym", "d2h"))

    _print("read integral finished", time.perf_counter() - tx)

    vacuum = SX(0)
    if "k_symmetry" in dic:
        if "k_mod" in dic:
            fcidump.k_mod = int(dic["k_mod"])
            if fcidump.k_mod != 0:
                fcidump.k_sym = VectorInt(
                    [x % fcidump.k_mod for x in fcidump.k_sym])
            fcidump.k_isym = fcidump.k_isym % fcidump.k_mod
            iksym = [x % fcidump.k_mod for x in iksym]
        target = SX(fcidump.n_elec, fcidump.twos,
                    SX.pg_combine(swap_pg(fcidump.isym), fcidump.k_isym, fcidump.k_mod))
    else:
        target = SX(fcidump.n_elec, fcidump.twos, swap_pg(fcidump.isym))
    targets = []
    for inelec in nelec:
        for ispin in spin:
            for iisym in isym:
                if "k_symmetry" in dic:
                    for iiksym in iksym:
                        targets.append(SX(inelec, ispin,
                                          SX.pg_combine(swap_pg(iisym), iiksym, fcidump.k_mod)))
                else:
                    targets.append(SX(inelec, ispin, swap_pg(iisym)))
    targets = VectorSL(targets)
    if len(targets) == 0:
        targets = VectorSL([target])
    if singlet_embedding and SX == SU2:
        for it, targ in enumerate(targets):
            if targ.twos != 0:
                targets[it] = SX(targ.n + targ.twos, 0, targ.pg)
        assert len(spin) == 1
        singlet_embedding_spin = spin[0]
    if len(targets) == 1:
        target = targets[0]
    n_sites = n_orbs
    if orb_sym is None:
        orb_sym = VectorUInt8(map(swap_pg, fcidump.orb_sym))
        for x in orb_sym:
            if x == 8:
                raise RuntimeError("Wrong point group symmetry : ", dic.get("sym", "d2h"))
    sym_error = fcidump.symmetrize(orb_sym)
    _print("integral sym error = %12.4g" % sym_error)
    if "k_symmetry" in dic:
        pure_pg_sym = orb_sym
        k_sym = fcidump.k_sym
        k_mod = fcidump.k_mod
        k_sym_error = fcidump.symmetrize(k_sym, k_mod)
        _print("integral k sym error = %12.4g" % k_sym_error)
        sym_error += k_sym_error
        orb_sym = HamiltonianQC.combine_orb_sym(orb_sym, k_sym, k_mod)
    if sym_error > symmetrize_ints_tol:
        raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                            + "please check point group symmetry and FCIDUMP or set"
                            + " a higher tolerance for the keyword '%s'") % (
            symmetrize_ints_tol, "symmetrize_ints"))
    hamil_np = None
    hamil_cpx_h1e = None
    if big_site_method is None:
        if "simple_parallel" in dic:
            if qc_mpo_trans_center == -2:
                qc_mpo_trans_center = fcidump.n_sites / (1 + 1.0 / np.sqrt(prule.comm.size))
                qc_mpo_trans_center = int(qc_mpo_trans_center)
                if qc_mpo_trans_center >= fcidump.n_sites - 2:
                    qc_type = QCTypes.NC if spara_name == "IJ" else QCTypes.CN
                elif spara_name == "KL":
                    qc_mpo_trans_center = fcidump.n_sites - 1 - qc_mpo_trans_center
            hamil = HamiltonianQC(vacuum, n_sites, orb_sym, ParallelFCIDUMP(fcidump, prule))
            _print('simple parallel : mode =', spara_name, 'mpo type =', qc_type,
                'center =', qc_mpo_trans_center, "/", fcidump.n_sites)
        else:
            hamil = HamiltonianQC(vacuum, n_sites, orb_sym, fcidump)
        rhamil = hamil
        ghamil = GeneralHamiltonian(vacuum, n_sites, orb_sym)
        if "use_hybrid_complex" in dic:
            if "simple_parallel" in dic:
                hamil_cpx_h1e = HamiltonianQCCPX(vacuum, n_sites, orb_sym, ParallelFCIDUMPCPX(fd_cpx_h1e, prule))
            else:
                hamil_cpx_h1e = HamiltonianQCCPX(vacuum, n_sites, orb_sym, fd_cpx_h1e)
    elif big_site_method == "folding":
        hamil = HamiltonianQC(vacuum, n_sites, orb_sym, fcidump)
        mpo_fold = MPOQC(hamil, qc_type)
        assert dynamic_corr_method is not None
        if dynamic_corr_method[0] in ["casci"]:
            mps_info_fold = CASCIMPSInfo(
                n_orbs, vacuum, target, hamil.basis, n_inactive, n_cas, n_external)
        elif dynamic_corr_method[0] in ["mrcis", "mrcisd", "mrcisdt"]:
            ci_order = len(dynamic_corr_method[0]) - 4
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["nevpt2sd", "nevpt2s"]:
            ci_order = len(dynamic_corr_method[0]) - 6
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["mrrept2sd", "mrrept2s"]:
            ci_order = len(dynamic_corr_method[0]) - 7
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi",
                                        "nevpt2-ir", "nevpt2-i", "nevpt2-r"]:
            sub_space = dynamic_corr_method[0][7:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            mps_info_fold = NEVPTMPSInfo(
                n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi",
                                        "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
            sub_space = dynamic_corr_method[0][8:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            mps_info_fold = NEVPTMPSInfo(
                n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum, target, hamil.basis)
        for i in range(n_external - 1):
            _print("fold right %d / %d" % (i, n_external))
            mpo_fold = FusedMPO(mpo_fold, hamil.basis, mpo_fold.n_sites - 2,
                                mpo_fold.n_sites - 1, mps_info_fold.right_dims_fci[mpo_fold.n_sites - 2])
            hamil.basis = mpo_fold.basis
            hamil.n_sites = mpo_fold.n_sites
        for i in range(n_inactive - 1):
            _print("fold left %d / %d" % (i, n_inactive))
            mpo_fold = FusedMPO(mpo_fold, hamil.basis, 0, 1,
                                mps_info_fold.left_dims_fci[i + 2])
            hamil.basis = mpo_fold.basis
            hamil.n_sites = mpo_fold.n_sites
        for k, op in mpo_fold.tensors[0].ops.items():
            smat = CSRSparseMatrix()
            if op.sparsity() > 0.75:
                smat.from_dense(op)
                op.deallocate()
            else:
                smat.wrap_dense(op)
            mpo_fold.tensors[0].ops[k] = smat
        mpo_fold.sparse_form = 'S' + mpo_fold.sparse_form[1:]
        mpo_fold.tf = TensorFunctions(CSROperatorFunctions(hamil.opf.cg))
        for k, op in mpo_fold.tensors[-1].ops.items():
            smat = CSRSparseMatrix()
            if op.sparsity() > 0.75:
                smat.from_dense(op)
                op.deallocate()
            else:
                smat.wrap_dense(op)
            mpo_fold.tensors[-1].ops[k] = smat
        mpo_fold.sparse_form = mpo_fold.sparse_form[:-1] + 'S'
        mpo_fold.tf = TensorFunctions(CSROperatorFunctions(hamil.opf.cg))
        rhamil = hamil
    else:
        assert dynamic_corr_method is not None
        if dynamic_corr_method[0] in ['mrcisdt']:
            xl = -3, -3, -3
            xr = 3, 3, 3
        elif dynamic_corr_method[0] in ['mrcisd', 'nevpt2sd', 'mrrept2sd']:
            xl = -2, -2, -2
            xr = 2, 2, 2
        elif dynamic_corr_method[0] in ['mrcis', 'nevpt2s', 'mrrept2s']:
            xl = -1, -1, -1
            xr = 1, 1, 1
        elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr",
                                        "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                                        "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                                        "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
            # this is not correct yet
            if dynamic_corr_method[0].startswith('nevpt2-'):
                sub_space = dynamic_corr_method[0][7:]
            else:
                sub_space = dynamic_corr_method[0][8:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            xl = -n_ex_inactive, -n_ex_inactive, -n_ex_inactive
            xr = n_ex_external, n_ex_external, n_ex_external
        elif dynamic_corr_method[0] in ['casci']:
            xl = 0, 0, 0
            xr = 0, 0, 0
        else:
            assert len(nelec) == 1 and len(spin) == 1
            assert (nelec[0] + spin[0]) % 2 == 0
            n_alpha = (nelec[0] + spin[0]) // 2
            n_beta = (nelec[0] - spin[0]) // 2
            xl = -min(n_inactive, n_alpha), -min(n_inactive, n_beta), -min(n_inactive, nelec[0])
            xr = min(n_external, n_alpha), min(n_external, n_beta), min(n_external, nelec[0])
        if big_site_method == "fock":
            assert "nonspinadapted" in dic
            # special treatment for all aplpha/beta spatial orbtial
            if len(nelec) == 1 and len(spin) == 1 and nelec[0] == abs(spin[0]):
                if nelec[0] == spin[0]:
                    ref = VectorInt([i + i for i in range(n_inactive)])
                    xl = abs(xl[0]), 0, abs(xl[2])
                else:
                    ref = VectorInt([i + i + 1 for i in range(n_inactive)])
                    xl = 0, abs(xl[1]), abs(xl[2])
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *xl, ref)
            else:
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *[abs(x) for x in xl], VectorInt([]))
            poccr = SCIFockBigSite.ras_space(True, n_external, *xr, VectorInt([]))
            # need to include casci ref state for the first step even for nevpt2-r
            big_left_orig = SCIFockBigSite(n_orbs, n_inactive, False, fcidump, orb_sym, poccl, True)
            big_right_orig = SCIFockBigSite(n_orbs, n_external, True, fcidump, orb_sym, poccr, True)
        elif big_site_method == "csf":
            assert "nonspinadapted" not in dic
            big_left_orig = CSFBigSite(n_inactive, abs(xl[-1]), False, fcidump, orb_sym[:n_inactive])
            big_right_orig = CSFBigSite(n_external, abs(xr[-1]), True, fcidump, orb_sym[-n_external:])
        elif big_site_method == "drt":
            left_iqs = DRTBigSite.get_target_quanta(False, n_inactive, abs(xl[-1]), orb_sym[:n_inactive])
            right_iqs = DRTBigSite.get_target_quanta(True, n_external, abs(xr[-1]), orb_sym[-n_external:])
            big_left_orig = DRTBigSite(left_iqs, False, n_inactive,
                orb_sym[:n_inactive], fcidump, max(outputlevel, 0))
            big_right_orig = DRTBigSite(right_iqs, True, n_external,
                orb_sym[-n_external:], fcidump, max(outputlevel, 0))
        elif big_site_method == "bigdrt":
            left_iqs = DRTBigSite.get_target_quanta(False, 0, 0, orb_sym[:0])
            right_iqs = DRTBigSite.get_target_quanta(True, n_inactive + n_external, abs(xr[-1]),
                orb_sym[-(n_inactive + n_external):], nc_ref=n_inactive)
            big_left_orig = DRTBigSite(left_iqs, False, 0, orb_sym[:0], fcidump, max(outputlevel, 0))
            big_right_orig = DRTBigSite(right_iqs, True, n_inactive + n_external,
                orb_sym[-(n_inactive + n_external):], fcidump, max(outputlevel, 0))
            big_right_orig.drt = DRT(
                big_right_orig.drt.n_sites,
                big_right_orig.drt.get_init_qs(),
                big_right_orig.drt.orb_sym, n_inactive, n_external, n_ex=abs(xr[-1]), nc_ref=n_inactive,
            )
            right_iqs_z = DRTBigSite.get_target_quanta(True, n_inactive + n_external, 0,
                orb_sym[-(n_inactive + n_external):], nc_ref=n_inactive)
            big_right_orig_z = DRTBigSite(right_iqs_z, True, n_inactive + n_external,
                orb_sym[-(n_inactive + n_external):], fcidump, max(outputlevel, 0))
            big_right_orig_z.drt = DRT(
                big_right_orig_z.drt.n_sites,
                big_right_orig_z.drt.get_init_qs(),
                big_right_orig_z.drt.orb_sym, n_inactive, n_external, n_ex=0, nc_ref=n_inactive,
            )
        else:
            raise NotImplementedError
        big_left = SimplifiedBigSite(big_left_orig, simpl_rule)
        big_right = SimplifiedBigSite(big_right_orig, simpl_rule)
        if big_site_method == "bigdrt":
            big_right_z = SimplifiedBigSite(big_right_orig_z, simpl_rule)
        if MPI is not None:
            big_left_np = big_left
            big_right_np = big_right
            if one_body_only:
                big_left = ParallelBigSite(big_left, prule_one_body)
                big_right = ParallelBigSite(big_right, prule_one_body)
            else:
                big_left = ParallelBigSite(big_left, prule)
                big_right = ParallelBigSite(big_right, prule)
            if big_site_method == "bigdrt":
                big_right_np_z = big_right_z
                if one_body_only:
                    big_right_z = ParallelBigSite(big_right_z, prule_one_body)
                else:
                    big_right_z = ParallelBigSite(big_right_z, prule)
            hamil_np = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 or big_site_method == "bigdrt" else big_left_np,
                    None if (n_external == 0 and big_site_method != "bigdrt") or n_inactive + n_external == 0 else big_right_np)
        hamil = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 or big_site_method == "bigdrt" else big_left,
                    None if (n_external == 0 and big_site_method != "bigdrt") or n_inactive + n_external == 0 else big_right)
        rhamil = hamil
        if big_site_method == "bigdrt" and dynamic_corr_method[0] in ["casci", "nevpt2s", "nevpt2sd", "nevpt2-ijrs", "nevpt2-ij",
                        "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                        "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                        "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
            hamil_np = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 or big_site_method == "bigdrt" else big_left_np,
                    None if (n_external == 0 and big_site_method != "bigdrt") or n_inactive + n_external == 0 else big_right_np_z)
            hamil = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 or big_site_method == "bigdrt" else big_left,
                    None if (n_external == 0 and big_site_method != "bigdrt") or n_inactive + n_external == 0 else big_right_z)
    n_sites = hamil.n_sites

else:
    orb_idx = np.load(scratch + '/orbital_reorder.npy')
    n_orbs = len(orb_idx)
    if n_orbs == 2 and qc_type == QCTypes.Conventional:
        qc_type = QCTypes.NC
        _print("WARN: changed qc_mpo_type to NC because there are only 2 sites!")
    if "nofiedler" in dic or "noreorder" in dic:
        orb_idx = None
    orb_sym = None
    fcidump = None
    spin = [int(x) for x in dic.get("spin", "0").split()]
    if singlet_embedding and SX == SU2:
        assert len(spin) == 1
        singlet_embedding_spin = spin[0]

if min_mpo_mem == "auto":
    gframe.minimal_memory_usage = n_orbs >= 120
else:
    gframe.minimal_memory_usage = min_mpo_mem.lower()[0] in ['1', 't']

_print('MinMPOMemUsage = ', gframe.minimal_memory_usage)

if no_pre_run:
    impo = MPO(0)
    impo.load_data(scratch + '/mpo-ident.bin', minimal=True)
    n_sites = impo.n_sites
    if "k_symmetry" in dic:
        k_mod = 0
        for b in impo.basis:
            for bb in b.quanta:
                k_mod = k_mod | bb.pg_k_mod
    nelec = [int(x) for x in dic["nelec"].split()]
    spin = [int(x) for x in dic.get("spin", "0").split()]
    isym = [int(x) for x in dic.get("irrep", "1").split()]
    iksym = [int(x) for x in dic.get("k_irrep", "0").split()]
    targets = []
    swap_pg = getattr(PointGroup, "swap_" + dic.get("sym", "d2h"))
    for inelec in nelec:
        for ispin in spin:
            for iisym in isym:
                if "k_symmetry" in dic:
                    for iiksym in iksym:
                        targets.append(SX(inelec, ispin,
                                          SX.pg_combine(swap_pg(iisym), iiksym, k_mod)))
                else:
                    targets.append(SX(inelec, ispin, swap_pg(iisym)))
    assert len(targets) != 0
    if singlet_embedding and SX == SU2:
        for it, targ in enumerate(targets):
            if targ.twos != 0:
                targets[it] = SX(targ.n + targ.twos, 0, targ.pg)
    if len(targets) == 1:
        target = targets[0]

# parallelization over sites
# use keyword: conn_centers auto 5      (5 is number of procs)
#          or  conn_centers 10 20 30 40 (list of connection site indices)
if "conn_centers" in dic:
    assert MPI is not None
    if "twodot_to_onedot" in dic:
        raise RuntimeError("twodot_to_onedot with conn_centers is not supported!")
    cc = dic["conn_centers"].split()
    if cc[0] == "auto":
        ncc = int(cc[1])
        conn_centers = list(
            np.arange(0, n_sites * ncc, n_sites, dtype=int) // ncc)[1:]
        assert len(conn_centers) == ncc - 1
    else:
        conn_centers = [int(xcc) for xcc in cc]
    _print("using connection sites: ", conn_centers)
    assert MPI.size % (len(conn_centers) + 1) == 0
    mps_prule = prule
    prule = prule.split(MPI.size // (len(conn_centers) + 1))
else:
    conn_centers = None

if "hf_occ" in dic and dic["hf_occ"] != "integral" and \
    len(dic["hf_occ"].split()) == n_sites and "occ" not in dic:
    dic["warmup"] = "occ"
    dic["occ"] = dic["hf_occ"]
    if "cbias" not in dic:
        dic["cbias"] = 0.2

if dic.get("warmup", None) == "occ":
    _print("using occ init")
    assert "occ" in dic
    if len(dic["occ"].split()) == 1:
        with open(dic["occ"], 'r') as ofin:
            dic["occ"] = ofin.readlines()[0]
    occs = VectorDouble([float(occ)
                         for occ in dic["occ"].split() if len(occ) != 0])
    if orb_idx is not None:
        occs = FCIDUMP.array_reorder(occs, VectorUInt16(orb_idx))
        _print("using reordered occ init")
    assert len(occs) == n_sites or len(occs) == n_sites * \
        2 or len(occs) == n_sites * 4
    cbias = float(dic.get("cbias", 0.0))
    if cbias != 0.0:
        if len(occs) == n_sites:
            if "use_general_spin" in dic:
                occs = VectorDouble(
                    [c - cbias if c >= 0.5 else c + cbias for c in occs])
            else:
                occs = VectorDouble(
                    [c - cbias if c >= 1 else c + cbias for c in occs])
        elif len(occs) == 2 * n_sites:
            occs = VectorDouble(
                [c - cbias if c >= 0.5 else c + cbias for c in occs])
        elif len(occs) == 4 * n_sites:
            moccs = np.array(occs).reshape((n_sites, 4))
            f = (1 - cbias) / moccs.sum(axis=1)[:, None]
            moccs = moccs * f + cbias / 4
            occs = VectorDouble(moccs.ravel())
        else:
            assert False
    bias = float(dic.get("bias", 1.0))
else:
    occs = None

dot = 1 if "onedot" in dic or ("zerodot" in dic and "twodot_to_onedot" not in dic) else 2
nroots = int(dic.get("nroots", 1))
mps_tags = dic.get("mps_tags", "KET").split()
read_tags = dic.get("read_mps_tags", "KET").split()
proj_tags = dic.get("proj_mps_tags", "").split()
soc = "soc" in dic
overlap = "overlap" in dic
conv_npdm = ("conventional_npdm" in dic) or soc

def fmt_size(i, suffix='B'):
    if i < 1000:
        return "%d %s" % (i, suffix)
    else:
        a = 1024
        for pf in "KMGTPEZY":
            p = 2
            for k in [10, 100, 1000]:
                if i < k * a:
                    return "%%.%df %%s%%s" % p % (i / a, pf, suffix)
                p -= 1
            a *= 1024
    return "??? " + suffix


if "compression" in dic or "stopt_compression" in dic or "delta_t" in dic:
    if mps_tags == read_tags:
        raise RuntimeError("""For compression and time evolution, the input MPS
            tags "read_mps_tags" and the output MPS tags "mps_tags" cannot
            be the same!""")

if "statespecific" in dic and "proj_weights" in dic:
    if not (len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0])) \
        and not os.path.isfile(scratch + "/mps_info.bin"):
        del dic["fullrestart"]

# prepare mps
if len(mps_tags) > 1 or ("compression" in dic and "random_mps_init" not in dic) \
   or "stopt_sampling" in dic or "delta_t" in dic:
    nroots = len(mps_tags)
    mps = None
    mps_info = None
    forward = False
elif "fullrestart" in dic:
    _print("full restart")
    mps_info = MPSInfo(0) if nroots == 1 and len(
        targets) == 1 and not complex_mps and "use_hybrid_complex" not in dic else MultiMPSInfo(0)
    if len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0]):
        mps_info.load_data(scratch + "/%s-mps_info.bin" % mps_tags[0])
    else:
        mps_info.load_data(scratch + "/mps_info.bin")
    mps_info.tag = mps_tags[0]
    mps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in mps_info.left_dims])
    if mps_info.bond_dim < max_bdim:
        mps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in mps_info.right_dims])
    if mps_info.bond_dim < max_bdim:
        mps_info.bond_dim = max_bdim
    mps = MPS(mps_info) if nroots == 1 and len(
        targets) == 1 and not complex_mps and "use_hybrid_complex" not in dic else MultiMPS(mps_info)
    mps.load_data()
    if mps.dot != dot:
        if MPI is not None:
            MPI.barrier()
        mps.dot = dot
        mps.save_data()
        if MPI is not None:
            MPI.barrier()
    if "use_hybrid_complex" in dic:
        mps.nroots = nroots * 2
        mps.wfns = mps.wfns[:nroots * 2]
        mps.weights = mps.weights[:nroots * 2]
    elif nroots != 1 and not complex_mps:
        mps.nroots = nroots
        mps.wfns = mps.wfns[:nroots]
        mps.weights = mps.weights[:nroots]
    weights = dic.get("weights", None)
    if weights is not None:
        mps.weights = VectorFP([float(x) for x in weights.split()])
    mps.load_mutable()
    forward = mps.center == 0
    if mps.canonical_form[mps.center] == 'L' and mps.center != mps.n_sites - mps.dot:
        mps.center += 1
        forward = True
        if mps.canonical_form[mps.center] in "ST" and mps.dot == 2:
            if MPI is not None:
                MPI.barrier()
            mps.flip_fused_form(
                mps.center, CG(), prule if MPI is not None else None)
            mps.save_data()
            if MPI is not None:
                MPI.barrier()
            mps.load_mutable()
            mps.info.load_mutable()
            if MPI is not None:
                MPI.barrier()
    elif mps.canonical_form[mps.center] in "CMKJST" and mps.center != 0:
        if mps.canonical_form[mps.center] in "KJ" and mps.dot == 2:
            if MPI is not None:
                MPI.barrier()
            mps.flip_fused_form(
                mps.center, CG(), prule if MPI is not None else None)
            mps.save_data()
            if MPI is not None:
                MPI.barrier()
            mps.load_mutable()
            mps.info.load_mutable()
            if MPI is not None:
                MPI.barrier()
        if not mps.canonical_form[mps.center:mps.center + 2] == "CC" and mps.dot == 2:
            mps.center -= 1
        forward = False
    elif mps.center == mps.n_sites - 1 and mps.dot == 2:
        if MPI is not None:
            MPI.barrier()
        if mps.canonical_form[mps.center] in "KJ":
            mps.flip_fused_form(
                mps.center, CG(), prule if MPI is not None else None)
        mps.center = mps.n_sites - 2
        mps.save_data()
        forward = False
        if MPI is not None:
            MPI.barrier()
        mps.load_mutable()
        mps.info.load_mutable()
        if MPI is not None:
            MPI.barrier()
    elif mps.center == 0 and mps.dot == 2:
        if MPI is not None:
            MPI.barrier()
        if mps.canonical_form[mps.center] in "ST":
            mps.flip_fused_form(
                mps.center, CG(), prule if MPI is not None else None)
        mps.save_data()
        forward = True
        if MPI is not None:
            MPI.barrier()
        mps.load_mutable()
        mps.info.load_mutable()
        if MPI is not None:
            MPI.barrier()
elif pre_run or not no_pre_run:
    if "trans_mps_info" in dic:
        assert nroots == 1 and len(targets) == 1
        tr_vacuum = TrSX(vacuum.n, abs(vacuum.twos), vacuum.pg)
        tr_target = TrSX(target.n, abs(target.twos), target.pg)
        tr_basis = TrVectorStateInfo([trans_si(b) for b in hamil.basis])
        tr_mps_info = TrMPSInfo(n_sites, tr_vacuum, tr_target, tr_basis)
        assert "full_fci_space" not in dic
        tr_mps_info.tag = mps_tags[0]
        if occs is None:
            tr_mps_info.set_bond_dimension(bond_dims[0])
        else:
            tr_mps_info.set_bond_dimension_using_occ(
                bond_dims[0], occs, bias=bias)
        mps_info = trans_mi(tr_mps_info, target)
    else:
        basis = hamil.basis
        assert n_sites % condense_mpo == 0
        mps_n_sites = n_sites // condense_mpo
        icd = 1
        while icd < condense_mpo:
            basis = MPSInfo.condense_basis(basis)
            icd *= 2
        if nroots == 1 and len(targets) == 1 and "use_hybrid_complex" not in dic:
            if dynamic_corr_method is not None:
                if dynamic_corr_method[0] in ["casci", "nevpt2s", "nevpt2sd", "nevpt2-ijrs", "nevpt2-ij",
                        "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                        "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                        "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                    if big_site_method == "bigdrt":
                        acts = VectorActTypes([ActiveTypes.Active] * n_cas + [ActiveTypes.Frozen])
                        basis = basis.__class__(list(basis))
                        mps_info = CASCIMPSInfo(mps_n_sites, vacuum, target, basis, acts)
                    elif big_site_method is not None:
                        mps_info = CASCIMPSInfo(mps_n_sites, vacuum, target, basis,
                                                1 if n_inactive != 0 else 0, n_cas, 1 if n_external != 0 else 0)
                    else:
                        mps_info = CASCIMPSInfo(
                            mps_n_sites, vacuum, target, basis, n_inactive, n_cas, n_external)
                elif dynamic_corr_method[0] in ["mrcis", "mrcisd", "mrcisdt"]:
                    if big_site_method is not None:
                        mps_info = MPSInfo(
                            mps_n_sites, vacuum, target, basis)
                    else:
                        ci_order = len(dynamic_corr_method[0]) - 4
                        mps_info = MRCIMPSInfo(
                            mps_n_sites, n_inactive, n_external, ci_order, vacuum, target, basis)
                else:
                    mps_info = MPSInfo(mps_n_sites, vacuum, target, basis)
            else:
                mps_info = MPSInfo(mps_n_sites, vacuum, target, basis)
        else:
            assert dynamic_corr_method is None
            _print('TARGETS = ', list(targets))
            mps_info = MultiMPSInfo(mps_n_sites, vacuum, targets, basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                mps_info.set_bond_dimension_full_fci(left_vacuum, right_vacuum)
            else:
                mps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                mps_info.set_bond_dimension_full_fci()
        mps_info.tag = mps_tags[0]
        if occs is None:
            mps_info.set_bond_dimension(bond_dims[0])
        else:
            mps_info.set_bond_dimension_using_occ(
                bond_dims[0], occs, bias=bias)
    if "skip_inact_ext_sites" in dic:
        assert dynamic_corr_method is not None
        mps_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)
    if MPI is None or MPI.rank == 0:
        mps_info.save_data(scratch + '/mps_info.bin')
        mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    if conn_centers is not None:
        assert nroots == 1
        mps = ParallelMPS(mps_info.n_sites, init_mps_center, dot, mps_prule)
        if "svd_eps" in dic:
            mps.svd_eps = float(dic["svd_eps"])
        if "svd_cutoff" in dic:
            mps.svd_cutoff = float(dic["svd_cutoff"])
    elif nroots != 1 or len(targets) != 1 or "use_hybrid_complex" in dic:
        if "use_hybrid_complex" in dic:
            mps = MultiMPS(mps_info.n_sites, init_mps_center, dot, nroots * 2)
        else:
            mps = MultiMPS(mps_info.n_sites, init_mps_center, dot, nroots)
        weights = dic.get("weights", None)
        if weights is not None:
            mps.weights = VectorFP([float(x) for x in weights.split()])
    else:
        mps = MPS(mps_info.n_sites, init_mps_center, dot)
    mps.initialize(mps_info)
    mps.random_canonicalize()
    if nroots == 1 and "use_hybrid_complex" not in dic:
        mps.tensors[mps.center].normalize()
    else:
        for xwfn in mps.wfns:
            xwfn.normalize()
    if "skip_inact_ext_sites" in dic:
        mps.set_inact_ext_identity(n_inactive, n_external)
    forward = mps.center == 0
else:
    mps_info = MPSInfo(0) if nroots == 1 and len(
        targets) == 1 and "use_hybrid_complex" not in dic else MultiMPSInfo(0)
    if len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0]):
        mps_info.load_data(scratch + "/%s-mps_info.bin" % mps_tags[0])
    else:
        mps_info.load_data(scratch + "/mps_info.bin")
    mps_info.tag = mps_tags[0]
    if occs is None:
        mps_info.set_bond_dimension(bond_dims[0])
    else:
        mps_info.set_bond_dimension_using_occ(
            bond_dims[0], occs, bias=bias)
    if "skip_inact_ext_sites" in dic:
        assert dynamic_corr_method is not None
        mps_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)

    if conn_centers is not None:
        assert nroots == 1
        mps = ParallelMPS(mps_info.n_sites, init_mps_center, dot, mps_prule)
        if "svd_eps" in dic:
            mps.svd_eps = float(dic["svd_eps"])
        if "svd_cutoff" in dic:
            mps.svd_cutoff = float(dic["svd_cutoff"])
    elif nroots != 1 or len(targets) != 1 or "use_hybrid_complex" in dic:
        if "use_hybrid_complex" in dic:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots * 2)
        else:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots)
        weights = dic.get("weights", None)
        if weights is not None:
            mps.weights = VectorFP([float(x) for x in weights.split()])
    else:
        mps = MPS(mps_info.n_sites, init_mps_center, dot)
    mps.initialize(mps_info)
    mps.random_canonicalize()
    if nroots == 1 and "use_hybrid_complex" not in dic:
        mps.tensors[mps.center].normalize()
    else:
        for xwfn in mps.wfns:
            xwfn.normalize()
    if "skip_inact_ext_sites" in dic:
        mps.set_inact_ext_identity(n_inactive, n_external)
    forward = mps.center == 0

if mps is not None:
    _print("MPS = ", mps.canonical_form, mps.center, mps.dot,
        mps.info.target if nroots == 1 else list(mps.info.targets))
    _print("GS INIT MPS BOND DIMS = ", ''.join(
        ["%6d" % x.n_states_total for x in mps_info.left_dims]))

if conn_centers is not None and "fullrestart" in dic:
    assert mps.dot == 2
    mps = ParallelMPS(mps, mps_prule)
    if "svd_eps" in dic:
        mps.svd_eps = float(dic["svd_eps"])
    if "svd_cutoff" in dic:
        mps.svd_cutoff = float(dic["svd_cutoff"])
    if mps.canonical_form[0] == 'C' and mps.canonical_form[1] == 'R':
        mps.canonical_form = 'K' + mps.canonical_form[1:]
    elif mps.canonical_form[-1] == 'C' and mps.canonical_form[-2] == 'L':
        mps.canonical_form = mps.canonical_form[:-1] + 'S'
        mps.center = mps.n_sites - 1

try:
    import psutil
    mem = psutil.Process(os.getpid()).memory_info().rss
    _print("pre-mpo memory usage = ", fmt_size(mem))
except ImportError:
    pass

# prepare mpo
if pre_run or not no_pre_run:
    # mpo for dmrg
    _print("build mpo start ...")
    txx = time.perf_counter()
    if big_site_method == "folding":
        mpo = mpo_fold
    else:
        if qc_mpo_trans_center == -1:
            qc_mpo_trans_center = hamil.n_sites // 2
        if condense_mpo == 1:
            mpo = MPOQC(hamil, qc_type, "HQC", qc_mpo_trans_center)
        else:
            mpo = MPOQC(hamil, qc_type, "HQC", qc_mpo_trans_center // condense_mpo * condense_mpo, condense_mpo)
            mpo.basis = hamil.basis
            icd = 1
            while icd < condense_mpo:
                mpo = CondensedMPO(mpo, mpo.basis)
                icd *= 2
    _print("build mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f" % (mpo.tread, mpo.twrite, time.perf_counter() - txx))
    _print("simpl mpo start ...")
    txx = time.perf_counter()
    mpo = SimplifiedMPO(mpo, simpl_rule, True, condense_mpo == 1,
                        OpNamesSet((OpNames.R, OpNames.RD)))
    _print("simpl mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f" % (mpo.tread, mpo.twrite, time.perf_counter() - txx))

    mpo_bdims = [None] * len(mpo.left_operator_names)
    for ix in range(len(mpo.left_operator_names)):
        mpo.load_left_operators(ix)
        x = mpo.left_operator_names[ix]
        mpo_bdims[ix] = x.m * x.n
        mpo.unload_left_operators(ix)
    _print('GS MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

    if MPI is None or MPI.rank == 0:
        mpo.save_data(scratch + '/mpo.bin')

    if "use_hybrid_complex" in dic:
        txx = time.perf_counter()
        mpo_cpx_h1e = MPOQCCPX(hamil_cpx_h1e, qc_type)
        mpo_cpx_h1e = SimplifiedMPOCPX(mpo_cpx_h1e, simpl_rule_cpx_h1e, True, True,
                        OpNamesSet((OpNames.R, OpNames.RD)))
        _print("cpx h1e mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f"
            % (mpo_cpx_h1e.tread, mpo_cpx_h1e.twrite, time.perf_counter() - txx))

        if MPI is None or MPI.rank == 0:
            mpo_cpx_h1e.save_data(scratch + '/mpo_cpx_h1e.bin')

    # mpo for 1pdm
    _print("build 1pdm mpo", time.perf_counter() - tx)
    pmpo = PDM1MPOQC(hamil_np or hamil, 1 if soc else 0)
    pmpo.basis = hamil.basis
    icd = 1
    while icd < condense_mpo:
        pmpo = CondensedMPO(pmpo, pmpo.basis, True)
        icd *= 2
    pmpo = SimplifiedMPO(pmpo,
                         NoTransposeRule(RuleQC()) if has_tran else RuleQC(),
                         True, condense_mpo == 1, OpNamesSet((OpNames.R, OpNames.RD)))

    if MPI is None or MPI.rank == 0:
        pmpo.save_data(scratch + '/mpo-1pdm.bin')

    if has_2pdm and conv_npdm:
        # mpo for 2pdm
        _print("build 2pdm mpo", time.perf_counter() - tx)
        p2mpo = PDM2MPOQC(hamil_np or hamil)
        p2mpo = SimplifiedMPO(p2mpo,
                              NoTransposeRule(
                                  RuleQC()) if has_tran else RuleQC(),
                              True, True, OpNamesSet((OpNames.R, OpNames.RD)))

        if MPI is None or MPI.rank == 0:
            p2mpo.save_data(scratch + '/mpo-2pdm.bin')

    if has_1npc:
        # mpo for particle number correlation
        _print("build 1npc mpo", time.perf_counter() - tx)
        nmpo = NPC1MPOQC(hamil_np or hamil)
        nmpo = SimplifiedMPO(nmpo, RuleQC(), True, True,
                             OpNamesSet((OpNames.R, OpNames.RD)))

        if MPI is None or MPI.rank == 0:
            nmpo.save_data(scratch + '/mpo-1npc.bin')

    # mpo for identity operator
    _print("build identity mpo", time.perf_counter() - tx)
    impo = IdentityMPO(hamil_np or hamil)
    impo.basis = hamil.basis
    icd = 1
    while icd < condense_mpo:
        impo = CondensedMPO(impo, impo.basis)
        icd *= 2
    impo = SimplifiedMPO(impo,
                         NoTransposeRule(RuleQC()) if has_tran else RuleQC(),
                         True, condense_mpo == 1, OpNamesSet((OpNames.R, OpNames.RD)))

    if MPI is None or MPI.rank == 0:
        impo.save_data(scratch + '/mpo-ident.bin')

    if para_pre_run:
        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(mpo, prule_one_body)
            else:
                mpo = ParallelMPO(mpo, prule)
            pmpo = ParallelMPO(pmpo, prule_pdm1)
            if has_2pdm and conv_npdm:
                p2mpo = ParallelMPO(p2mpo, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(nmpo, prule_pdm1)
            impo = ParallelMPO(impo, prule_ident)

        _print("para mpo finished", time.perf_counter() - tx)
        try:
            import psutil
            mem = psutil.Process(os.getpid()).memory_info().rss
            _print("memory usage = ", fmt_size(mem))
        except ImportError:
            pass

        mrank = MPI.rank if MPI is not None else 0
        mpo.reduce_data()
        mpo.save_data(scratch + '/mpo.bin.%d' % mrank)
        pmpo.reduce_data()
        pmpo.save_data(scratch + '/mpo-1pdm.bin.%d' % mrank)
        if has_2pdm and conv_npdm:
            p2mpo.reduce_data()
            p2mpo.save_data(scratch + '/mpo-2pdm.bin.%d' % mrank)
        if has_1npc:
            nmpo.reduce_data()
            nmpo.save_data(scratch + '/mpo-1npc.bin.%d' % mrank)
        impo.reduce_data()
        impo.save_data(scratch + '/mpo-ident.bin.%d' % mrank)

else:

    if not para_no_pre_run:

        mpo = MPO(0)
        mpo.load_data(scratch + '/mpo.bin')

        _print('GS MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in mpo.left_operator_names]))

        if "use_hybrid_complex" in dic:
            mpo_cpx_h1e = MPOCPX(0)
            mpo_cpx_h1e.load_data(scratch + '/mpo_cpx_h1e.bin')

        pmpo = MPO(0)
        pmpo.load_data(scratch + '/mpo-1pdm.bin')

        _print('1PDM MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in pmpo.left_operator_names]))

        if has_2pdm and conv_npdm:
            p2mpo = MPO(0)
            p2mpo.load_data(scratch + '/mpo-2pdm.bin')

            _print('2PDM MPO BOND DIMS = ', ''.join(
                ["%6d" % (x.m * x.n) for x in p2mpo.left_operator_names]))

        if has_1npc:
            nmpo = MPO(0)
            nmpo.load_data(scratch + '/mpo-1npc.bin')

            _print('1NPC MPO BOND DIMS = ', ''.join(
                ["%6d" % (x.m * x.n) for x in nmpo.left_operator_names]))

        impo = MPO(0)
        impo.load_data(scratch + '/mpo-ident.bin')

        _print('IDENT MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in impo.left_operator_names]))

    else:

        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(0, prule_one_body)
            else:
                mpo = ParallelMPO(0, prule)
            pmpo = ParallelMPO(0, prule_pdm1)
            if has_2pdm and conv_npdm:
                p2mpo = ParallelMPO(0, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(0, prule_pdm1)
            impo = ParallelMPO(0, prule_ident)
        else:
            mpo = MPO(0)
            pmpo = MPO(0)
            if has_2pdm and conv_npdm:
                p2mpo = MPO(0)
            if has_1npc:
                nmpo = MPO(0)
            impo = MPO(0)

        mrank = MPI.rank if MPI is not None else 0
        mpo.load_data(scratch + '/mpo.bin.%d' % mrank, minimal=True)
        pmpo.load_data(scratch + '/mpo-1pdm.bin.%d' % mrank, minimal=True)
        if has_2pdm and conv_npdm:
            p2mpo.load_data(scratch + '/mpo-2pdm.bin.%d' % mrank, minimal=True)
        if has_1npc:
            nmpo.load_data(scratch + '/mpo-1npc.bin.%d' % mrank, minimal=True)
        impo.load_data(scratch + '/mpo-ident.bin.%d' % mrank, minimal=True)

        if "use_hybrid_complex" in dic:
            if MPI is not None:
                if one_body_only:
                    mpo_cpx_h1e = ParallelMPOCPX(0, prule_one_body_cpx_h1e)
                else:
                    mpo_cpx_h1e = ParallelMPOCPX(0, prule_cpx_h1e)
            else:
                mpo_cpx_h1e = MPO(0)
            mpo_cpx_h1e.load_data(scratch + '/mpo_cpx_h1e.bin.%d' % mrank, minimal=True)

try:
    import psutil
    mem = psutil.Process(os.getpid()).memory_info().rss
    _print("memory usage = ", fmt_size(mem))
except ImportError:
    pass

if "release_integral" in dic and fcidump is not None:
    fcidump.deallocate()
    fcidump = None

    try:
        import psutil
        mem = psutil.Process(os.getpid()).memory_info().rss
        _print("integral deallocated memory usage = ", fmt_size(mem))
    except ImportError:
        pass

def split_mps(iroot, mps, mps_info, mpi=MPI):
    mps.load_data()  # this will avoid memory sharing
    mps_info.load_mutable()
    mps.load_mutable()

    # break up a MultiMPS to single MPSs
    if len(mps_info.targets) != 1:
        smps_info = MultiMPSInfo(mps_info.n_sites, mps_info.vacuum,
                                 mps_info.targets, mps_info.basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci(
                    left_vacuum, right_vacuum)
            else:
                smps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci()
        smps_info.tag = mps_info.tag + "-%d" % iroot
        smps_info.bond_dim = mps_info.bond_dim
        for i in range(0, smps_info.n_sites + 1):
            smps_info.left_dims[i] = mps_info.left_dims[i]
            smps_info.right_dims[i] = mps_info.right_dims[i]
        smps_info.save_mutable()
        smps = MultiMPS(smps_info)
        smps.n_sites = mps.n_sites
        smps.center = mps.center
        smps.dot = mps.dot
        smps.canonical_form = '' + mps.canonical_form
        smps.tensors = mps.tensors[:]
        smps.wfns = mps.wfns[iroot:iroot + 1]
        smps.weights = mps.weights[iroot:iroot + 1]
        smps.weights[0] = 1
        smps.nroots = 1
        smps.save_mutable()
    else:
        smps_info = MPSInfo(mps_info.n_sites, mps_info.vacuum,
                            mps_info.targets[0], mps_info.basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci(
                    left_vacuum, right_vacuum)
            else:
                smps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci()
        smps_info.tag = mps_info.tag + "-%d" % iroot
        smps_info.bond_dim = mps_info.bond_dim
        for i in range(0, smps_info.n_sites + 1):
            smps_info.left_dims[i] = mps_info.left_dims[i]
            smps_info.right_dims[i] = mps_info.right_dims[i]
        smps_info.save_mutable()
        smps = MPS(smps_info)
        smps.n_sites = mps.n_sites
        smps.center = mps.center
        smps.dot = mps.dot
        smps.canonical_form = '' + mps.canonical_form
        smps.canonical_form = smps.canonical_form.replace(
            'T', 'S').replace('J', 'K')
        smps.tensors = mps.tensors[:]
        if smps.tensors[smps.center] is None:
            smps.tensors[smps.center] = mps.wfns[iroot][0]
        else:
            assert smps.center + 1 < smps.n_sites
            assert smps.tensors[smps.center + 1] is None
            smps.tensors[smps.center + 1] = mps.wfns[iroot][0]
        smps.save_mutable()

    if smps.center == 0 and dot == 2:
        if mpi is not None:
            mpi.barrier()
        if smps.canonical_form[smps.center] in "ST":
            smps.flip_fused_form(
                smps.center, CG(), prule if mpi is not None else None)
        smps.save_data()
        forward = True
        if mpi is not None:
            mpi.barrier()
        smps.load_mutable()
        smps.info.load_mutable()
        if mpi is not None:
            mpi.barrier()

    smps.dot = dot
    forward = smps.center == 0
    if smps.canonical_form[smps.center] == 'L' and smps.center != smps.n_sites - smps.dot:
        smps.center += 1
        forward = True
    elif (smps.canonical_form[smps.center] == 'C' or smps.canonical_form[smps.center] == 'M') and smps.center != 0:
        smps.center -= 1
        forward = False
    if smps.canonical_form[smps.center] == 'M' and not isinstance(smps, MultiMPS):
        smps.canonical_form = smps.canonical_form[:smps.center] + \
            'C' + smps.canonical_form[smps.center + 1:]
    if smps.canonical_form[-1] == 'M' and not isinstance(smps, MultiMPS):
        smps.canonical_form = smps.canonical_form[:-1] + 'C'
    if dot == 1:
        if smps.canonical_form[0] == 'C' and smps.canonical_form[1] == 'R':
            smps.canonical_form = 'K' + smps.canonical_form[1:]
        elif smps.canonical_form[-1] == 'C' and smps.canonical_form[-2] == 'L':
            smps.canonical_form = smps.canonical_form[:-1] + 'S'
            smps.center = smps.n_sites - 1
        if smps.canonical_form[0] == 'M' and smps.canonical_form[1] == 'R':
            smps.canonical_form = 'J' + smps.canonical_form[1:]
        elif smps.canonical_form[-1] == 'M' and smps.canonical_form[-2] == 'L':
            smps.canonical_form = smps.canonical_form[:-1] + 'T'
            smps.center = smps.n_sites - 1

    mps.deallocate()
    mps_info.deallocate_mutable()
    smps.save_data()
    return smps, smps_info, forward


def get_mps_from_tags(iroot, proj_mps=False, ref_center=0):
    if proj_mps:
        _print('----- proj = %3d tag = %s -----' % (iroot, proj_tags[iroot]))
        tag = proj_tags[iroot]
    elif iroot >= 0:
        _print('----- root = %3d tag = %s -----' % (iroot, mps_tags[iroot]))
        tag = mps_tags[iroot]
    else:
        _print('----- cps/te init tag = %s -----' % read_tags[0])
        tag = read_tags[0]
    smps_info = MPSInfo(0) if not complex_mps else MultiMPSInfo(0)
    smps_info.load_data(scratch + "/%s-mps_info.bin" % tag)
    if MPI is not None:
        MPI.barrier()
    if not complex_mps:
        smps = MPS(smps_info).deep_copy(smps_info.tag + "-%d" % iroot)
    else:
        smps = MultiMPS(smps_info).deep_copy(smps_info.tag + "-%d" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps_info = smps.info
    smps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in smps_info.left_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in smps_info.right_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    smps.load_data()
    if MPI is not None:
        MPI.barrier()
    if smps.dot == 2 and smps.center == smps.n_sites - 1 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == 0 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == smps.n_sites - 1:
        if complex_mps:
            _print('change canonical form ...')
            cf = str(smps.canonical_form)
            smps.dot = 1
            ime = MovingEnvironment(impo, smps, smps, "IEX")
            ime.delayed_contraction = OpNamesSet.normal_ops()
            ime.cached_contraction = cached_contraction
            ime.init_environments(False)
            expect = ComplexExpect(ime, smps.info.bond_dim, smps.info.bond_dim)
            expect.iprint = max(min(outputlevel, 3), 0)
            expect.solve(True, smps.center == 0)
            if MPI is not None:
                MPI.barrier()
            smps.dot = 2
            smps.save_data()
            if MPI is not None:
                MPI.barrier()
            if smps.canonical_form[smps.center] in "ST":
                smps.flip_fused_form(
                    smps.center, CG(), prule if MPI is not None else None)
            smps.save_data()
            if MPI is not None:
                MPI.barrier()
            _print(cf + ' -> ' + smps.canonical_form)
    if smps.dot == 1 and dot == 2:
        if smps.center == 0 and smps.canonical_form[0] == 'S':
            smps.move_right(CG(), prule if MPI is not None else None)
            smps.center = 0
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'K':
            smps.move_left(CG(), prule if MPI is not None else None)
            smps.center = smps.n_sites - 2
        smps.dot = dot
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
    if (smps.center == 0) != (ref_center == 0):
        _print('change canonical form ...')
        cf = str(smps.canonical_form)
        ime = MovingEnvironment(impo, smps, smps, "IEX")
        ime.delayed_contraction = OpNamesSet.normal_ops()
        ime.cached_contraction = cached_contraction
        ime.init_environments(False)
        if not complex_mps:
            expect = Expect(ime, smps.info.bond_dim, smps.info.bond_dim)
        else:
            expect = ComplexExpect(ime, smps.info.bond_dim, smps.info.bond_dim)
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, smps.center == 0)
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
        _print(cf + ' -> ' + smps.canonical_form)
    forward = smps.center == 0
    return smps, smps.info, forward


def get_state_specific_mps(iroot, mps_info):
    smps_info = MPSInfo(0)
    smps_info.load_data(scratch + "/mps_info-ss-%d.bin" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps = MPS(smps_info).deep_copy(mps_info.tag + "-%d" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps_info = smps.info
    smps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in smps_info.left_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in smps_info.right_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    smps.load_data()
    if smps.dot == 2 and smps.center == smps.n_sites - 1 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == 0 and dot == 1:
        smps.dot = 1
    elif (smps.dot == 1 or smps.center == smps.n_sites - 1) and dot == 2:
        if smps.center == 0 and smps.canonical_form[0] == 'S':
            smps.move_right(CG(), prule)
            smps.center = 0
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'K':
            smps.move_left(CG(), prule)
            smps.center = smps.n_sites - 2
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'S':
            smps.center = smps.n_sites - 2
        smps.dot = dot
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
    forward = smps.center == 0
    _print('ss-mps', smps.center, smps.dot, dot, smps.canonical_form)
    return smps, smps_info, forward


if not pre_run:

    if not para_no_pre_run:

        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(mpo, prule_one_body)
            else:
                mpo = ParallelMPO(mpo, prule)
            pmpo = ParallelMPO(pmpo, prule_pdm1)
            if has_2pdm and conv_npdm:
                p2mpo = ParallelMPO(p2mpo, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(nmpo, prule_pdm1)
            impo = ParallelMPO(impo, prule_ident)

            if "use_hybrid_complex" in dic:
                if one_body_only:
                    mpo_cpx_h1e = ParallelMPOCPX(mpo_cpx_h1e, prule_one_body_cpx_h1e)
                else:
                    mpo_cpx_h1e = ParallelMPOCPX(mpo_cpx_h1e, prule_cpx_h1e)

        _print("para mpo finished", time.perf_counter() - tx)

    if mps is not None:
        mps.save_data()
        mps.save_mutable()
        mps.deallocate()
        mps_info.save_mutable()
        mps_info.deallocate_mutable()

    if conn_centers is not None:
        mps.conn_centers = VectorInt(conn_centers)

    # state-specific DMRG
    if "statespecific" in dic and "restart_onepdm" not in dic \
            and "restart_correlation" not in dic and "restart_tran_twopdm" not in dic \
            and "restart_oh" not in dic and "restart_twopdm" not in dic \
            and "restart_threepdm" not in dic and "restart_fourpdm" not in dic \
            and "restart_tran_threepdm" not in dic and "restart_tran_fourpdm" not in dic \
            and "restart_fock_fourpdm" not in dic \
            and "restart_nevpt2_npdm" not in dic and "restart_mps_nevpt" not in dic \
            and "restart_tran_onepdm" not in dic and "restart_tran_oh" not in dic \
            and "restart_copy_mps" not in dic and "restart_sample" not in dic:
        assert isinstance(mps, MultiMPS)
        assert nroots != 1

        ext_mpss = []
        dmrg_energies = []
        for iroot in range(nroots):
            tx = time.perf_counter()
            _print('----- root = %3d / %3d -----' % (iroot, nroots))
            ext_mpss.append(mps.extract(iroot, mps.info.tag + "-%d" % iroot)
                               .make_single(mps.info.tag + "-S%d" % iroot))
            for iex, ext_mps in enumerate(ext_mpss):
                _print(iex, ext_mpss[iex].canonical_form, ext_mpss[iex].center)
                if (ext_mps.dot == 1 or ext_mps.center == ext_mps.n_sites - 1) and dot == 2:
                    if ext_mps.center == 0 and ext_mps.canonical_form[0] == 'S':
                        ext_mps.move_right(CG(), prule if MPI is not None else None)
                        ext_mps.center = 0
                    elif ext_mps.center == ext_mps.n_sites - 1 and ext_mps.canonical_form[ext_mps.center] == 'K':
                        ext_mps.move_left(CG(), prule if MPI is not None else None)
                        ext_mps.center = ext_mps.n_sites - 2
                    elif ext_mps.center == ext_mps.n_sites - 1 and ext_mps.canonical_form[ext_mps.center] == 'S':
                        ext_mps.center = ext_mps.n_sites - 2
                    ext_mps.dot = dot
                    ext_mps.save_data()
                _print(iex, ext_mpss[iex].canonical_form, ext_mpss[iex].center)
            if ext_mpss[0].center != ext_mpss[iroot].center:
                _print('change canonical form ...')
                cf = str(ext_mpss[iroot].canonical_form)
                ime = MovingEnvironment(
                    impo, ext_mpss[iroot], ext_mpss[iroot], "IEX")
                ime.delayed_contraction = OpNamesSet.normal_ops()
                ime.cached_contraction = cached_contraction
                ime.init_environments(False)
                expect = Expect(
                    ime, ext_mpss[iroot].info.bond_dim, ext_mpss[iroot].info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                expect.solve(True, ext_mpss[iroot].center == 0)
                ext_mpss[iroot].save_data()
                _print(cf + ' -> ' + ext_mpss[iroot].canonical_form)

            me = MovingEnvironment(
                mpo, ext_mpss[iroot], ext_mpss[iroot], "DMRG")
            me.delayed_contraction = OpNamesSet.normal_ops()
            me.cached_contraction = cached_contraction
            me.save_partition_info = True
            me.init_environments(outputlevel >= 2)

            _print("env init finished", time.perf_counter() - tx)

            dmrg = DMRG(me, VectorUBond(bond_dims), VectorFP(noises))
            dmrg.ext_mpss = VectorMPS(ext_mpss[:iroot])
            dmrg.state_specific = True
            proj_weights = dic.get("proj_weights", None)
            if proj_weights is not None:
                proj_weights = [float(x) for x in proj_weights.split()][:iroot]
                if len(proj_weights) == 1:
                    proj_weights = proj_weights * iroot
                dmrg.projection_weights = VectorFP(proj_weights)
            dmrg.iprint = max(min(outputlevel, 3), 0)
            for ext_mps in dmrg.ext_mpss:
                ext_me = MovingEnvironment(
                    impo, ext_mpss[iroot], ext_mps, "EX" + ext_mps.info.tag)
                ext_me.delayed_contraction = OpNamesSet.normal_ops()
                ext_me.init_environments(outputlevel >= 2)
                dmrg.ext_mes.append(ext_me)
            if "lowmem_noise" in dic:
                dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
            if "dm_noise" in dic:
                dmrg.noise_type = NoiseTypes.DensityMatrix
            elif decomp_type != DecompositionTypes.SVD:
                dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollected
            else:
                dmrg.noise_type = NoiseTypes.ReducedPerturbative
            dmrg.cutoff = float(dic.get("cutoff", 1E-14))
            dmrg.decomp_type = decomp_type
            dmrg.trunc_type = trunc_type
            dmrg.davidson_conv_thrds = VectorFP(dav_thrds)
            dmrg.davidson_max_iter = int(dic.get("davidson_max_iter", 5000))
            dmrg.davidson_soft_max_iter = int(
                dic.get("davidson_soft_max_iter", 4000))
            dmrg.davidson_def_max_size = int(
                dic.get("davidson_def_max_size", 50))
            dmrg.store_wfn_spectra = store_wfn_spectra
            dmrg.site_dependent_bond_dims = site_dependent_bdims

            sweep_energies = []
            discarded_weights = []
            if "twodot_to_onedot" not in dic:
                E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
            else:
                tto = int(dic["twodot_to_onedot"])
                assert len(bond_dims) > tto
                dmrg.solve(tto, forward, 0)
                # save the twodot part energies and discarded weights
                sweep_energies.append(np.array(dmrg.energies))
                discarded_weights.append(np.array(dmrg.discarded_weights))
                dmrg.me.dot = 1
                for ext_me in dmrg.ext_mes:
                    ext_me.dot = 1
                dmrg.bond_dims = VectorUBond(bond_dims[tto:])
                dmrg.noises = VectorFP(noises[tto:])
                dmrg.davidson_conv_thrds = VectorFP(dav_thrds[tto:])
                E_dmrg = dmrg.solve(len(bond_dims) - tto,
                                    ext_mpss[iroot].center == 0, sweep_tol)
                ext_mpss[iroot].dot = 1

            if MPI is None or MPI.rank == 0:
                for ir in range(iroot + 1):
                    ext_mpss[ir].save_data()

            if conn_centers is not None:
                me.finalize_environments()

            sweep_energies.append(np.array(dmrg.energies))
            discarded_weights.append(np.array(dmrg.discarded_weights))
            sweep_energies = np.vstack(sweep_energies)
            discarded_weights = np.hstack(discarded_weights)

            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_dmrg-%d.npy" % iroot, E_dmrg)
                dmrg_energies.append(E_dmrg)
                np.save(scratch + "/bond_dims-%d.npy" %
                        iroot, bond_dims[:len(discarded_weights)])
                np.save(scratch + "/sweep_energies-%d.npy" %
                        iroot, sweep_energies)
                np.save(scratch + "/discarded_weights-%d.npy" %
                        iroot, discarded_weights)
            _print("DMRG Energy for root %4d = %20.15f" % (iroot, E_dmrg))

            if MPI is None or MPI.rank == 0:
                ext_mpss[iroot].info.save_data(
                    scratch + '/mps_info-ss-%d.bin' % iroot)
                ext_mpss[iroot].info.save_data(
                    scratch + '/%s-mps_info-ss-%d.bin' % (mps_tags[0], iroot))
        
        if MPI is None or MPI.rank == 0:
            if stackblock_compat:
                with open(os.path.join(scratch + "/dmrg.e"), "wb") as f:
                    import struct
                    f.write(struct.pack('d' * nroots, *dmrg_energies))
            if openmolcas_compat:
                with open(os.path.join(scratch + "/../block.energy"), "w") as f:
                    for enx in dmrg_energies:
                        f.write("%25.12f\n" % enx)

        if "twodot_to_onedot" in dic:
            dot = 1

    # GS DMRG
    if "restart_onepdm" not in dic and "restart_twopdm" not in dic \
            and "restart_correlation" not in dic and "restart_tran_twopdm" not in dic \
            and "restart_oh" not in dic and "statespecific" not in dic \
            and "restart_tran_onepdm" not in dic and "restart_tran_oh" not in dic \
            and "restart_threepdm" not in dic and "restart_tran_threepdm" not in dic \
            and "restart_fourpdm" not in dic and "restart_tran_fourpdm" not in dic \
            and "restart_fock_fourpdm" not in dic \
            and "restart_nevpt2_npdm" not in dic and "restart_mps_nevpt" not in dic \
            and "restart_copy_mps" not in dic and "restart_sample" not in dic \
            and "delta_t" not in dic and "compression" not in dic and "stopt_sampling" not in dic:

        me = MovingEnvironment(mpo, mps, mps, "DMRG")
        if "use_hybrid_complex" not in dic:
            if condense_mpo == 1:
                me.delayed_contraction = OpNamesSet.normal_ops()
            me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        if "use_hybrid_complex" in dic:
            cpx_me = MovingEnvironmentX(mpo_cpx_h1e, mps, mps, "DMRG-CPX")
            cpx_me.cached_contraction = False
            cpx_me.init_environments(outputlevel >= 2)

        if conn_centers is not None:
            forward = mps.center == 0

        _print("env init finished", time.perf_counter() - tx)

        if big_site_method is not None and n_cas != 0:
            dmrg = DMRGBigSite(me, VectorUBond(bond_dims), VectorFP(noises))
            dmrg.last_site_svd = True
            dmrg.last_site_1site = dot == 2
            dmrg.decomp_last_site = False
        else:
            dmrg = DMRG(me, VectorUBond(bond_dims), VectorFP(noises))
        dmrg.iprint = max(min(outputlevel, 3), 0)

        if "skip_inact_ext_sites" in dic:
            dmrg.sweep_start_site = n_inactive
            dmrg.sweep_end_site = me.n_sites - n_external

        if "use_hybrid_complex" in dic:
            dmrg.cpx_me = cpx_me

        # projection
        if len(proj_tags) != 0:
            proj_weights = dic.get("proj_weights", None)
            assert proj_weights is not None
            proj_weights = VectorFP([float(x) for x in proj_weights.split()])
            if len(proj_weights) == 1:
                proj_weights = VectorFP(list(proj_weights) * len(proj_tags))
            assert len(proj_weights) == len(proj_tags)
            ext_mpss = []
            for ipj in range(len(proj_weights)):
                xmps, xmps_info, _ = get_mps_from_tags(ipj, True, mps.center)
                ext_mpss.append(xmps)
            dmrg.projection_weights = proj_weights
            dmrg.ext_mpss = VectorMPS(ext_mpss)
            for ext_mps in dmrg.ext_mpss:
                ext_me = MovingEnvironment(impo, mps, ext_mps, "PJ" + ext_mps.info.tag)
                ext_me.delayed_contraction = OpNamesSet.normal_ops()
                ext_me.init_environments(outputlevel >= 2)
                dmrg.ext_mes.append(ext_me)

        if "lowmem_noise" in dic:
            dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
        elif decomp_type != DecompositionTypes.SVD:
            dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollected
        else:
            dmrg.noise_type = NoiseTypes.ReducedPerturbative
        dmrg.cutoff = float(dic.get("cutoff", 1E-14))
        dmrg.davidson_max_iter = int(dic.get("davidson_max_iter", 5000))
        dmrg.davidson_soft_max_iter = int(
            dic.get("davidson_soft_max_iter", 4000))
        dmrg.davidson_def_max_size = int(
            dic.get("davidson_def_max_size", 50))
        dmrg.decomp_type = decomp_type
        dmrg.trunc_type = trunc_type
        dmrg.davidson_conv_thrds = VectorFP(dav_thrds)
        dmrg.store_wfn_spectra = store_wfn_spectra
        dmrg.site_dependent_bond_dims = site_dependent_bdims
        sweep_energies = []
        discarded_weights = []
        if big_site_method is not None and n_cas == 0:
            if "twodot_to_onedot" in dic:
                _print("WARNING: twodot_to_onedot is ignored for n_cas = 0")
            E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
        elif "twodot_to_onedot" not in dic:
            E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
        else:
            tto = int(dic["twodot_to_onedot"])
            assert len(bond_dims) > tto
            dmrg.solve(tto, forward, 0)
            # save the twodot part energies and discarded weights
            for x in dmrg.energies:
                while len(x) < nroots:
                    x.append(0.0)
            sweep_energies.append(np.array(dmrg.energies))
            discarded_weights.append(np.array(dmrg.discarded_weights))
            if big_site_method is not None and n_cas != 0:
                dmrg.last_site_1site = False
                dmrg.me.center = mps.center
            dmrg.bond_dims = VectorUBond(bond_dims[tto:])
            dmrg.noises = VectorFP(noises[tto:])
            dmrg.davidson_conv_thrds = VectorFP(dav_thrds[tto:])
            dmrg.site_dependent_bond_dims = site_dependent_bdims[tto:]
            dmrg.me.dot = 1
            for ext_me in dmrg.ext_mes:
                ext_me.dot = 1
            E_dmrg = dmrg.solve(len(bond_dims) - tto,
                                mps.center == dmrg.sweep_start_site, sweep_tol)
            mps.dot = 1
            dot = 1
            if MPI is None or MPI.rank == 0:
                mps.save_data()

        if conn_centers is not None:
            me.finalize_environments()
            mps = MPS(mps)
            if prule.comm.group != 0:
                quit()
        
        dmrg.me.remove_partition_files()
        for xme in dmrg.ext_mes:
            xme.remove_partition_files()

        if mps.center == mps.n_sites - 1 and mps.dot == 2 and dot == 2:
            mps.center = mps.n_sites - 2

        _print("Final canonical form = ", mps.canonical_form, mps.center)
        for x in dmrg.energies:
            while len(x) < nroots:
                x.append(0.0)
        sweep_energies.append(np.array(dmrg.energies))
        discarded_weights.append(np.array(dmrg.discarded_weights))
        sweep_energies = np.vstack(sweep_energies)
        discarded_weights = np.hstack(discarded_weights)

        if MPI is None or MPI.rank == 0:
            bdims = bond_dims[:len(discarded_weights)]
            if len(bdims) < len(discarded_weights):
                bdims = bdims + bdims[-1:] * \
                    (len(discarded_weights) - len(bdims))
            np.save(scratch + "/E_dmrg.npy", E_dmrg)
            if stackblock_compat:
                dmrg_energies = [E_dmrg] if nroots == 1 else list(sweep_energies[-1])
                with open(os.path.join(scratch + "/dmrg.e"), "wb") as f:
                    import struct
                    f.write(struct.pack('d' * nroots, *dmrg_energies))
            if openmolcas_compat:
                dmrg_energies = [E_dmrg] if nroots == 1 else list(sweep_energies[-1])
                with open(os.path.join(scratch + "/../block.energy"), "w") as f:
                    for enx in dmrg_energies:
                        f.write("%25.12f\n" % enx)
            np.save(scratch + "/bond_dims.npy", bdims)
            np.save(scratch + "/sweep_energies.npy", sweep_energies)
            np.save(scratch + "/discarded_weights.npy", discarded_weights)
            if store_wfn_spectra:
                np.save(scratch + "/sweep_wfn_spectra.npy",
                        np.array([np.array(x) for x in dmrg.sweep_wfn_spectra], dtype=object))
                bip_ent = np.zeros(len(dmrg.sweep_wfn_spectra), dtype=np.float64)
                for ix, x in enumerate(dmrg.sweep_wfn_spectra):
                    ldsq = np.array(x, dtype=np.float128) ** 2
                    ldsq = ldsq[ldsq != 0]
                    bip_ent[ix] = float(np.sum(-ldsq * np.log(ldsq)))
                np.save(scratch + "/sweep_wfn_entropy.npy", bip_ent)
                _print('WFN BIP Entanglement = ', ''.join(["%10.5f" % x for x in bip_ent]))
            if "extrapolation" in dic:
                ext_eners = []
                ext_dws = []
                ext_bdims = []
                if "twodot_to_onedot" not in dic:
                    llsw = len(sweep_energies)
                else:
                    llsw = tto
                for iext in range(llsw):
                    if bdims[iext] not in ext_bdims:
                        ext_bdims.append(bdims[iext])
                        ext_dws.append(discarded_weights[iext])
                        ext_eners.append(sweep_energies[iext, 0])
                    else:
                        ii = ext_bdims.index(bdims[iext])
                        ext_dws[ii] = discarded_weights[iext]
                        ext_eners[ii] = sweep_energies[iext, 0]
                ext_eners = np.array(ext_eners)
                ext_dws = np.array(ext_dws)
                ext_bdims = np.array(ext_bdims)
                _print('EXTRAP discarded weights = ', ext_dws)
                _print('EXTRAP oh energies (au) = ', ext_eners)
                _print('EXTRAP bond dimensions = ', ext_bdims)
                import scipy.stats
                reg = scipy.stats.linregress(ext_dws, ext_eners)
                _print('EXTRAP Energy = %20.15f (+/-) %20.15f' % (reg.intercept,
                                                                  np.min(np.abs(reg.intercept - ext_eners)) / 5))
                _print('EXTRAP R^2 = %20.15f' % (reg.rvalue ** 2))
                emin, emax = min(ext_eners), max(ext_eners)
                de = emax - emin
                xmin, xmax = min(ext_dws), max(ext_dws)
                ddw = xmax - xmin
                import matplotlib.pyplot as plt
                x_reg = np.array([0, xmax + ddw / 12])
                plt.plot(x_reg, reg.intercept + reg.slope * x_reg,
                         '--', linewidth=1, color='#5FA8AB')
                plt.plot(ext_dws, ext_eners, 'o', color='#38686A',
                         markerfacecolor='white', markersize=5)
                plt.text(ddw / 12, emax, "$E(M=\\infty) = %.6f \\pm %.6f \\mathrm{\\ Hartree}$" %
                         (reg.intercept, abs(reg.intercept - emin) / 5), color='#38686A', fontsize=12)
                plt.text(ddw / 12, emax - de / 12, "$R^2 = %.6f$" % (reg.rvalue ** 2),
                         color='#38686A', fontsize=12)
                plt.xlim((0, xmax + ddw / 12))
                plt.ylim((emin - de / 12, emax + de / 12))
                plt.xlabel("Largest Discarded Weight")
                plt.ylabel("Sweep Energy (Hartree)")
                plt.subplots_adjust(left=0.16, bottom=0.1,
                                    right=0.95, top=0.95)
                plt.savefig(scratch + "/extrapolation.png", dpi=600)

        if dynamic_corr_method is not None:
            if dynamic_corr_method[0] in ['casci', 'nevpt2s', 'nevpt2sd', "nevpt2-ijrs", "nevpt2-ij",
                        "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                        "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                        "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                _print("DMRG-CASCI Energy = %20.15f" % E_dmrg)
            elif dynamic_corr_method[0] in ['dmrgfci']:
                _print("DMRG-FCI Energy = %20.15f" % E_dmrg)
            elif dynamic_corr_method[0] in ['mrcis', 'mrcisd', 'mrcisdt']:
                _print("DMRG-%s Energy = %20.15f" %
                       (dynamic_corr_method[0].upper(), E_dmrg))
        else:
            _print("DMRG Energy = %20.15f" % E_dmrg)

        if MPI is None or MPI.rank == 0:
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    # Compression
    if "compression" in dic:
        lmps, lmps_info, _ = get_mps_from_tags(-1)
        if "random_mps_init" not in dic:
            mps = lmps.deep_copy(mps_tags[0])
            mps_info = mps.info

        if "stopt_compression" in dic:
            E_dmrg = float(np.load(scratch + "/E_dmrg.npy"))
            mpo.const_e -= E_dmrg

        me = MovingEnvironment(impo if overlap else mpo, mps, lmps, "CPS")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        cps = Linear(me, VectorUBond(bond_dims),
                     VectorUBond([lmps.info.bond_dim]))
        cps.iprint = max(min(outputlevel, 3), 0)
        cps.cutoff = float(dic.get("cutoff", 1E-14))
        cps.decomp_type = decomp_type
        cps.trunc_type = trunc_type
        if "stopt_compression" in dic:
            cps.conv_type = ConvergenceTypes.LastMaximal
        if "twodot_to_onedot" not in dic:
            ovl = cps.solve(len(bond_dims), mps.center == 0, sweep_tol)
        else:
            tto = int(dic["twodot_to_onedot"])
            assert len(bond_dims) > tto
            cps.solve(tto, mps.center == 0, 0)
            cps.bra_bond_dims = VectorUBond(bond_dims[tto:])
            cps.rme.dot = 1
            ovl = cps.solve(len(bond_dims) - tto, mps.center == 0, sweep_tol)
            mps.dot = 1
            lmps.dot = 1
            if MPI is None or MPI.rank == 0:
                mps.save_data()
                lmps.save_data()
        _print("Final canonical form = ", mps.canonical_form, mps.center)
        _print("Compression overlap = %20.15f" % ovl)

        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/cps_overlap.npy", ovl)
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

        if "stopt_compression" in dic:
            mpo.const_e += E_dmrg

    # Time Evolution
    if "delta_t" in dic:

        if len(read_tags) == 0:
            _print("Time Evolution START FROM RANDOM MPS !!!")
        else:
            mps, mps_info, _ = get_mps_from_tags(-1)

        dt = complex(dic["delta_t"].replace(' ', '').replace('i', 'j'))
        tt = complex(dic["target_t"].replace(' ', '').replace('i', 'j'))
        n_steps = int(abs(tt) / abs(dt) + 0.1)
        assert np.abs(abs(n_steps * dt) - abs(tt)) < 1E-10
        is_imag_te = abs(np.imag(dt)) < 1E-10
        if is_imag_te:
            dt = np.real(dt)
            tt = np.real(tt)
            _print("Time Evolution  DELTA T = %15.8f" % dt)
            _print("Time Evolution TARGET T = %15.8f" % tt)
        else:
            _print("Time Evolution  DELTA T = RE %15.8f + IM %15.8f" %
                   (np.real(dt), np.imag(dt)))
            _print("Time Evolution TARGET T = RE %15.8f + IM %15.8f" %
                   (np.real(tt), np.imag(tt)))

        if isinstance(mps, MultiMPS):
            assert len(mps.wfns) == 2
            assert mps.info.tag != mps_tags[0]
            assert complex_mps
            mps = mps.deep_copy(mps_tags[0])
            mps_info = mps.info
        else:
            assert not complex_mps
            assert mps.info.tag != mps_tags[0]
            mps = mps.deep_copy(mps_tags[0])
            mps_info = mps.info

        _print("Time Evolution   NSTEPS = %d" % n_steps)
        _print("    with %s wavefunction" %
               ("complex" if complex_mps else "real"))
        _print("    with %s step (%s TE)" % (
            "real" if is_imag_te else "complex", "imag" if is_imag_te else "real"))
        _print("    init canonical form = %s" % mps.canonical_form)

        me = MovingEnvironment(mpo, mps, mps, "DMRG")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        te = TimeEvolution(me, VectorUBond(bond_dims), te_type)
        te.cutoff = float(dic.get("cutoff", 1E-14))
        te.hermitian = not anti_herm
        te.iprint = max(min(outputlevel, 3), 0)
        te.n_sub_sweeps = 1
        if te.mode != TETypes.TangentSpace:
            te.n_sub_sweeps = int(dic.get("n_sub_sweeps", 2))
        te.normalize_mps = "normalize_mps" in dic
        te_times = []
        te_energies = []
        te_normsqs = []
        te_discarded_weights = []
        for i in range(n_steps):
            if te.mode == TETypes.TangentSpace:
                te.solve(2, dt / 2, mps.center == 0)
            else:
                te.solve(1, dt, mps.center == 0)
            if is_imag_te:
                _print("T = %10.5f <E> = %20.15f <Norm^2> = %20.15f" %
                       ((i + 1) * dt, te.energies[-1], te.normsqs[-1]))
            else:
                _print("T = RE %10.5f + IM %10.5f <E> = %20.15f <Norm^2> = %20.15f" %
                       ((i + 1) * np.real(dt), (i + 1) * np.imag(dt), te.energies[-1], te.normsqs[-1]))
            te_times.append((i + 1) * dt)
            te_energies.append(te.energies[-1])
            te_normsqs.append(te.normsqs[-1])
            te_discarded_weights.append(te.discarded_weights[-1])
        _print("Max Discarded Weight = %9.5g" % max(te_discarded_weights))

        _print("   mps final tag = %s" % mps_tags[0])
        _print("   mps final canonical form = %s" % mps.canonical_form)

        np.save(scratch + "/te_times.npy", np.array(te_times))
        np.save(scratch + "/te_energies.npy", np.array(te_energies))
        np.save(scratch + "/te_normsqs.npy", np.array(te_normsqs))
        np.save(scratch + "/te_discarded_weights.npy",
                np.array(te_discarded_weights))

        if MPI is None or MPI.rank == 0:
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    def do_npdm(bmps, kmps, pdm_type):

        if MPI is not None:
            MPI.barrier()

        bmps.info.load_mutable()
        bmps.info.bond_dim = max(bmps.info.bond_dim, bmps.info.get_max_bond_dimension())
        bmps.info.deallocate_mutable()
        kmps.info.load_mutable()
        kmps.info.bond_dim = max(kmps.info.bond_dim, kmps.info.get_max_bond_dimension())
        kmps.info.deallocate_mutable()

        if MPI is not None:
            MPI.barrier()

        import block2 as b

        if "use_general_spin" in dic:
            op_str = "C" * pdm_type + "D" * pdm_type
            perm = b.SpinPermScheme.initialize_sz(pdm_type * 2, op_str, True)
            perms = b.VectorSpinPermScheme([perm])
        elif "nonspinadapted" in dic:
            op_str = ["cd", "CD"]
            for _ in range(pdm_type - 1):
                op_str = ["c%sd" % x for x in op_str] + ["C%sD" % op_str[-1]]
            perms = b.VectorSpinPermScheme([
                b.SpinPermScheme.initialize_sz(pdm_type * 2, cd, True)
                for cd in op_str])
        elif pdm_type == 'nevpt':
            op_str = [
                "(C+D)0",
                "((C+D)0+(C+D)0)0",
                "((C+D)0+((C+D)0+(C+D)0)0)0",
                "((C+D)0+((C+D)0+((C+D)0+(C+D)0)0)0)0"
            ]
            perms = b.VectorSpinPermScheme([
                b.SpinPermScheme.initialize_su2((ip + 1) * 2, cd, True)
                for ip, cd in enumerate(op_str)])
        else:
            op_str = [
                ["(C+D)0"],
                ["((C+(C+D)0)1+D)0"],
                ["((C+((C+(C+D)0)1+D)0)1+D)0"],
                ["((C+((C+((C+(C+D)0)1+D)0)1+D)0)1+D)0"]
            ][pdm_type - 1]
            perms = b.VectorSpinPermScheme([
                b.SpinPermScheme.initialize_su2(pdm_type * 2, cd, True)
                for cd in op_str])
        
        _print("npdm string =", op_str)
        if MPI is not None:
            npdm_prule = ParallelRuleSimple(b.ParallelSimpleTypes.Nothing, MPI)

        scheme = b.NPDMScheme(perms)
        symbol_free = (algo_type & ExpectationAlgorithmTypes.SymbolFree) or \
            (algo_type & ExpectationAlgorithmTypes.Automatic)
        pmpo = GeneralNPDMMPO(ghamil, scheme, symbol_free)
        pmpo.iprint = 2 if outputlevel >= 4 else min(outputlevel, 1)
        if MPI is not None:
            pmpo.parallel_rule = npdm_prule
        pmpo.delta_quantum = (bmps.info.target - kmps.info.target)[0]
        pmpo.build()

        pmpo = SimplifiedMPO(pmpo, Rule(), False, False)
        if MPI is not None:
            pmpo = ParallelMPO(pmpo, npdm_prule)

        pme = MovingEnvironment(pmpo, bmps, kmps, "NPDM")
        pme.cached_contraction = False
        pme.fused_contraction_rotation = True
        pme.save_partition_info = True
        pme.init_environments(outputlevel >= 2)

        expect = XExpect(pme, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.zero_dot_algo = True
        expect.algo_type = algo_type
        expect.cutoff = float(dic.get("cutoff", 1E-24))
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, kmps.center == 0)
        expect.me.remove_partition_files()

        npdms = list(expect.get_npdm())
        for ip in range(len(npdms)):
            npdms[ip] = np.asarray(npdms[ip])

        if "use_general_spin" not in dic and "nonspinadapted" not in dic:
            for ip in range(len(npdms)):
                npdms[ip] *= np.array(np.sqrt(2.0)) ** (npdms[ip].ndim // 2)

        if orb_idx is not None:
            rev_idx = np.argsort(orb_idx)
            if "trans_integral_to_spin_orbital" in dic:
                rev_idx = np.array(list(zip(rev_idx*2, rev_idx*2+1))).ravel()
            for ip in range(len(npdms)):
                for i in range(npdms[ip].ndim):
                    npdms[ip] = npdms[ip][(slice(None),) * i + (rev_idx,)]

        if "use_general_spin" in dic:
            dm = npdms[0][None]
        elif "nonspinadapted" in dic:
            dm = np.array(npdms)
        else:
            if pdm_type == 1:
                dm = np.array([npdms[0] / 2, ] * 2)
            elif pdm_type == 2:
                daa = (npdms[0] - npdms[0].transpose(0, 1, 3, 2)) / 6
                dab = (2 * npdms[0] + npdms[0].transpose(0, 1, 3, 2)) / 6
                dm = np.array([daa, dab, daa])
            elif pdm_type == 'nevpt':
                dm = npdms
            else:
                dm = npdms[0]

        return dm

    def do_onepdm(bmps, kmps):
        me = MovingEnvironment(pmpo, bmps, kmps, "1PDM")
        # currently delayed_contraction is not compatible to
        # ExpectationAlgorithmTypes.Fast
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.zero_dot_algo = "zerodot" in dic
        expect.algo_type = algo_type
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, kmps.center == 0)
        _print("Final canonical form = ", bmps.canonical_form, bmps.center)

        if MPI is None or MPI.rank == 0:
            dmr = expect.get_1pdm(n_orbs)
            dm = np.array(dmr).copy()
            dmr.deallocate()
            if "use_general_spin" in dic:
                dm = dm[None]
            else:
                dm = dm.reshape((dm.shape[0] // 2, 2, dm.shape[1] // 2, 2))
                dm = np.transpose(dm, (0, 2, 1, 3))
                dm = np.concatenate(
                    [dm[None, :, :, 0, 0], dm[None, :, :, 1, 1]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                if "trans_integral_to_spin_orbital" in dic:
                    rev_idx = np.array(list(zip(rev_idx*2, rev_idx*2+1))).ravel()
                assert dm.shape[-1] == len(rev_idx)
                dm[:, :, :] = dm[:, rev_idx, :][:, :, rev_idx]
            return dm
        else:
            return None

    # ONEPDM
    if "restart_onepdm" in dic or "onepdm" in dic:

        if nroots == 1:

            if "skip_inact_ext_sites" in dic and mps.center != 0 and mps.center != mps.n_sites - mps.dot:
                if MPI is not None:
                    MPI.barrier()
                _print('change canonical form ...')
                _print('original cf = ', mps.canonical_form)
                ime = MovingEnvironment(impo, mps, mps, "IEX")
                ime.delayed_contraction = OpNamesSet.normal_ops()
                ime.cached_contraction = cached_contraction
                ime.init_environments(False)
                expect = XExpect(ime, mps.info.bond_dim, mps.info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                expect.solve(True, mps.center != n_inactive)
                _print('final cf = ', mps.canonical_form)
                if MPI is not None:
                    MPI.barrier()

            if conv_npdm or ghamil is None:
                dm = do_onepdm(mps, mps)
            else:
                dm = do_npdm(mps, mps, pdm_type=1)
            if MPI is None or MPI.rank == 0:
                dmocc = np.diag(np.sum(dm, axis=0))
                if dmocc.dtype == np.complex128:
                    _print("DMRG OCC = ", "".join([" (%6.3f,%6.3f)" % (np.real(x), np.imag(x)) for x in dmocc]))
                else:
                    _print("DMRG OCC = ", "".join(["%6.3f" % x for x in dmocc]))
                if big_site_method != "folding":
                    np.save(scratch + "/1pdm.npy", dm)
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

                # natural orbital generation
                if "nat_orbs" in dic:
                    spdm = np.sum(dm, axis=0)
                    # need pdm after orbital rotation
                    if orb_idx is not None:
                        spdm[:, :] = spdm[orb_idx, :][:, orb_idx]
                    xpdm = spdm.copy()
                    _print("REORDERED OCC = ", "".join(
                        ["%6.3f" % x for x in np.diag(spdm)]))
                    spdm = spdm.ravel()
                    nat_occs = np.zeros((n_sites, ))
                    if orb_sym is None:
                        raise ValueError(
                            "Need FCIDUMP construction (namely, not a pre run) for 'nat_orbs'!")
                    MatrixFunctions.block_eigs(spdm, nat_occs, orb_sym)
                    _print("NAT OCC = ", "".join(
                        ["%9.6f" % x for x in nat_occs]))
                    # (old, new)
                    rot = np.array(spdm.reshape(
                        (n_sites, n_sites)).T, copy=True)
                    np.save(scratch + "/nat_orb_sym.npy", np.array(orb_sym))
                    for isym in set(orb_sym):
                        mask = np.array(orb_sym) == isym
                        if "nat_km_reorder" in dic:
                            kmidx = np.argsort(KuhnMunkres(
                                1 - rot[mask, :][:, mask] ** 2).solve()[1])
                            _print("init = ", np.sum(
                                np.diag(rot[mask, :][:, mask]) ** 2))
                            rot[:, mask] = rot[:, mask][:, kmidx]
                            nat_occs[mask] = nat_occs[mask][kmidx]
                            _print("final = ", np.sum(
                                np.diag(rot[mask, :][:, mask]) ** 2))
                        if "nat_positive_def" in dic:
                            for j in range(len(nat_occs[mask])):
                                mrot = rot[mask, :][:j + 1,
                                                    :][:, mask][:, :j + 1]
                                mrot_det = np.linalg.det(mrot)
                                _print("ISYM = %d J = %d MDET = %15.10f" %
                                       (isym, j, mrot_det))
                                if mrot_det < 0:
                                    mask0 = np.arange(len(mask), dtype=int)[
                                        mask][j]
                                    rot[:, mask0] = -rot[:, mask0]
                        else:
                            mrot = rot[mask, :][:, mask]
                            mrot_det = np.linalg.det(mrot)
                            _print("ISYM = %d MDET = %15.10f" %
                                   (isym, mrot_det))
                            if mrot_det < 0:
                                mask0 = np.arange(len(mask), dtype=int)[
                                    mask][0]
                                rot[:, mask0] = -rot[:, mask0]
                    if "nat_km_reorder" in dic:
                        _print("REORDERED NAT OCC = ", "".join(
                            ["%9.6f" % x for x in nat_occs]))
                    assert np.linalg.norm(rot @ np.diag(
                        nat_occs) @ rot.T - xpdm) < 1E-10
                    np.save(scratch + "/nat_occs.npy", nat_occs)
                    rot_det = np.linalg.det(rot)
                    _print("DET = %15.10f" % rot_det)
                    assert rot_det > 0
                    np.save(scratch + "/nat_rotation.npy", rot)

                    def my_logm(mrot):
                        rs = mrot + mrot.T
                        rl, rv = np.linalg.eigh(rs)
                        assert np.linalg.norm(
                            rs - rv @ np.diag(rl) @ rv.T) < 1E-10
                        rd = rv.T @ mrot @ rv
                        ra, rdet = 1, rd[0, 0]
                        for i in range(1, len(rd)):
                            ra, rdet = rdet, rd[i, i] * rdet - \
                                rd[i - 1, i] * rd[i, i - 1] * ra
                        assert rdet > 0
                        ld = np.zeros_like(rd)
                        for i in range(0, len(rd) // 2 * 2, 2):
                            xcos = (rd[i, i] + rd[i + 1, i + 1]) / 2
                            xsin = (rd[i, i + 1] - rd[i + 1, i]) / 2
                            theta = np.arctan2(xsin, xcos)
                            ld[i, i + 1] = theta
                            ld[i + 1, i] = -theta
                        return rv @ ld @ rv.T

                    import scipy.linalg
                    # kappa = scipy.linalg.logm(rot)
                    kappa = np.zeros_like(rot)
                    for isym in set(orb_sym):
                        mask = np.array(orb_sym) == isym
                        mrot = rot[mask, :][:, mask]
                        mkappa = my_logm(mrot)
                        # mkappa = scipy.linalg.logm(mrot)
                        # assert mkappa.dtype == float
                        gkappa = np.zeros((kappa.shape[0], mkappa.shape[1]))
                        gkappa[mask, :] = mkappa
                        kappa[:, mask] = gkappa
                    assert np.linalg.norm(
                        scipy.linalg.expm(kappa) - rot) < 1E-10
                    assert np.linalg.norm(kappa + kappa.T) < 1E-10

                    # rot is (old, new) => kappa should be minus
                    np.save(scratch + "/nat_kappa.npy", kappa)

                    # integral rotation
                    nat_fname = dic["nat_orbs"].strip()
                    if len(nat_fname) > 0:
                        if fcidump is None:
                            raise ValueError(
                                "Need FCIDUMP construction (namely, not a pre run) for 'nat_orbs'!")
                        # the following code will not check values inside fcidump
                        # since all MPOs are already constructed
                        _print("rotating integrals to natural orbitals ...")
                        # (old, new)
                        fcidump.rotate(VectorFP(rot.ravel()))
                        _print("finished.")
                        rot_sym_error = fcidump.symmetrize(orb_sym)
                        _print("rotated integral sym error = %12.4g" %
                               rot_sym_error)
                        if rot_sym_error > symmetrize_ints_tol:
                            raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                                                + "please check point group symmetry and FCIDUMP or set"
                                                + " a higher tolerance for the keyword '%s'") % (
                                symmetrize_ints_tol, "symmetrize_ints"))
                        _print("writing natural orbital integrals ...")
                        fcidump.write(nat_fname)
                        _print("finished.")

        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, forward = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                if conv_npdm:
                    dm = do_onepdm(smps, smps)
                else:
                    dm = do_npdm(smps, smps, pdm_type=1)
                if MPI is None or MPI.rank == 0:
                    dmocc = np.diag(np.sum(dm, axis=0))
                    if dmocc.dtype == np.complex128:
                        _print("DMRG OCC = ", "".join(["(%6.3f,%6.3f)" % (np.real(x), np.imag(x)) for x in dmocc]))
                    else:
                        _print("DMRG OCC = ", "".join(["%6.3f" % x for x in dmocc]))
                    np.save(scratch + "/1pdm-%d-%d.npy" % (iroot, iroot), dm)
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # Transition ONEPDM
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    if "restart_tran_onepdm" in dic or "tran_onepdm" in dic:

        assert nroots != 1
        brar, ketr = range(nroots), range(nroots)
        if "tran_bra_range" in dic:
            tbr = [int(x) for x in dic["tran_bra_range"].split()]
            brar = range(*tbr)
        if "tran_ket_range" in dic:
            tkr = [int(x) for x in dic["tran_ket_range"].split()]
            ketr = range(*tkr)
        for iroot in brar:
            for jroot in ketr:
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if "tran_triangular" in dic:
                    if iroot < jroot:
                        continue
                tx = time.perf_counter()
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                if jroot == iroot:
                    sjmps = simps
                if conv_npdm:
                    dm = do_onepdm(simps, sjmps)
                else:
                    dm = do_npdm(simps, sjmps, pdm_type=1)
                if soc:
                    if SX == SU2:
                        if hasattr(simps.info, "targets"):
                            qsbra = simps.info.targets[0].twos
                        else:
                            qsbra = simps.info.target.twos
                        # fix different Wigner–Eckart theorem convention
                        dm *= np.sqrt(qsbra + 1)
                    dm = dm / np.sqrt(2)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/1pdm-%d-%d.npy" % (iroot, jroot), dm)
                if (MPI is None or MPI.rank == 0) and iroot == jroot:
                    _print("DMRG OCC (state %4d) = " % iroot, "".join(
                        ["%6.3f" % x for x in np.diag(dm[0]) + np.diag(dm[1])]))
                    simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
                _print("tran 1pdm finished", time.perf_counter() - tx)

    # Particle Number Correlation
    if "restart_correlation" in dic or "correlation" in dic:
        assert nroots == 1
        me = MovingEnvironment(nmpo, mps, mps, "1NPC")
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, mps.info.bond_dim, mps.info.bond_dim)
        expect.algo_type = algo_type
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, mps.center == 0)
        _print("Final canonical form = ", mps.canonical_form, mps.center)

        if MPI is None or MPI.rank == 0:

            dmr = expect.get_1npc_spatial(0, n_orbs)
            dm_pure = np.array(dmr).copy()
            dmr.deallocate()
            dmr = expect.get_1npc_spatial(1, n_orbs)
            dm_mix = np.array(dmr).copy()
            dmr.deallocate()
            dm = np.concatenate(
                [dm_pure[None, :, :], dm_mix[None, :, :]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                dm[:, :, :] = dm[:, rev_idx, :][:, :, rev_idx]

            np.save(scratch + "/1npc.npy", dm)
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    # diag twopdm
    if "restart_diag_twopdm" in dic or "diag_twopdm" in dic:
        if MPI is None or MPI.rank == 0:
            assert nroots == 1
            dm_npc = np.load(scratch + "/1npc.npy")
            dm_pdm = np.load(scratch + "/1pdm.npy").sum(axis=0)
            dm_e_pqqp = dm_npc[0] - np.diag(np.diag(dm_pdm))
            dm_e_pqpq = -dm_npc[1] + 2 * np.diag(np.diag(dm_pdm))
            np.save(scratch + "/e_pqqp.npy", dm_e_pqqp)
            np.save(scratch + "/e_pqpq.npy", dm_e_pqpq)

    def do_twopdm(bmps, kmps):
        me = MovingEnvironment(p2mpo, bmps, kmps, "2PDM")
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.algo_type = algo_type
        expect.zero_dot_algo = "zerodot" in dic
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, kmps.center == 0)
        _print("Final canonical form = ", bmps.canonical_form, bmps.center)

        if MPI is None or MPI.rank == 0:
            dmr = expect.get_2pdm(n_orbs)
            dm = np.array(dmr, copy=True)
            dm = dm.reshape((n_orbs, 2, n_orbs, 2,
                             n_orbs, 2, n_orbs, 2))
            dm = np.transpose(dm, (0, 2, 4, 6, 1, 3, 5, 7))
            dm = np.concatenate([dm[None, :, :, :, :, 0, 0, 0, 0], dm[None, :, :, :, :, 0, 1, 1, 0],
                                 dm[None, :, :, :, :, 1, 1, 1, 1]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                dm[:, :, :, :, :] = dm[:, rev_idx, :, :, :][:, :, rev_idx,
                                                            :, :][:, :, :, rev_idx, :][:, :, :, :, rev_idx]
            return dm
        else:
            return None

    def save_npdm_stackblock_format(dm, fn):
        n_orbs, n_ops = len(dm), dm.ndim
        with open(scratch + "/" + fn, "w") as f:
            f.write("%s\n" % n_orbs)
            if dm.dtype == np.complex128:
                f.writelines(("%2d " * n_ops + "%20.14f %20.14f\n") % (tuple(ix) + (x.real, x.imag))
                    for ix, x in zip(np.mgrid[(slice(n_orbs), ) * n_ops].reshape(n_ops, -1).T,
                        dm.ravel()))
            else:
                f.writelines(("%2d " * n_ops + "%20.14f\n") % (tuple(ix) + (x, ))
                    for ix, x in zip(np.mgrid[(slice(n_orbs), ) * n_ops].reshape(n_ops, -1).T,
                        dm.ravel()))

    # TWOPDM
    if "restart_twopdm" in dic or "twopdm" in dic:

        if nroots == 1:
            if conv_npdm:
                dm = do_twopdm(mps, mps)
            else:
                dm = do_npdm(mps, mps, pdm_type=2)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/2pdm.npy", dm)
                if stackblock_compat or openmolcas_compat:
                    xdm = (dm[0] + dm[2] + 2 * dm[1]) / 2
                    save_npdm_stackblock_format(xdm, "spatial_twopdm.0.0.txt")
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                if conv_npdm:
                    dm = do_twopdm(smps, smps)
                else:
                    dm = do_npdm(smps, smps, pdm_type=2)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/2pdm-%d-%d.npy" % (iroot, iroot), dm)
                    if stackblock_compat or openmolcas_compat:
                        xdm = (dm[0] + dm[2] + 2 * dm[1]) / 2
                        save_npdm_stackblock_format(xdm, "spatial_twopdm.%d.%d.txt" % (iroot, iroot))
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # Transition TWOPDM
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    if "restart_tran_twopdm" in dic or "tran_twopdm" in dic:

        assert nroots != 1
        brar, ketr = range(nroots), range(nroots)
        if "tran_bra_range" in dic:
            tbr = [int(x) for x in dic["tran_bra_range"].split()]
            brar = range(*tbr)
        if "tran_ket_range" in dic:
            tkr = [int(x) for x in dic["tran_ket_range"].split()]
            ketr = range(*tkr)
        for iroot in brar:
            for jroot in ketr:
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if "tran_triangular" in dic:
                    if iroot < jroot:
                        continue
                tx = time.perf_counter()
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                if jroot == iroot:
                    sjmps = simps
                if conv_npdm:
                    dm = do_twopdm(simps, sjmps)
                else:
                    dm = do_npdm(simps, sjmps, pdm_type=2)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/2pdm-%d-%d.npy" % (iroot, jroot), dm)
                _print("tran 2pdm finished", time.perf_counter() - tx)
            if MPI is None or MPI.rank == 0:
                simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # THREEPDM / FOURPDM / FOCK_FOURPDM
    if "restart_threepdm" in dic or "threepdm" in dic or "restart_fourpdm" in dic or "fourpdm" in dic \
         or "restart_fock_fourpdm" in dic or "fock_fourpdm" in dic:

        pdm_types, pdm_names = [], []
        if "restart_fourpdm" in dic or "fourpdm" in dic:
            pdm_types.append(4)
            pdm_names.append("fourpdm")
        if "restart_fock_fourpdm" in dic or "fock_fourpdm" in dic:
            pdm_types.append(4)
            pdm_names.append("fock_fourpdm")
            assert "fock_matrix" in dic
            fock = read_fock_fcidump(dic["fock_matrix"])[1]
        if "restart_threepdm" in dic or "threepdm" in dic:
            pdm_types.append(3)
            pdm_names.append("threepdm")

        if nroots == 1:
            for pdm_type, pdm_name in zip(pdm_types, pdm_names):
                dm = do_npdm(mps, mps, pdm_type=pdm_type)
                if MPI is None or MPI.rank == 0:
                    if (stackblock_compat or openmolcas_compat) and pdm_type == 3:
                        # with open(scratch + "/spatial_%s.%d.%d.bin" % (pdm_name, 0, 0), "wb") as f:
                        #     np.save(f, dm)
                        save_npdm_stackblock_format(dm, "spatial_%s.%d.%d.txt" % (pdm_name, 0, 0))
                    elif (stackblock_compat or openmolcas_compat) and pdm_type == 4:
                        if pdm_name == "fourpdm":
                            with open(scratch + "/spatial_%s.%d.%d.bin" % (pdm_name, 0, 0), 'wb') as f:
                                f.seek(109)
                                dm.tofile(f)
                        elif pdm_name == "fock_fourpdm":
                            fdm = np.einsum('ijklmnop,lm', dm, fock, optimize=True)
                            dm = None
                            save_npdm_stackblock_format(fdm, "%s.%d.%d.txt" % (pdm_name, 0, 0))
                    np.save(scratch + "/%dpdm.npy" % pdm_type, dm)
            if MPI is None or MPI.rank == 0:
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                for pdm_type, pdm_name in zip(pdm_types, pdm_names):
                    dm = do_npdm(smps, smps, pdm_type=pdm_type)
                    if MPI is None or MPI.rank == 0:
                        if (stackblock_compat or openmolcas_compat) and pdm_type == 3:
                            # with open(scratch + "/spatial_%s.%d.%d.bin" % (pdm_name, iroot, iroot), "wb") as f:
                            #     np.save(f, dm)
                            save_npdm_stackblock_format(dm, "spatial_%s.%d.%d.txt" % (pdm_name, iroot, iroot))
                        elif (stackblock_compat or openmolcas_compat) and pdm_type == 4:
                            if pdm_name == "fourpdm":
                                with open(scratch + "/spatial_%s.%d.%d.bin" % (pdm_name, iroot, iroot), 'wb') as f:
                                    f.seek(109)
                                    dm.tofile(f)
                            elif pdm_name == "fock_fourpdm":
                                fdm = np.einsum('ijklmnop,lm', dm, fock, optimize=True)
                                dm = None
                                save_npdm_stackblock_format(fdm, "%s.%d.%d.txt" % (pdm_name, iroot, iroot))
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/%dpdm-%d-%d.npy" % (pdm_type, iroot, iroot), dm)
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # Transition THREEPDM / FOURPDM
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    if "restart_tran_threepdm" in dic or "tran_threepdm" in dic or "restart_tran_fourpdm" in dic or "tran_fourpdm" in dic:
        
        pdm_types = []
        if "restart_tran_fourpdm" in dic or "tran_fourpdm" in dic:
            pdm_types.append(4)
        if "restart_tran_threepdm" in dic or "tran_threepdm" in dic:
            pdm_types.append(3)

        assert nroots != 1
        brar, ketr = range(nroots), range(nroots)
        if "tran_bra_range" in dic:
            tbr = [int(x) for x in dic["tran_bra_range"].split()]
            brar = range(*tbr)
        if "tran_ket_range" in dic:
            tkr = [int(x) for x in dic["tran_ket_range"].split()]
            ketr = range(*tkr)
        for iroot in brar:
            for jroot in ketr:
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if "tran_triangular" in dic:
                    if iroot < jroot:
                        continue
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                if jroot == iroot:
                    sjmps = simps
                for pdm_type in pdm_types:
                    tx = time.perf_counter()
                    dm = do_npdm(simps, sjmps, pdm_type=pdm_type)
                    if MPI is None or MPI.rank == 0:
                        np.save(scratch + "/%dpdm-%d-%d.npy" % (pdm_type, iroot, jroot), dm)
                    _print("tran %dpdm finished" % pdm_type, time.perf_counter() - tx)
            if MPI is None or MPI.rank == 0:
                simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # SC-NEVPT2 intermediates (A16 and A22)
    # adapted from pyscf/mrpt/nevpt2.py
    if "restart_nevpt2_npdm" in dic:

        assert stackblock_compat

        if MPI is None or MPI.rank == 0:
            h1e = np.asarray(fcidump.h1e_matrix()).reshape((n_orbs, n_orbs))
            h2e = np.asarray(fcidump.g2e_1fold())
            h2e = h2e.reshape(n_orbs, n_orbs, n_orbs, n_orbs)
            h2e = h2e.transpose(0, 2, 1, 3)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                if "trans_integral_to_spin_orbital" in dic:
                    rev_idx = np.array(list(zip(rev_idx*2, rev_idx*2+1))).ravel()
                for i in range(h1e.ndim):
                    h1e = h1e[(slice(None),) * i + (rev_idx,)]
                for i in range(h2e.ndim):
                    h2e = h2e[(slice(None),) * i + (rev_idx,)]
        else:
            h1e, h2e = None, None

        def save_nevpt2_npdm_stackblock_format(h1e, h2e, dms, iroot):

            dm2, dm3, dm4 = dms[1:]

            f3ac = np.einsum('ijka,rpqbjcik->rpqbac', h2e, dm4, optimize=True)
            f3ca = np.einsum('kcij,rpqbajki->rpqbac', h2e, dm4, optimize=True)
            dm4 = None

            a16 = -np.einsum('ib,rpqiac->pqrabc', h1e, dm3, optimize=True)
            a16 += np.einsum('ia,rpqbic->pqrabc', h1e, dm3, optimize=True)
            a16 -= np.einsum('ci,rpqbai->pqrabc', h1e, dm3, optimize=True)

            a16 -= f3ca.transpose(1, 4, 0, 2, 5, 3)
            a16 -= np.einsum('kbia,rpqcki->pqrabc', h2e, dm3, optimize=True)
            a16 -= np.einsum('kbaj,rpqjkc->pqrabc', h2e, dm3, optimize=True)
            a16 += np.einsum('cbij,rpqjai->pqrabc', h2e, dm3, optimize=True)
            fdm2 = np.einsum('kbij,rpajki->prab', h2e, dm3, optimize=True)
            a16[:, np.mgrid[:n_orbs], :, :, :, np.mgrid[:n_orbs]] += fdm2[None]

            a16 += f3ac.transpose(1, 2, 0, 4, 3, 5)
            a16 -= f3ca.transpose(1, 2, 0, 4, 3, 5)
            a16 += np.einsum('jbij,rpqiac->pqrabc', h2e, dm3, optimize=True)
            a16 -= np.einsum('cjka,rpqbjk->pqrabc', h2e, dm3, optimize=True)
            a16 += np.einsum('jcij,rpqbai->pqrabc', h2e, dm3, optimize=True)

            save_npdm_stackblock_format(a16, "%s_matrix.%d.%d.txt" % ('A16', iroot, iroot))
            a16 = None

            a22 = -np.einsum('pb,kipjac->ijkabc', h1e, dm3, optimize=True)
            a22 -= np.einsum('pa,kibjpc->ijkabc', h1e, dm3, optimize=True)
            a22 += np.einsum('cp,kibjap->ijkabc', h1e, dm3, optimize=True)
            a22 += np.einsum('cqra,kibjqr->ijkabc', h2e, dm3, optimize=True)
            a22 -= np.einsum('qcpq,kibjap->ijkabc', h2e, dm3, optimize=True)

            a22 -= f3ac.transpose(1, 5, 0, 2, 4, 3)
            fdm2 = np.einsum('pqrb,kiqcpr->ikbc', h2e, dm3, optimize=True)
            a22[:, np.mgrid[:n_orbs], :, np.mgrid[:n_orbs], :, :] -= fdm2[None]
            a22 -= np.einsum('pqab,kiqjpc->ijkabc', h2e, dm3, optimize=True)
            a22 += np.einsum('pcrb,kiajpr->ijkabc', h2e, dm3, optimize=True)
            a22 += np.einsum('cqrb,kiqjar->ijkabc', h2e, dm3, optimize=True)

            a22 -= f3ac.transpose(1, 3, 0, 4, 2, 5)
            a22 += f3ca.transpose(1, 3, 0, 4, 2, 5)
            a22 += 2.0 * np.einsum('jb,kiac->ijkabc', h1e, dm2, optimize=True)
            a22 += 2.0 * np.einsum('pjrb,kiprac->ijkabc', h2e, dm3, optimize=True)
            fdm2  = np.einsum('pa,kipc->ikac', h1e, dm2, optimize=True)
            fdm2 -= np.einsum('cp,kiap->ikac', h1e, dm2, optimize=True)
            fdm2 -= np.einsum('cqra,kiqr->ikac', h2e, dm2, optimize=True)
            fdm2 += np.einsum('qcpq,kiap->ikac', h2e, dm2, optimize=True)
            fdm2 += np.einsum('pqra,kiqcpr->ikac', h2e, dm3, optimize=True)
            fdm2 -= np.einsum('rcpq,kiaqrp->ikac', h2e, dm3, optimize=True)
            a22[:, np.mgrid[:n_orbs], :, :, np.mgrid[:n_orbs], :] += fdm2[None] * 2

            save_npdm_stackblock_format(a22, "%s_matrix.%d.%d.txt" % ('A22', iroot, iroot))
            a22 = None

        if nroots == 1:
            dms = do_npdm(mps, mps, pdm_type='nevpt')
            if MPI is None or MPI.rank == 0:
                if stackblock_compat:
                    save_nevpt2_npdm_stackblock_format(h1e, h2e, dms, 0)
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
            dms = None
        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                dms = do_npdm(smps, smps, pdm_type='nevpt')
                if MPI is None or MPI.rank == 0:
                    if stackblock_compat:
                        save_nevpt2_npdm_stackblock_format(h1e, h2e, dms, iroot)
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
                dms = None

    # SC-NEVPT2 compress approximation
    # J. Chem. Phys. 146, 244102 (2017) Appendix B
    if "restart_mps_nevpt" in dic:

        assert "nonspinadapted" not in dic

        if nroots == 1:
            iroot = 0
            orig_mps = mps
        else:
            assert "nevpt_state_num" in dic
            iroot = int(dic["nevpt_state_num"])
            _print('----- root = %3d / %3d -----' % (iroot, nroots))
            if len(mps_tags) > 1:
                smps, smps_info, _ = get_mps_from_tags(iroot)
            elif "statespecific" in dic:
                smps, smps_info, forward = get_state_specific_mps(
                    iroot, mps_info)
            else:
                smps, smps_info, forward = split_mps(iroot, mps, mps_info)
            orig_mps = smps

        with open(os.path.join(scratch + "/dmrg.e"), "rb") as f:
            import struct
            dmrg_energies = struct.unpack('d' * nroots, f.read())

        e_casci = dmrg_energies[iroot]

        orig_tag = orig_mps.info.tag

        _print('transform ref state to singlet embedding ...')

        ref_mps = orig_mps.deep_copy(orig_tag + "@NEVPT-SE")

        if not singlet_embedding:
            if ref_mps.canonical_form[0] == 'C' and ref_mps.canonical_form[1] == 'R':
                ref_mps.canonical_form = 'K' + ref_mps.canonical_form[1:]
                ref_mps.center = 0
            elif ref_mps.canonical_form[-1] == 'C' and ref_mps.canonical_form[-2] == 'L':
                ref_mps.canonical_form = ref_mps.canonical_form[:-1] + 'S'
                ref_mps.center = ref_mps.n_sites - 1
            elif ref_mps.center == ref_mps.n_sites - 2 and ref_mps.canonical_form[-2] == 'L':
                ref_mps.center = ref_mps.n_sites - 1
            while ref_mps.center > 0:
                ref_mps.move_left(CG(), prule if MPI is not None else None)
            ref_mps.to_singlet_embedding_wfn(CG(), SX.invalid, prule if MPI is not None else None)
            if MPI is None or MPI.rank == 0:
                ref_mps.save_data()

        if MPI is not None:
            MPI.barrier()

        _print('transform ref state to singlet embedding finshed.')

        vi_ener = np.zeros((nevpt_ncore, ))
        vi_norm = np.zeros((nevpt_ncore, ))
        va_ener = np.zeros((nevpt_nvirt, ))
        va_norm = np.zeros((nevpt_nvirt, ))

        h1e_sr = h1e_sr - np.einsum('mbbn->mn', g2e_sr, optimize=True)
        mpo.const_e = 0.0

        if orb_idx is not None:
            h1e_sr = h1e_sr[:, orb_idx]
            h1e_si = h1e_si[orb_idx]
            g2e_sr = g2e_sr[:, orb_idx][:, :, orb_idx][:, :, :, orb_idx]
            g2e_si = g2e_si[orb_idx][:, :, orb_idx][:, :, :, orb_idx]
            nevpt_h1e = nevpt_h1e[orb_idx][:, orb_idx]

        gfd = GeneralFCIDUMP()
        gfd.elem_type = ElemOpTypes.SU2
        gfd.exprs.append("(C+D)0")
        gfd.add_sum_term(np.ascontiguousarray(np.sqrt(2) * nevpt_h1e), cutoff=nevpt_symmetrize_ints)
        gfd = gfd.adjust_order(merge=True)

        xmpo = GeneralMPO(ghamil, gfd, MPOAlgorithmTypes.FastBipartite, 1E-12, -1, 0)
        xmpo.iprint = 2 if outputlevel >= 4 else min(outputlevel, 1)
        xmpo.build()
        xmpo = SimplifiedMPO(xmpo, Rule(), False, False)
        xmpo = IdentityAddedMPO(xmpo)

        for icv in range(nevpt_ncore + nevpt_nvirt):

            is_core = icv < nevpt_ncore
            ix = icv if is_core else icv - nevpt_ncore

            _print('=== nevpt compress %s subspace %d ===\n' % ("core" if is_core else "virtual", ix))

            gfd = GeneralFCIDUMP()
            gfd.elem_type = ElemOpTypes.SU2
            if is_core:
                gfd.exprs.append("(C+(C+D)0)1")
                gfd.exprs.append("C")
                gfd.add_sum_term(np.ascontiguousarray(2 * g2e_si[:, ix]), cutoff=nevpt_symmetrize_ints)
                gfd.add_sum_term(np.ascontiguousarray(np.sqrt(2) * h1e_si[:, ix]), cutoff=nevpt_symmetrize_ints)
            else:
                gfd.exprs.append("(D+(C+D)0)1")
                gfd.exprs.append("D")
                gfd.add_sum_term(np.ascontiguousarray(2 * g2e_sr[ix]), cutoff=nevpt_symmetrize_ints)
                gfd.add_sum_term(np.ascontiguousarray(np.sqrt(2) * h1e_sr[ix]), cutoff=nevpt_symmetrize_ints)
            gfd = gfd.adjust_order(merge=True)
            if len(gfd.exprs) == 0:
                continue

            pmpo = GeneralMPO(ghamil, gfd, MPOAlgorithmTypes.FastBipartite, 1E-12, -1, 0)
            pmpo.iprint = 2 if outputlevel >= 4 else min(outputlevel, 1)
            pmpo.build()
            pmpo = SimplifiedMPO(pmpo, Rule(), False, False)
            pmpo = IdentityAddedMPO(pmpo)

            bra_q = pmpo.op.q_label + ref_mps.info.target
            ref_left_vacuum = ref_mps.info.left_dims_fci[0].quanta[0]
            bra_left_vacuum = ref_left_vacuum + pmpo.left_vacuum

            normsq = 0
            ener = 0

            for j in range(bra_left_vacuum.count):

                bra_info = MPSInfo(nevpt_ncas, vacuum, bra_q, ghamil.basis)
                bra_info.tag = orig_tag + "@BRA"
                if "full_fci_space" in dic:
                    bra_info.set_bond_dimension_full_fci(bra_left_vacuum[j], vacuum)
                else:
                    bra_info.set_bond_dimension_fci(bra_left_vacuum[j], vacuum)
                bra_info.set_bond_dimension(bond_dims[0])
                bra_info.bond_dim = bond_dims[0]

                if bra_info.get_max_bond_dimension() == 0:
                    continue

                bra = MPS(nevpt_ncas, ref_mps.center, ref_mps.dot)
                bra.initialize(bra_info)
                bra.random_canonicalize()
                bra.tensors[bra.center].normalize()
                bra.save_mutable()
                bra_info.save_mutable()
                bra.save_data()

                ref_tmp = ref_mps.deep_copy("NEVPT@TMP")
                if len(noises) > 0 and noises[0] != 0:
                    pme = MovingEnvironment(xmpo, bra, bra, "NEVPT-PERT")
                    pme.init_environments(outputlevel >= 2)
                else:
                    pme = None
                me = MovingEnvironment(pmpo, bra, ref_tmp, "NEVPT-CPS")
                me.delayed_contraction = OpNamesSet.normal_ops()
                me.cached_contraction = False # not allowed by perturbative noise
                me.save_partition_info = True
                me.init_environments(outputlevel >= 2)
                cps = Linear(pme, me, VectorUBond(bond_dims),
                    VectorUBond([ref_tmp.info.bond_dim]), VectorFP(noises))
                if pme is not None:
                    cps.noise_type = NoiseTypes.ReducedPerturbative
                    cps.eq_type = EquationTypes.PerturbativeCompression
                cps.iprint = max(min(outputlevel, 3), 0)
                cps.cutoff = float(dic.get("cutoff", 1E-14))
                cps.decomp_type = decomp_type
                cps.trunc_type = trunc_type
                if "twodot_to_onedot" not in dic:
                    norm = cps.solve(len(bond_dims), bra.center == 0, sweep_tol)
                else:
                    tto = int(dic["twodot_to_onedot"])
                    assert len(bond_dims) > tto
                    cps.solve(tto, mps.center == 0, 0)
                    cps.bra_bond_dims = VectorUBond(bond_dims[tto:])
                    me.dot = 1
                    if pme is not None:
                        pme.dot = 1
                    norm = cps.solve(len(bond_dims) - tto, bra.center == 0, sweep_tol)
                    bra.dot = 1
                    ref_tmp.dot = 1
                    if MPI is None or MPI.rank == 0:
                        bra.save_data()
                if cps.lme is not None:
                    cps.lme.remove_partition_files()
                cps.rme.remove_partition_files()

                # no need to add any cg factor since bra and ket factors cancelled
                normsq += norm ** 2

                me = MovingEnvironment(mpo, bra, bra, "NEVPT-EX")
                me.delayed_contraction = OpNamesSet.normal_ops()
                me.cached_contraction = cached_contraction
                me.save_partition_info = True
                me.init_environments(outputlevel >= 2)

                expect = XExpect(me, bra.info.bond_dim, bra.info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                ener += expect.solve(False, bra.center == 0)
                expect.me.remove_partition_files()

            if abs(normsq) > 1E-12:
                ener_cas = ener / normsq + nevpt_const_e
                ener = orbe[icv] * (-1 if is_core else 1) + ener_cas
                ecorr = normsq / (e_casci - ener)
                eamp = normsq / (e_casci - ener) ** 2
            else:
                ener_cas = ener = ecorr = eamp = 0

            _print("Norm^2 = %20.15f" % normsq)
            _print("-- Energy (cas)  = %20.15f" % ener_cas)
            _print("-- Energy        = %20.15f" % ener)
            _print("-- Energy (ref)  = %20.15f" % e_casci)
            _print("-- Amplitude     = %20.15f" % eamp)
            _print("-- Energy (corr) = %20.15f" % ecorr)
            _print()

            if is_core:
                vi_ener[ix] += ecorr
                vi_norm[ix] += eamp
            else:
                va_ener[ix] += ecorr
                va_norm[ix] += eamp

        _print("Core subspace Energy (corr) = %20.15f" % sum(vi_ener))
        _print("Core subspace Amplitude     = %20.15f" % sum(vi_norm))
        _print("Virt subspace Energy (corr) = %20.15f" % sum(va_ener))
        _print("Virt subspace Amplitude     = %20.15f" % sum(va_norm))

        with open(scratch + "/Vi_%d" % iroot, "w") as f:
            f.write("%20.15f\n%20.15f" % (sum(vi_ener), sum(vi_norm)))
        with open(scratch + "/Va_%d" % iroot, "w") as f:
            f.write("%20.15f\n%20.15f" % (sum(va_ener), sum(va_norm)))

    def do_oh(bmps, kmps):
        me = MovingEnvironment(impo if overlap else mpo, bmps, kmps, "OH")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.iprint = max(min(outputlevel, 3), 0)
        E_oh = expect.solve(False, kmps.center == 0)

        if MPI is None or MPI.rank == 0:
            return E_oh
        else:
            return None

    # OH (Hamiltonian expectation on MPS)
    if "restart_oh" in dic or "oh" in dic:

        if nroots == 1:
            E_oh = do_oh(mps, mps)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_oh.npy", E_oh)
                _print("OH Energy = %20.15f" % E_oh)
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
        else:
            mat_oh = np.zeros((nroots, ))
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                E_oh = do_oh(smps, smps)
                if MPI is None or MPI.rank == 0:
                    mat_oh[iroot] = E_oh
                    print("OH Energy %4d - %4d = %20.15f" %
                          (iroot, iroot, E_oh))
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_oh.npy", mat_oh)

    # Transition OH (OH between different MPS roots)
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    # only mat_oh[i, j] with i >= j are filled
    if "restart_tran_oh" in dic or "tran_oh" in dic:

        assert nroots != 1
        mat_oh = np.zeros((nroots, nroots),
                          dtype=np.complex128 if complex_mps else float)
        for iroot in range(nroots):
            for jroot in range(iroot + 1):
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                E_oh = do_oh(simps, sjmps)
                if MPI is None or MPI.rank == 0:
                    mat_oh[iroot, jroot] = E_oh
                    if complex_mps:
                        print("OH Energy %4d - %4d = RE %20.15f + IM %20.15f" %
                              (iroot, jroot, np.real(E_oh), np.imag(E_oh)))
                    else:
                        print("OH Energy %4d - %4d = %20.15f" %
                              (iroot, jroot, E_oh))
            if MPI is None or MPI.rank == 0:
                simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/E_oh.npy", mat_oh)

    # Uncontracted NEVPT2 / MRREPT2
    if dynamic_corr_method is not None and dynamic_corr_method[0] in ["nevpt2s", "nevpt2sd",
            "nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir",
            "nevpt2-i", "nevpt2-r", "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
            "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:

        if fcidump is None:
            raise ValueError(
                "Need FCIDUMP construction (namely, not a pre run) for 'nevpt2/mrrept2'!")
        if MPI is not None:
            MPI.barrier()
        if dynamic_corr_method[0].startswith("nevpt2"):
            if big_site_method == "bigdrt":
                fd_dyall = DyallFCIDUMP(fcidump, n_inactive, n_external, True)
            else:
                fd_dyall = DyallFCIDUMP(fcidump, n_inactive, n_external)
            dm = np.load(scratch + "/1pdm.npy")
            if orb_idx is not None:
                dm = dm[:, orb_idx][:, :, orb_idx]
                dm = np.ascontiguousarray(dm)
            if fcidump.uhf:
                dmx = np.zeros((dm.shape[0] * 2, dm.shape[1] * 2))
                dmx[0::2, 0::2] = dm[0]
                dmx[1::2, 1::2] = dm[1]
                dmx = np.ascontiguousarray(dmx)
                fd_dyall.initialize_from_1pdm_sz(dmx)
            else:
                fd_dyall.initialize_from_1pdm_su2(dm[0] + dm[1])
            fd_zero = fd_dyall
        else:
            if big_site_method == "bigdrt":
                fd_fink = FinkFCIDUMP(fcidump, n_inactive, n_external, True)
            else:
                fd_fink = FinkFCIDUMP(fcidump, n_inactive, n_external)
            fd_zero = fd_fink
        e_casci = float(np.load(scratch + "/E_dmrg.npy"))

        sym_error = fd_zero.symmetrize(orb_sym)
        _print("integral sym error = %12.4g" % sym_error)
        if sym_error > symmetrize_ints_tol:
            raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                                + "please check point group symmetry and FCIDUMP or set"
                                + " a higher tolerance for the keyword '%s'") % (
                symmetrize_ints_tol, "symmetrize_ints"))

        if big_site_method == "folding":
            pass
        elif big_site_method is not None:
            big_left = SimplifiedBigSite(
                big_left_orig, NoTransposeRule(simpl_rule))
            big_right = SimplifiedBigSite(
                big_right_orig, NoTransposeRule(simpl_rule))
            if MPI is not None:
                if one_body_only:
                    big_left = ParallelBigSite(big_left, prule_one_body)
                    big_right = ParallelBigSite(big_right, prule_one_body)
                else:
                    big_left = ParallelBigSite(big_left, prule)
                    big_right = ParallelBigSite(big_right, prule)
            hamil = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                                         None if n_inactive == 0 or big_site_method == "bigdrt" else big_left,
                                         None if (n_external == 0 and big_site_method != "bigdrt")
                                         or n_inactive + n_external == 0 else big_right)
            rhamil = hamil
        if big_site_method is None:
            hm_zero = HamiltonianQC(vacuum, n_sites, orb_sym, fd_zero)
        elif big_site_method == "folding":
            assert dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s', "nevpt2-ijrs", "nevpt2-ij",
                "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]
            hm_zero = HamiltonianQC(vacuum, n_orbs, orb_sym, fd_zero)
            lmpo_fold = MPOQC(hm_zero, qc_type)
            if dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s']:
                ci_order = len(dynamic_corr_method[0]) - 6
                lmps_info_fold = MRCIMPSInfo(
                    n_orbs, n_inactive, n_external, ci_order, vacuum, target, hm_zero.basis)
            elif dynamic_corr_method[0] in ['mrrept2sd', 'mrrept2s']:
                ci_order = len(dynamic_corr_method[0]) - 7
                lmps_info_fold = MRCIMPSInfo(
                    n_orbs, n_inactive, n_external, ci_order, vacuum, target, hm_zero.basis)
            else:
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                lmps_info_fold = NEVPTMPSInfo(n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum,
                                              target, hm_zero.basis)
            for i in range(n_external - 1):
                _print("fold right %d / %d" % (i, n_external))
                lmpo_fold = FusedMPO(lmpo_fold, hm_zero.basis, lmpo_fold.n_sites - 2,
                                     lmpo_fold.n_sites - 1, lmps_info_fold.right_dims_fci[lmpo_fold.n_sites - 2])
                hm_zero.basis = lmpo_fold.basis
                hm_zero.n_sites = lmpo_fold.n_sites
            for i in range(n_inactive - 1):
                _print("fold left %d / %d" % (i, n_inactive))
                lmpo_fold = FusedMPO(
                    lmpo_fold, hm_zero.basis, 0, 1, lmps_info_fold.left_dims_fci[i + 2])
                hm_zero.basis = lmpo_fold.basis
                hm_zero.n_sites = lmpo_fold.n_sites
            for k, op in lmpo_fold.tensors[0].ops.items():
                smat = CSRSparseMatrix()
                if op.sparsity() > 0.75:
                    smat.from_dense(op)
                    op.deallocate()
                else:
                    smat.wrap_dense(op)
                lmpo_fold.tensors[0].ops[k] = smat
            for k, op in lmpo_fold.tensors[-1].ops.items():
                smat = CSRSparseMatrix()
                if op.sparsity() > 0.75:
                    smat.from_dense(op)
                    op.deallocate()
                else:
                    smat.wrap_dense(op)
                lmpo_fold.tensors[-1].ops[k] = smat
            lmpo_fold.sparse_form = 'S' + lmpo_fold.sparse_form[1:-1] + 'S'
            lmpo_fold.tf = TensorFunctions(
                CSROperatorFunctions(hm_zero.opf.cg))
        else:
            assert dynamic_corr_method is not None
            if dynamic_corr_method[0] in ['nevpt2sd', 'mrrept2sd']:
                xl = -2, -2, -2
                xr = 2, 2, 2
            elif dynamic_corr_method[0] in ['nevpt2s', 'mrrept2s']:
                xl = -1, -1, -1
                xr = 1, 1, 1
            elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr",
                                            "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                                            "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                                            "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                # this is not correct yet
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                xl = -n_ex_inactive, -n_ex_inactive, -n_ex_inactive
                xr = n_ex_external, n_ex_external, n_ex_external
            else:
                assert False
            if big_site_method == "fock":
                assert "nonspinadapted" in dic
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *[abs(x) for x in xl], VectorInt([]))
                poccr = SCIFockBigSite.ras_space(True, n_external, *xr, VectorInt([]))
                if '-' in dynamic_corr_method[0]:
                    maxl = max([len(x) for x in poccl])
                    poccl = VectorVectorInt([x for x in poccl if len(x) == maxl - abs(xl[0])])
                    poccr = VectorVectorInt([x for x in poccr if len(x) == abs(xr[0])])
                big_left_orig = SCIFockBigSite(n_orbs, n_inactive, False, fd_zero, orb_sym, poccl, True)
                big_right_orig = SCIFockBigSite(n_orbs, n_external, True, fd_zero, orb_sym, poccr, True)
            elif big_site_method == "csf":
                assert "nonspinadapted" not in dic
                big_left_orig = CSFBigSite(n_inactive, abs(
                    xl[-1]), False, fd_zero, orb_sym[:n_inactive])
                big_right_orig = CSFBigSite(n_external, abs(
                    xr[-1]), True, fd_zero, orb_sym[-n_external:])
            elif big_site_method == "drt":
                left_iqs = DRTBigSite.get_target_quanta(False, n_inactive, abs(xl[-1]), orb_sym[:n_inactive])
                right_iqs = DRTBigSite.get_target_quanta(True, n_external, abs(xr[-1]), orb_sym[-n_external:])
                big_left_orig = DRTBigSite(left_iqs, False, n_inactive,
                    orb_sym[:n_inactive], fd_zero, max(outputlevel, 0))
                big_right_orig = DRTBigSite(right_iqs, True, n_external,
                    orb_sym[-n_external:], fd_zero, max(outputlevel, 0))
            elif big_site_method == "bigdrt":
                left_iqs = DRTBigSite.get_target_quanta(False, 0, 0, orb_sym[:0])
                right_iqs = DRTBigSite.get_target_quanta(True, n_inactive + n_external, abs(xr[-1]),
                    orb_sym[-(n_inactive + n_external):], nc_ref=n_inactive)
                big_left_orig = DRTBigSite(left_iqs, False, 0, orb_sym[:0], fd_zero, max(outputlevel, 0))
                big_right_orig = DRTBigSite(right_iqs, True, n_inactive + n_external,
                    orb_sym[-(n_inactive + n_external):], fd_zero, max(outputlevel, 0))
                big_right_orig.drt = DRT(
                    big_right_orig.drt.n_sites,
                    big_right_orig.drt.get_init_qs(),
                    big_right_orig.drt.orb_sym, n_inactive, n_external, n_ex=abs(xr[-1]), nc_ref=n_inactive,
                )
            else:
                raise NotImplementedError
            big_left = SimplifiedBigSite(big_left_orig, simpl_rule)
            big_right = SimplifiedBigSite(big_right_orig, simpl_rule)
            if MPI is not None:
                if one_body_only:
                    big_left = ParallelBigSite(big_left, prule_one_body)
                    big_right = ParallelBigSite(big_right, prule_one_body)
                else:
                    big_left = ParallelBigSite(big_left, prule)
                    big_right = ParallelBigSite(big_right, prule)
            hm_zero = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fd_zero,
                                            None if n_inactive == 0 or big_site_method == "bigdrt" else big_left,
                                            None if (n_external == 0 and big_site_method != "bigdrt") or n_inactive + n_external == 0 else big_right)

        # left mpo
        _print("build left mpo", time.perf_counter() - tx)
        if big_site_method == "folding":
            lmpo = lmpo_fold
        else:
            lmpo = MPOQC(hm_zero, qc_type, "LQC")
        _print("simpl left mpo", time.perf_counter() - tx)
        lmpo = SimplifiedMPO(lmpo, simpl_rule, True, True,
                             OpNamesSet((OpNames.R, OpNames.RD)))
        _print("simpl left mpo finished", time.perf_counter() - tx)
        lmpo.const_e -= e_casci
        lmpo = lmpo * -1

        mpo_bdims = [None] * len(lmpo.left_operator_names)
        for ix in range(len(lmpo.left_operator_names)):
            lmpo.load_left_operators(ix)
            x = lmpo.left_operator_names[ix]
            mpo_bdims[ix] = x.m * x.n
            lmpo.unload_left_operators(ix)
        _print('LEFT MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

        if MPI is None or MPI.rank == 0:
            lmpo.save_data(scratch + '/lmpo.bin')

        # right mpo
        _print("build right mpo", time.perf_counter() - tx)
        if big_site_method == "folding":
            rmpo = mpo_fold
        else:
            rmpo = MPOQC(rhamil, qc_type, "RQC")
        _print("simpl right mpo", time.perf_counter() - tx)
        rmpo = SimplifiedMPO(rmpo, NoTransposeRule(simpl_rule),
                             True, True, OpNamesSet((OpNames.R, OpNames.RD)))
        _print("simpl right mpo finished", time.perf_counter() - tx)
        rmpo.const_e -= e_casci

        mpo_bdims = [None] * len(rmpo.left_operator_names)
        for ix in range(len(rmpo.left_operator_names)):
            rmpo.load_left_operators(ix)
            x = rmpo.left_operator_names[ix]
            mpo_bdims[ix] = x.m * x.n
            rmpo.unload_left_operators(ix)
        _print('RIGHT MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

        if MPI is None or MPI.rank == 0:
            rmpo.save_data(scratch + '/rmpo.bin')

        if not para_no_pre_run:

            if MPI is not None:
                if one_body_only:
                    lmpo = ParallelMPO(lmpo, prule_one_body)
                    rmpo = ParallelMPO(rmpo, prule_one_body)
                else:
                    lmpo = ParallelMPO(lmpo, prule)
                    rmpo = ParallelMPO(rmpo, prule)

            _print("para left/right mpo finished", time.perf_counter() - tx)

        mps.dot = dot
        if mps.center == mps.n_sites - 1 and mps.dot == 2:
            mps.center = mps.n_sites - 2

        # bra mps
        if big_site_method is not None:
            bra_info = MPSInfo(n_sites, vacuum, target, hm_zero.basis)
        else:
            assert dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s', "nevpt2-ijrs", "nevpt2-ij",
                "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                "mrrept2sd", "mrrept2s", "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]
            if dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s']:
                ci_order = len(dynamic_corr_method[0]) - 6
                bra_info = MRCIMPSInfo(
                    n_sites, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
            elif dynamic_corr_method[0] in ["mrrept2sd", "mrrept2s"]:
                ci_order = len(dynamic_corr_method[0]) - 7
                bra_info = MRCIMPSInfo(
                    n_sites, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
            else:
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                bra_info = NEVPTMPSInfo(n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum,
                                        target, hamil.basis)

        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                bra_info.set_bond_dimension_full_fci(left_vacuum, right_vacuum)
            else:
                bra_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                bra_info.set_bond_dimension_full_fci()
        bra_info.tag = 'BRA'

        while bra_info.tag == mps_info.tag:
            bra_info.tag += 'X'
        bra_info.set_bond_dimension(bond_dims[0])
        if "skip_inact_ext_sites" in dic:
            bra_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)
        if dynamic_corr_method[0].startswith('nevpt2'):
            mrpt_method_name = 'nevpt2'
        else:
            mrpt_method_name = 'mrrept2'
        if MPI is None or MPI.rank == 0:
            bra_info.save_data(scratch + '/%s_mps_info.bin' % mrpt_method_name)
            bra_info.save_data(scratch + '/%s-mps_info.bin' % bra_info.tag)
        bra = MPS(n_sites, mps.center, dot)
        bra.initialize(bra_info)
        bra.random_canonicalize()
        bra.tensors[bra.center].normalize()
        if "skip_inact_ext_sites" in dic:
            bra.set_inact_ext_identity(n_inactive, n_external)
        bra.save_mutable()
        bra.deallocate()
        bra_info.save_mutable()
        bra_info.deallocate_mutable()

        if MPI is not None:
            MPI.barrier()
        if bra.center == 0 and bra.dot == 2:
            bra.move_left(hamil.opf.cg, prule if MPI is not None else None)
        elif bra.center == bra.n_sites - 2 and bra.dot == 2:
            bra.move_right(hamil.opf.cg, prule if MPI is not None else None)
        bra.center = mps.center
        if MPI is not None:
            MPI.barrier()
        
        if big_site_method == "bigdrt":
            bra.load_mutable()
            xidx = big_right_orig.drt >> big_right_orig_z.drt
            xlen = 0
            assert len(xidx) == 1
            for i in range(big_right_orig.drt.n_init_qs):
                xa, xb = big_right_orig.drt.q_range(i)
                if xidx[0] >= xa and xidx[0] < xb:
                    xidx[0] = xidx[0] - xa
                    xlen = xb - xa
                    break
            mps.load_mutable()
            mps.info.load_mutable()
            mps.info.basis[-1] = big_right_orig.drt.get_basis()
            i_alloc = IntVectorAllocator()
            d_alloc = DoubleVectorAllocator()
            minfo = mps.tensors[-1].info.__class__(i_alloc)
            if mps.canonical_form[-1] in 'SC':
                minfo.initialize(mps.info.left_dims[-2], mps.info.basis[-1], mps.info.target, False, True)
            elif mps.canonical_form[-1] in 'R':
                minfo.initialize(mps.info.right_dims[-2], mps.info.basis[-1], mps.info.vacuum, False, False)
            else:
                raise NotImplementedError
            mat = mps.tensors[-1].__class__(d_alloc)
            mat.allocate(minfo)
            assert np.array(mat[0]).shape[1] == xlen
            xmat = np.zeros(np.array(mat[0]).shape, dtype=np.array(mat[0]).dtype)
            xmat[:, xidx[0]] = np.array(mps.tensors[-1][0])[:, 0]
            mat[0] = xmat
            mps.tensors[-1] = mat
            mps.info.right_dims_fci[-2:] = bra.info.right_dims_fci[-2:]
            mps.info.left_dims_fci[-2:] = bra.info.left_dims_fci[-2:]

            mps.save_mutable()
            mps.deallocate()
            mps.info.save_mutable()
            mps.info.deallocate_mutable()

        _print("BRA MPS = ", bra.canonical_form,
               bra.center, bra.dot, bra.info.target)
        _print("BRA INIT MPS BOND DIMS = ", ''.join(
            ["%6d" % x.n_states_total for x in bra_info.left_dims]))

        tx = time.perf_counter()

        lme = MovingEnvironment(lmpo, bra, bra, "LME")
        lme.init_environments(outputlevel >= 2)
        lme.delayed_contraction = OpNamesSet.normal_ops()
        lme.cached_contraction = False
        rme = MovingEnvironment(rmpo, bra, mps, "RME")
        rme.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        right_bdims = VectorUBond([mps.info.bond_dim + 400])
        if big_site_method is not None and n_cas != 0:
            linear = LinearBigSite(lme, rme, None, VectorUBond(
                bond_dims), right_bdims, VectorFP(noises))
            linear.last_site_svd = True
            linear.last_site_1site = dot == 2
            linear.decomp_last_site = False
        else:
            linear = Linear(lme, rme, None, VectorUBond(
                bond_dims), right_bdims, VectorFP(noises))
        linear.iprint = max(min(outputlevel, 3), 0)

        if "skip_inact_ext_sites" in dic:
            linear.sweep_start_site = n_inactive
            linear.sweep_end_site = lme.n_sites - n_external

        if "lowmem_noise" in dic:
            linear.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
        elif decomp_type != DecompositionTypes.SVD:
            linear.noise_type = NoiseTypes.ReducedPerturbativeCollected
        else:
            linear.noise_type = NoiseTypes.ReducedPerturbative
        linear.cutoff = float(dic.get("cutoff", 1E-14))
        linear.decomp_type = decomp_type
        linear.trunc_type = trunc_type
        linear.linear_soft_max_iter = int(
                dic.get("linear_soft_max_iter", -1))
        linear.linear_conv_thrds = VectorFP([x / 50 for x in dav_thrds])

        e_corr = linear.solve(len(bond_dims),
            mps.center == linear.sweep_start_site, sweep_tol)
        nevpt_sweep_energies = np.array(linear.targets)
        nevpt_discarded_weights = np.array(linear.discarded_weights)

        if MPI is None or MPI.rank == 0:
            bdims = bond_dims[:len(nevpt_discarded_weights)]
            if len(bdims) < len(nevpt_discarded_weights):
                bdims = bdims + bdims[-1:] * \
                    (len(nevpt_discarded_weights) - len(bdims))
            np.save(scratch + "/E_%s.npy" % mrpt_method_name, e_casci + e_corr)
            np.save(scratch + "/%s_e_corr.npy" % mrpt_method_name, e_corr)
            np.save(scratch + "/%s_bond_dims.npy" % mrpt_method_name, bdims)
            np.save(scratch + "/%s_sweep_energies.npy" % mrpt_method_name,
                    nevpt_sweep_energies)
            np.save(scratch + "/%s_discarded_weights.npy" % mrpt_method_name,
                    nevpt_discarded_weights)

        _print("DMRG-CASCI  Energy     = %20.15f" % e_casci)
        _print("DMRG-%s Correction = %20.15f" % (mrpt_method_name.upper(), e_corr))
        _print("DMRG-%s Energy     = %20.15f" % (mrpt_method_name.upper(), e_corr + e_casci))

    # MPS split
    if ("restart_copy_mps" in dic or "copy_mps" in dic) and (
            "split_states" in dic or "trans_mps_to_complex" in dic):

        if "restart_copy_mps" in dic:
            copy_tag = dic["restart_copy_mps"]
        else:
            copy_tag = dic["copy_mps"]

        if MPI is None or MPI.rank == 0:
            if "split_states" in dic:
                assert nroots != 1

                for iroot in range(nroots):
                    _print('----- root = %3d / %3d -----' % (iroot, nroots))
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info, mpi=None)
                    if "trans_mps_to_complex" in dic:
                        simps = MultiMPS.make_complex(
                            simps, "%s-CPX-%d" % (mps_info.tag, iroot))
                        simps_info = simps.info
                    if copy_tag != '' and "%s-%d" % (copy_tag, iroot) != simps_info.tag:
                        final_tag = "%s-%d" % (copy_tag, iroot)
                        simps = simps.deep_copy(final_tag)
                        simps_info = simps.info
                    else:
                        final_tag = simps_info.tag
                    _print("   final tag = %s" % final_tag)
                    _print("   final canonical form = %s" %
                           simps.canonical_form)
                    simps_info.save_data(
                        scratch + '/%s-mps_info.bin' % final_tag)

            elif "trans_mps_to_complex" in dic:
                if copy_tag == '':
                    copy_tag = "%s-CPX" % mps_info.tag
                assert copy_tag != mps_info.tag
                simps = MultiMPS.make_complex(mps, copy_tag)
                simps_info = simps.info
                _print("   final tag = %s" % copy_tag)
                _print("   final canonical form = %s" % simps.canonical_form)
                simps_info.save_data(scratch + '/%s-mps_info.bin' % copy_tag)

    # MPS copy/transform
    elif "restart_copy_mps" in dic or "copy_mps" in dic:

        if "restart_copy_mps" in dic:
            copy_tag = dic["restart_copy_mps"]
        else:
            copy_tag = dic["copy_mps"]

        mps.dot = dot
        if MPI is not None:
            MPI.barrier()
        if MPI is None or MPI.rank == 0:
            if (mps.center == 0 and mps.canonical_form[0] == 'S') or \
                    (mps.center == mps.n_sites - 1 and mps.canonical_form[-1] == 'K'):
                mps.flip_fused_form(mps.center, CG(), None)
                mps.save_data()
        if MPI is not None:
            MPI.barrier()
        if mps.center == mps.n_sites - 1 and mps.dot == 2:
            mps.center = mps.n_sites - 2
        _print("Copy-init canonical form = ", mps.canonical_form, mps.center)

        if copy_tag == '':
            raise ValueError(
                "A tag name must be given for the keyword copy_mps/restart_copy_mps!")
        if "trans_mps_to_sz" in dic:
            assert "nonspinadapted" not in dic

            if mps.center != 0:
                xmps = mps.deep_copy(copy_tag + "@TMP")
                _print('change canonical form ...')
                cf = str(xmps.canonical_form)
                ime = MovingEnvironment(impo, xmps, xmps, "IEX")
                ime.delayed_contraction = OpNamesSet.normal_ops()
                ime.cached_contraction = cached_contraction
                ime.init_environments(False)
                if not complex_mps:
                    expect = Expect(ime, xmps.info.bond_dim, xmps.info.bond_dim)
                else:
                    expect = ComplexExpect(ime, xmps.info.bond_dim, xmps.info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                expect.solve(True, xmps.center == 0)
                if MPI is not None:
                    MPI.barrier()
                xmps.save_data()
                if MPI is not None:
                    MPI.barrier()
                _print(cf + ' -> ' + xmps.canonical_form)
            else:
                xmps = mps

            xmps.info.load_mutable()
            xmps.load_mutable()
            targetz = TrSX(xmps.info.target.n, xmps.info.target.twos, xmps.info.target.pg)
            cp_mps = trans_mps(UnfusedMPS(xmps), copy_tag, mpo.tf.opf.cg, targetz)
            if "resolve_twosz" in dic:
                res_twosz = int(dic["resolve_twosz"])
                cp_mps.resolve_singlet_embedding(res_twosz)
            cp_mps = cp_mps.finalize()
            dot_bk, center_bk = cp_mps.dot, cp_mps.center
            if MPI is not None:
                MPI.barrier()
            last_flipped = None
            if dot == 2:
                if cp_mps.center == 0 and cp_mps.canonical_form[cp_mps.center + 1] == 'R':
                    cp_mps.dot = 1
                elif cp_mps.center == cp_mps.n_sites - 2 and cp_mps.canonical_form[cp_mps.center] == 'L':
                    cp_mps.center = cp_mps.n_sites - 1
                    cp_mps.dot = 1
                    if cp_mps.canonical_form[cp_mps.center] in ['C', 'S']:
                        last_flipped = cp_mps.canonical_form[cp_mps.center]
                        from block2.sz import MPICommunicator as MPICommunicatorX
                        from block2.sz import CG as CGX, ParallelRuleQC as ParallelRuleQCX
                        prulex = ParallelRuleQCX(MPICommunicatorX())
                        cp_mps.canonical_form = cp_mps.canonical_form[:-1] + 'S'
                        cp_mps.flip_fused_form(cp_mps.center, CGX(), prulex)
            cp_mps.info.load_mutable()
            cp_mps.load_mutable()
            cp_mps.dynamic_canonicalize()
            if "normalize_mps" in dic:
                cp_mps.tensors[cp_mps.center].normalize()
            cp_mps.dot, cp_mps.center = dot_bk, center_bk
            if MPI is not None:
                MPI.barrier()
            if MPI is None or MPI.rank == 0:
                cp_mps.info.save_mutable()
                cp_mps.save_mutable()
                cp_mps.save_data()
            if MPI is not None:
                MPI.barrier()
            if last_flipped is not None:
                cp_mps.flip_fused_form(cp_mps.n_sites - 1, CGX(), prulex)
                cp_mps.canonical_form = cp_mps.canonical_form[: -
                                                              1] + last_flipped
                if MPI is None or MPI.rank == 0:
                    cp_mps.save_data()
                if MPI is not None:
                    MPI.barrier()
        else:
            cp_mps = mps.deep_copy(copy_tag)

        if "trans_mps_to_singlet_embedding" in dic or "trans_mps_from_singlet_embedding" in dic:
            assert "nonspinadapted" not in dic
            cp_mps = mps.deep_copy(copy_tag)
            if cp_mps.canonical_form[0] == 'C' and cp_mps.canonical_form[1] == 'R':
                cp_mps.canonical_form = 'K' + cp_mps.canonical_form[1:]
                cp_mps.center = 0
            elif cp_mps.canonical_form[-1] == 'C' and cp_mps.canonical_form[-2] == 'L':
                cp_mps.canonical_form = cp_mps.canonical_form[:-1] + 'S'
                cp_mps.center = cp_mps.n_sites - 1
            elif cp_mps.center == cp_mps.n_sites - 2 and cp_mps.canonical_form[-2] == 'L':
                cp_mps.center = cp_mps.n_sites - 1
            while cp_mps.center > 0:
                cp_mps.move_left(CG(), prule if MPI is not None else None)
            if "trans_mps_to_singlet_embedding" in dic:
                cp_mps.to_singlet_embedding_wfn(CG(), SX.invalid, prule if MPI is not None else None)
            if "trans_mps_from_singlet_embedding" in dic:
                cp_mps.from_singlet_embedding_wfn(CG(), prule if MPI is not None else None)
            if MPI is None or MPI.rank == 0:
                cp_mps.save_data()

        if MPI is None or MPI.rank == 0:
            cp_mps.info.save_data(scratch + '/mps_info.bin')
            cp_mps.info.save_data(scratch + '/%s-mps_info.bin' % copy_tag)

        mps = cp_mps
        _print("Copy-final canonical form = ", mps.canonical_form, mps.center)

    # stoptDMRG sampling (sampling CSF is not supported yet)
    if "stopt_sampling" in dic:

        if fcidump is None:
            raise ValueError(
                "Need FCIDUMP construction (namely, not a pre run) for 'stoptDMRG'!")
        if MPI is not None:
            MPI.barrier()

        try:
            from pyblock2.driver.stopt import SPDMRG
        except ImportError:
            from stopt import SPDMRG

        nsample = int(dic.get("stopt_sampling", 10000))
        sp_dmrg = SPDMRG(su2="nonspinadapted" not in dic, scratch=scratch, fcidump=fcidump,
                         orb_idx=orb_idx, mps_tags=mps_tags, verbose=outputlevel)
        e_corr, std_corr = sp_dmrg.kernel(nsample)

        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/E_stopt.npy", sp_dmrg.Edmrg + e_corr)
            np.save(scratch + "/stopt_e_corr.npy", e_corr)
            np.save(scratch + "/stopt_std_corr.npy", std_corr)

        _print("            DMRG Energy     = %25.15f" % sp_dmrg.Edmrg)
        _print("stochastic PDMRG Correction = %25.15f (%20.15f)" %
               (e_corr, std_corr))
        _print("stochastic PDMRG Energy     = %25.15f (%20.15f)" %
               (sp_dmrg.Edmrg + e_corr, std_corr))

    # CSF/DET coefficients
    if "restart_sample" in dic or "sample" in dic:

        if "restart_sample" in dic:
            sample_cutoff = float(dic.get("restart_sample", 0))
        else:
            sample_cutoff = float(dic.get("sample", 0))

        if "trans_mps_to_sz" in dic:
            from block2.sz import DeterminantTRIE, UnfusedMPS
            su2mps = False
        else:
            su2mps = SX == SU2

        if "sample_reference" not in dic:
            max_rank = nelec[0]
            sample_ref = []
        else:
            sample_params = dic.get("sample_reference", 0).split()
            max_rank = int(sample_params[0])
            sample_ref = sample_params[1]
            assert len(sample_ref) == n_sites
            sample_ref = [ int(i) for i in sample_ref ]
            nelec_ref = 0
            for occ in sample_ref:
                if occ == 3: nelec_ref += 2 
                elif occ == 1 or occ == 2: nelec_ref += 1 
            assert nelec_ref == nelec[0] 

        tx = time.perf_counter()
        dtrie = DeterminantTRIE(n_sites, True)
        mps.info.load_mutable()
        mps.load_mutable()
        if mps.center != 0:
            _print("Warning: sample an MPS with center != 0 will be highly inefficient!")
            _print("One can load the MPS and do one extra sweep to change the center.")
        dtrie.evaluate(UnfusedMPS(mps), sample_cutoff, max_rank, VectorUInt8(sample_ref))

        if "sample_phase" in dic:
            ref_idx = dic["sample_phase"].split()
            if len(ref_idx) == 0:
                if orb_idx is None:
                    ref_idx = [] 
                else:
                    ref_idx = orb_idx
            else:
                ref_idx = [int(i) for i in ref_idx]
                assert len(ref_idx) == n_sites
                check_idx = [i for i in ref_idx if i >= 0 and i < n_sites] 
                assert len(set(check_idx)) == n_sites
            dtrie.convert_phase(VectorInt(ref_idx))
        _print("dtrie finished", time.perf_counter() - tx)

        if MPI is None or MPI.rank == 0:
            dname = "CSF" if su2mps else "DET"
            _print("Number of %s = %10d (cutoff = %9.5g, max_rank = %3d)" %
                   (dname, len(dtrie), sample_cutoff, max_rank))
            ddstr = "0+-2" if su2mps else "0ab2"
            dvals = np.array(dtrie.vals)
            gidx = np.argsort(np.abs(dvals))[::-1][:50]
            _print("Sum of weights of sampled %s = %20.15f\n" %
                   (dname, (dvals ** 2).sum()))
            for ii, idx in enumerate(gidx):
                det = ''.join([ddstr[x] for x in dtrie[idx]])
                val = dvals[idx]
                _print(dname, "%10d" % ii, det, " = %20.15f" % val)
            if len(dvals) > 50:
                _print(" ... and more ... ")
            np.save(scratch + "/sample-vals.npy", dvals)
            dets = np.zeros((len(dtrie), n_sites), dtype=np.uint8)
            for i in range(len(dtrie)):
                dets[i] = np.array(dtrie[i])
            np.save(scratch + "/sample-dets.npy", dets)

            nq = 4 if "use_general_spin" not in dic else 2
            state_occ = np.array(
                dtrie.get_state_occupation()).reshape(n_sites, nq)
            # state_occ += ((1 - state_occ.sum(axis=1)) / 4)[:, None]
            state_occ *= (1 / state_occ.sum(axis=1))[:, None]
            _print("STATE OCC = ", "".join(
                ["%8.5f" % x for x in state_occ.ravel()]))
            np.save(scratch + "/sample-stocc.npy", state_occ)
