44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
#!/usr/bin/env python
|
|
import sys
|
|
from util import readFile
|
|
from jax import grad, numpy as jnp
|
|
from constants import m
|
|
from state import energy
|
|
import numpy as np
|
|
# from jax.config import config
|
|
# config.update("jax_debug_nans", True)
|
|
|
|
def main (argv):
|
|
if len(argv) != 4:
|
|
print(f'{argv[0]} <input .mol file> <timestep> <number of steps>')
|
|
sys.exit(2)
|
|
|
|
state = readFile(argv[1])
|
|
timestep = float(argv[2])
|
|
steps = int(argv[3])
|
|
|
|
energy_grad = grad(energy)
|
|
force = lambda: -energy_grad(*state.positions())
|
|
a = lambda: force()/m
|
|
|
|
with open('./res/integrated.mols', 'w') as f:
|
|
a_t = a()
|
|
for i in range(steps):
|
|
if i % 100 == 0: print(f'Step {i}')
|
|
v_t = state.molecules[:,3:6].ravel()
|
|
|
|
# change in position
|
|
delta_x = v_t*timestep + 0.5*a_t*timestep**2
|
|
state.molecules[:, 0:3] += delta_x.reshape((-1,3))
|
|
|
|
# change in velocity
|
|
a_tplus1 = a()
|
|
delta_v = (a_t + a_tplus1)/2*timestep
|
|
state.molecules[:, 3:6] += delta_v.reshape((-1,3))
|
|
state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size)
|
|
|
|
a_t = a_tplus1
|
|
f.write(state.serialize())
|
|
|
|
if __name__ == '__main__':
|
|
main(sys.argv) |