#!/usr/bin/env python

################################################################################
###                         parse script parameters                          ###
################################################################################
from optparse import OptionParser

usage = "usage: %prog [options] [<NR_PARTICLES>] [<BOX_SIZE>] [<TEMPERATURE>]"
arg_parser = OptionParser(usage = usage)
arg_parser.add_option("-M", "--particles", action = "store", type = int,
    dest = "nr_particles", default = -1,
    help = "Nr. of particles (required!)")
arg_parser.add_option("-L", "--box_size", action = "store", type = float,
    dest = "box_size", default = -1.0,
    help = "side box_size [A (angstrom)] of the simulation cube (required!)")
arg_parser.add_option("-T", "--temperature", action = "store", type = float,
    dest = "temperature", default = -1.0,
    help = "temperature [K] to generate initial conditions (required!)")
arg_parser.add_option("-o", "--output", action = "store", type = str,
    dest = "output", default = "task02.xyz",
    help = "output file path (default: 'task02.xyz')")
arg_parser.add_option("-v", action = "store_true",
    dest = "verbose", default = False,
    help = "turn verbosity mode on (default: off a.k.a. silent)")
# Parse command line arguments (as def. above) or store defaults to `config`
config, args = arg_parser.parse_args()
# overwrite options with positional arguments if supplied
try:
    if len(args) > 0:
        config.nr_particles = int(args[0])
    if len(args) > 1:
        config.box_size = float(args[1])
    if len(args) > 2:
        config.temperature = float(args[2])
except ValueError as expression:
    arg_parser.print_help()
    print(f"Error: {expression}")
    exit(-1)
else:
    # quick and dirty validation
    if not config.nr_particles > 0 \
    or not config.box_size > 0.0 \
    or not config.temperature > 0.0:
        arg_parser.print_help()
        print("Error: missing or illegal argument")
        exit(-1)

################################################################################
###                task 2 / generation of initial conditions                 ###
################################################################################
# note, load module _after_ processing script parameters (no need to load all
# of the heavy numeric modules if only usage or alike is needed)
import numpy as np
import scipy
from jax import jit, grad
from molecular_dynamics import dump, energy, force, mass

# Sample random positions in a 3D cube (TODO: make this not just uniform :-})
position = np.random.uniform(0.0, config.box_size, (config.nr_particles, 3))

# Sample particle velocities
sd = np.sqrt(scipy.constants.Boltzmann * config.temperature / mass)
velocity = np.random.normal(0.0, sd, (config.nr_particles, 3))
# center velocities
velocity -= velocity.mean(axis = 0)

# remember energy before optimizing for a low energy state
initial_energy = energy(position, config.box_size)
forces = force(position, config.box_size)
initial_mean_forces = forces.mean(axis = 0)
initial_mean_fnorm = np.linalg.norm(forces, axis = 1).mean()

# optimize energy to find low energy system state using Conjugate-Gradients
optim = scipy.optimize.minimize(energy,                     # objective func.
                                jac = jit(grad(energy)),    # jacobian
                                x0 = position,              # initial position
                                args = (config.box_size, ), # further args
                                method = "CG")
# extract (and reshape) optimization result
position = optim.x.reshape((config.nr_particles, 3))
# ensure all particles are in the box
position = np.mod(position, config.box_size)

# recompute stats after optimizing for low energy state
final_energy = energy(position, config.box_size)
forces = force(position, config.box_size)
final_mean_forces = forces.mean(axis = 0)
final_mean_fnorm = np.linalg.norm(forces, axis = 1).mean()

# store state snapshot to file (default target file defined by script args)
dump(config.output, position, velocity, config.box_size)

# report stats (if requested by `-v` script argument)
if config.verbose:
    print(f"Initial Energy:          {initial_energy:.4e}")
    print( "Initial Mean Forces:     {:.4e} {:.4e} {:.4e}".format(*initial_mean_forces))
    print(f"Initial Mean ||Forces||: {initial_mean_fnorm:.4e}")
    print(f"Final Energy:            {final_energy:.4e}")
    print( "Final Mean Forces:       {:.4e} {:.4e} {:.4e}".format(*final_mean_forces))
    print(f"Final Mean ||Forces||:   {final_mean_fnorm:.4e}")
    print(f"Done: saved inital state to '{config.output}'")