NSSC/Exercise_02/integrate.py

44 lines
1.2 KiB
Python

#!/usr/bin/env python
import sys
from util import readFile
from jax import grad, numpy as jnp
from constants import m
from state import energy
import numpy as np
# from jax.config import config
# config.update("jax_debug_nans", True)
def main (argv):
if len(argv) != 4:
print(f'{argv[0]} <input .mol file> <timestep> <number of steps>')
sys.exit(2)
state = readFile(argv[1])
timestep = float(argv[2])
steps = int(argv[3])
energy_grad = grad(energy)
force = lambda: -energy_grad(*state.positions())
a = lambda: force()/m
with open('./res/integrated.mols', 'w') as f:
a_t = a()
for i in range(steps):
if i % 100 == 0: print(f'Step {i}')
v_t = state.molecules[:,3:6].ravel()
# change in position
delta_x = v_t*timestep + 0.5*a_t*timestep**2
state.molecules[:, 0:3] += delta_x.reshape((-1,3))
# change in velocity
a_tplus1 = a()
delta_v = (a_t + a_tplus1)/2*timestep
state.molecules[:, 3:6] += delta_v.reshape((-1,3))
state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size)
a_t = a_tplus1
f.write(state.serialize())
if __name__ == '__main__':
main(sys.argv)