#!/usr/bin/env python import sys from util import readFile from jax import grad, numpy as jnp from constants import m from state import energy import numpy as np # from jax.config import config # config.update("jax_debug_nans", True) def main (argv): if len(argv) != 4: print(f'{argv[0]} ') sys.exit(2) state = readFile(argv[1]) timestep = float(argv[2]) steps = int(argv[3]) energy_grad = grad(energy) force = lambda: -energy_grad(*state.positions()) a = lambda: force()/m with open('./res/integrated.mols', 'w') as f: a_t = a() for i in range(steps): if i % 100 == 0: print(f'Step {i}') v_t = state.molecules[:,3:6].ravel() # change in position delta_x = v_t*timestep + 0.5*a_t*timestep**2 state.molecules[:, 0:3] += delta_x.reshape((-1,3)) # change in velocity a_tplus1 = a() delta_v = (a_t + a_tplus1)/2*timestep state.molecules[:, 3:6] += delta_v.reshape((-1,3)) state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size) a_t = a_tplus1 f.write(state.serialize()) if __name__ == '__main__': main(sys.argv)