Task 2 Setup
This commit is contained in:
parent
cde4ab6a17
commit
a3d4d8a2c6
|
@ -0,0 +1,6 @@
|
|||
D_e = 1.6 # eV
|
||||
alpha = 3.028 # Å^−1
|
||||
r_e = 1.411 # Å
|
||||
dt = 0.1 # s
|
||||
k_B = 8.617333e-5 # eV/K
|
||||
m = 18.998403 # amu
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python
|
||||
from math import sqrt
|
||||
from util import readFile, writeFile
|
||||
import sys
|
||||
from getopt import getopt
|
||||
from state import State
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from constants import k_B, m
|
||||
|
||||
def main(argv):
|
||||
if len(argv) != 4:
|
||||
print(f'{argv[0]} <number of particles> <side length> <temperature>')
|
||||
sys.exit(2)
|
||||
|
||||
state = State(
|
||||
count=int(argv[1]),
|
||||
size=float(argv[2]),
|
||||
temperature=float(argv[3])
|
||||
)
|
||||
|
||||
state.molecules = np.hstack([np.random.rand(state.count, 3)*state.size, np.zeros((state.count, 3))])
|
||||
|
||||
state.minimize()
|
||||
|
||||
sigma = sqrt(k_B*state.temperature/m)
|
||||
state.molecules[:,3:6] = np.random.multivariate_normal([0,0,0], np.eye(3)*sigma, state.count)
|
||||
|
||||
avg = np.sum(state.molecules[:,3:6], 0)
|
||||
state.molecules[:,3:6] -= avg / state.count
|
||||
|
||||
writeFile('./res/initial.mol', state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
import sys
|
||||
from util import readFile
|
||||
from jax import grad, numpy as jnp
|
||||
from constants import m
|
||||
from state import energy
|
||||
import numpy as np
|
||||
# from jax.config import config
|
||||
# config.update("jax_debug_nans", True)
|
||||
|
||||
def main (argv):
|
||||
if len(argv) != 4:
|
||||
print(f'{argv[0]} <input .mol file> <timestep> <number of steps>')
|
||||
sys.exit(2)
|
||||
|
||||
state = readFile(argv[1])
|
||||
timestep = float(argv[2])
|
||||
steps = int(argv[3])
|
||||
|
||||
energy_grad = grad(energy)
|
||||
force = lambda: -energy_grad(*state.positions())
|
||||
a = lambda: force()/m
|
||||
|
||||
with open('./res/integrated.mols', 'w') as f:
|
||||
a_t = a()
|
||||
for i in range(steps):
|
||||
if i % 100 == 0: print(f'Step {i}')
|
||||
v_t = state.molecules[:,3:6].ravel()
|
||||
|
||||
# change in position
|
||||
delta_x = v_t*timestep + 0.5*a_t*timestep**2
|
||||
state.molecules[:, 0:3] += delta_x.reshape((-1,3))
|
||||
|
||||
# change in velocity
|
||||
a_tplus1 = a()
|
||||
delta_v = (a_t + a_tplus1)/2*timestep
|
||||
state.molecules[:, 3:6] += delta_v.reshape((-1,3))
|
||||
state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size)
|
||||
|
||||
a_t = a_tplus1
|
||||
f.write(state.serialize())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
import sys
|
||||
from state import State
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
def main(argv):
|
||||
path = argv[1]
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
count = int(lines[0])
|
||||
size = float(lines[2])
|
||||
split = count + 3
|
||||
states = [lines[i:i+split] for i in range(0, len(lines), split)]
|
||||
states = [State.deserialize(state) for state in states]
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
current = 0
|
||||
while True:
|
||||
ax.cla()
|
||||
plt.xlim([0,size])
|
||||
plt.ylim([0,size])
|
||||
plt.title(f'Step {current}')
|
||||
ax.set_zlim([0,size])
|
||||
ax.scatter(
|
||||
states[current].molecules[:,0],
|
||||
states[current].molecules[:,1],
|
||||
states[current].molecules[:,2],
|
||||
c=range(count)
|
||||
)
|
||||
plt.pause(0.02)
|
||||
|
||||
current += 1
|
||||
current %= len(states)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
|
@ -0,0 +1,83 @@
|
|||
from dataclasses import dataclass
|
||||
from math import e
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from constants import alpha, D_e, r_e
|
||||
from scipy.optimize import minimize, Bounds
|
||||
from scipy.spatial import distance_matrix as dm
|
||||
from datetime import datetime
|
||||
from jax import numpy as jnp
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
count: int = 0
|
||||
size: float = 1.0
|
||||
molecules: np.matrix = np.zeros((0,6))
|
||||
temperature: float = 0
|
||||
|
||||
def energy(self) -> float:
|
||||
return energy(self.positions())
|
||||
|
||||
def minimize(self) -> None:
|
||||
res = minimize(energy, self.positions(), method='L-BFGS-B', bounds=Bounds(0,self.size), jac='3-point')
|
||||
min = res.x
|
||||
min = np.matrix(min)
|
||||
min.shape = (self.count,3)
|
||||
self.molecules[:,0:3] = min
|
||||
|
||||
@staticmethod
|
||||
def deserialize(lines: List[str]) -> None:
|
||||
# disregard comment line
|
||||
lines.pop(1)
|
||||
|
||||
count = int(lines.pop(0))
|
||||
size = float(lines.pop(0))
|
||||
lines = (line.strip() for line in lines)
|
||||
lines = (line.split() for line in lines if line)
|
||||
molecules = np.matrix([[float(f) for f in line] for line in lines])
|
||||
|
||||
return State(count, size, molecules)
|
||||
|
||||
def serialize(self) -> str:
|
||||
serialized = ''
|
||||
serialized += f'{self.count}\n'
|
||||
serialized += f'Written at {datetime.now()}\n'
|
||||
serialized += f'{self.size}\n'
|
||||
|
||||
m = np.asmatrix(self.molecules)
|
||||
for i in range(m.shape[0]):
|
||||
serialized += '\t'.join(['{:+10.8f}'.format(f) for f in m[i,:].tolist()[0]]) + '\n'
|
||||
|
||||
return serialized
|
||||
|
||||
def positions(self) -> List[float]:
|
||||
return self.molecules[:,0:3].ravel()
|
||||
|
||||
def energy (p: np.matrix):
|
||||
P = jnp.reshape(p, (-1,3))
|
||||
V = _morse(distances(P))
|
||||
return jnp.nansum(jnp.tril(V, k=-1))
|
||||
|
||||
@jnp.vectorize
|
||||
def _morse(r: float):
|
||||
return D_e * (e**(-2*alpha*(r-r_e)) - 2*e**(-alpha*(r - r_e)))
|
||||
|
||||
def distances(P):
|
||||
P = jnp.asarray(P)
|
||||
|
||||
XYZ = jnp.asarray([
|
||||
# insert new axis with None to get an outer difference matrix for each dimension
|
||||
P[:, 0, None] - P[None, :, 0],
|
||||
P[:, 1, None] - P[None, :, 1],
|
||||
P[:, 2, None] - P[None, :, 2]
|
||||
])
|
||||
|
||||
XYZ = abs(XYZ)
|
||||
|
||||
# nearest image convention
|
||||
XYZ = XYZ.at[XYZ > 0.5].add(-0.5)
|
||||
XYZ = XYZ**2
|
||||
|
||||
D = jnp.sum(XYZ, axis=0)
|
||||
D = D.at[D <= 0].set(0.0001)
|
||||
return jnp.sqrt(D)
|
|
@ -0,0 +1,12 @@
|
|||
from state import State
|
||||
|
||||
def readFile (path: str) -> State:
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
return State.deserialize(lines)
|
||||
|
||||
|
||||
def writeFile (path: str, state: State) -> None:
|
||||
with open(path, 'w') as f:
|
||||
f.write(state.serialize())
|
Loading…
Reference in New Issue