#!python

libnames = [('mdtraj', 'md'), ('numpy', 'np'), ('keras', 'krs'), ('argparse', 'arg'), ('datetime', 'dt'), ('sys', 'sys')]

for (name, short) in libnames:
  try:
    lib = __import__(name)
  except:
    print("Library %s is not installed, exiting" % name)
    exit(0)
  else:
    globals()[short] = lib

import anncolvar

# Parsing command line arguments
parser = arg.ArgumentParser(description='Artificial neural network learning of collective variables of molecular systems, requires numpy, keras and mdtraj')

parser.add_argument('-i', dest='infile', default='traj.xtc',
help='Input trajectory in pdb, xtc, trr, dcd, netcdf or mdcrd, WARNING: the trajectory must be 1. centered in the PBC box, 2. fitted to a reference structure and 3. must contain only atoms to be analysed!')

parser.add_argument('-p', dest='intop', default='top.pdb',
help='Input topology in pdb, WARNING: the structure must be 1. centered in the PBC box and 2. must contain only atoms to be analysed!')

parser.add_argument('-c', dest='colvar', default='colvar.txt',
help='Input collective variable file in text formate, must contain the same number of lines as frames in the trajectory')

parser.add_argument('-col', dest='col', default=2, type=int,
help='The index of the column containing collective variables in the input collective variable file')

parser.add_argument('-boxx', dest='boxx', default=0.0, type=float,
help='Size of x coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-boxy', dest='boxy', default=0.0, type=float,
help='Size of y coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-boxz', dest='boxz', default=0.0, type=float,
help='Size of z coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-nofit', dest='nofit', default='False',
help='Disable fitting, the trajectory must be properly fited (default False)')

parser.add_argument('-testset', dest='testset', default=0.10, type=float,
help='Size of test set (fraction of the trajectory, default = 0.1)')

parser.add_argument('-shuffle', dest='shuffle', default='True',
help='Shuffle trajectory frames to obtain training and test set (default True)')

parser.add_argument('-layers', dest='layers', default=1, type=int,
help='Number of hidden layers (allowed values 1-3, default = 1)')

parser.add_argument('-layer1', dest='layer1', default=256, type=int,
help='Number of neurons in the first encoding layer (default = 256)')

parser.add_argument('-layer2', dest='layer2', default=256, type=int,
help='Number of neurons in the second encoding layer (default = 256)')

parser.add_argument('-layer3', dest='layer3', default=256, type=int,
help='Number of neurons in the third encoding layer (default = 256)')

parser.add_argument('-actfun1', dest='actfun1', default='sigmoid',
help='Activation function of the first layer (default = sigmoid, for options see keras documentation)')

parser.add_argument('-actfun2', dest='actfun2', default='linear',
help='Activation function of the second layer (default = linear, for options see keras documentation)')

parser.add_argument('-actfun3', dest='actfun3', default='linear',
help='Activation function of the third layer (default = linear, for options see keras documentation)')

parser.add_argument('-optim', dest='optim', default='adam',
help='Optimizer (default = adam, for options see keras documentation)')

parser.add_argument('-loss', dest='loss', default='mean_squared_error',
help='Loss function (default = mean_squared_error, for options see keras documentation)')

parser.add_argument('-epochs', dest='epochs', default=100, type=int,
help='Number of epochs (default = 100, >1000 may be necessary for real life applications)')

parser.add_argument('-batch', dest='batch', default=256, type=int,
help='Batch size (0 = no batches, default = 256)')

parser.add_argument('-o', dest='ofile', default='',
help='Output file with original and approximated collective variables (txt, default = no output)')

parser.add_argument('-model', dest='modelfile', default='',
help='Prefix for output model files (experimental, default = no output)')

parser.add_argument('-plumed', dest='plumedfile', default='plumed.dat',
help='Output file for Plumed (default = plumed.dat)')

args = parser.parse_args()

infilename = args.infile
intopname = args.intop
colvarname = args.colvar
column = args.col
boxx = args.boxx
boxy = args.boxy
boxz = args.boxz
if args.testset < 0.0 or args.testset > 0.5:
  print("ERROR: -testset must be 0.0 - 0.5")
  exit(0)
atestset = float(args.testset)

# Shuffling the trajectory before splitting
if args.shuffle == "True":
  shuffle = 1
elif args.shuffle == "False":
  shuffle = 0
else:
  print("ERROR: -shuffle %s not understood" % args.shuffle)
  exit(0)
if args.nofit == "True":
  nofit = 1
elif args.nofit == "False":
  nofit = 0
else:
  print("ERROR: -nofit %s not understood" % args.nofit)
  exit(0)
if args.layers < 1 or args.layers > 3:
  print("ERROR: -layers must be 1-3, for deeper learning contact authors")
  exit(0)
if args.layer1 > 1024:
  print("WARNING: You plan to use %i neurons in the first layer, could be slow")
if args.layers == 2:
  if args.layer2 > 1024:
    print("WARNING: You plan to use %i neurons in the second layer, could be slow")
if args.layers == 3:
  if args.layer3 > 1024:
    print("WARNING: You plan to use %i neurons in the third layer, could be slow")
if args.actfun1 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
  print("ERROR: cannot understand -actfun1 %s" % args.actfun1)
  exit(0)
if args.layers == 2:
  if args.actfun2 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
    print("ERROR: cannot understand -actfun2 %s" % args.actfun1)
    exit(0)
if args.layers == 3:
  if args.actfun3 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
    print("ERROR: cannot understand -actfun3 %s" % args.actfun3)
    exit(0)
if args.layers == 1 and args.actfun2!='linear':
  print("ERROR: actfun2 must be linear for -layers 1")
  exit(0)
if args.layers == 2 and args.actfun3!='linear':
  print("ERROR: actfun3 must be linear for -layers 2")
  exit(0)
layers = args.layers
layer1 = args.layer1
layer2 = args.layer2
layer3 = args.layer3
actfun1 = args.actfun1
actfun2 = args.actfun2
actfun3 = args.actfun3
epochs = args.epochs
optim = args.optim
batch = args.batch
loss = args.loss
if args.ofile[-4:] == '.txt':
  ofilename = args.ofile
elif len(args.ofile)>0:
  ofilename = args.ofile + '.txt'
else:
  ofilename = ''
modelfile = args.modelfile
plumedfile = args.plumedfile
if plumedfile[-4:] != '.dat':
  plumedfile = plumedfile + '.dat'
anncolvar.anncollectivevariable(infilename, intopname, colvarname, column,
                                boxx, boxy, boxz, atestset, shuffle, nofit,
                                layers, layer1, layer2, layer3,
                                actfun1, actfun2, actfun3,
                                optim, loss, epochs, batch,
                                ofilename, modelfile, plumedfile)


