44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python
 | 
						|
import sys
 | 
						|
from util import readFile
 | 
						|
from jax import grad, numpy as jnp
 | 
						|
from constants import m
 | 
						|
from state import energy
 | 
						|
import numpy as np
 | 
						|
# from jax.config import config
 | 
						|
# config.update("jax_debug_nans", True)
 | 
						|
 | 
						|
def main (argv):
 | 
						|
  if len(argv) != 4:
 | 
						|
    print(f'{argv[0]} <input .mol file> <timestep> <number of steps>')
 | 
						|
    sys.exit(2)
 | 
						|
 | 
						|
  state = readFile(argv[1])
 | 
						|
  timestep = float(argv[2])
 | 
						|
  steps = int(argv[3])
 | 
						|
  
 | 
						|
  energy_grad = grad(energy)
 | 
						|
  force = lambda: -energy_grad(*state.positions())
 | 
						|
  a = lambda: force()/m
 | 
						|
 | 
						|
  with open('./res/integrated.mols', 'w') as f:
 | 
						|
    a_t = a()
 | 
						|
    for i in range(steps):
 | 
						|
      if i % 100 == 0: print(f'Step {i}')
 | 
						|
      v_t = state.molecules[:,3:6].ravel()
 | 
						|
 | 
						|
      # change in position
 | 
						|
      delta_x = v_t*timestep + 0.5*a_t*timestep**2
 | 
						|
      state.molecules[:, 0:3] += delta_x.reshape((-1,3))
 | 
						|
      
 | 
						|
      # change in velocity
 | 
						|
      a_tplus1 = a()
 | 
						|
      delta_v = (a_t + a_tplus1)/2*timestep
 | 
						|
      state.molecules[:, 3:6] += delta_v.reshape((-1,3))
 | 
						|
      state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size)
 | 
						|
 | 
						|
      a_t = a_tplus1
 | 
						|
      f.write(state.serialize())
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
  main(sys.argv) |