diff --git a/Exercise_02/constants.py b/Exercise_02/constants.py new file mode 100644 index 0000000..ccd77e8 --- /dev/null +++ b/Exercise_02/constants.py @@ -0,0 +1,6 @@ +D_e = 1.6 # eV +alpha = 3.028 # Å^−1 +r_e = 1.411 # Å +dt = 0.1 # s +k_B = 8.617333e-5 # eV/K +m = 18.998403 # amu \ No newline at end of file diff --git a/Exercise_02/generate_conditions.py b/Exercise_02/generate_conditions.py new file mode 100644 index 0000000..fc0827b --- /dev/null +++ b/Exercise_02/generate_conditions.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +from math import sqrt +from util import readFile, writeFile +import sys +from getopt import getopt +from state import State +import numpy as np +from matplotlib import pyplot as plt +from constants import k_B, m + +def main(argv): + if len(argv) != 4: + print(f'{argv[0]} ') + sys.exit(2) + + state = State( + count=int(argv[1]), + size=float(argv[2]), + temperature=float(argv[3]) + ) + + state.molecules = np.hstack([np.random.rand(state.count, 3)*state.size, np.zeros((state.count, 3))]) + + state.minimize() + + sigma = sqrt(k_B*state.temperature/m) + state.molecules[:,3:6] = np.random.multivariate_normal([0,0,0], np.eye(3)*sigma, state.count) + + avg = np.sum(state.molecules[:,3:6], 0) + state.molecules[:,3:6] -= avg / state.count + + writeFile('./res/initial.mol', state) + + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/Exercise_02/integrate.py b/Exercise_02/integrate.py new file mode 100644 index 0000000..d0bd843 --- /dev/null +++ b/Exercise_02/integrate.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +import sys +from util import readFile +from jax import grad, numpy as jnp +from constants import m +from state import energy +import numpy as np +# from jax.config import config +# config.update("jax_debug_nans", True) + +def main (argv): + if len(argv) != 4: + print(f'{argv[0]} ') + sys.exit(2) + + state = readFile(argv[1]) + timestep = float(argv[2]) + steps = int(argv[3]) + + energy_grad = grad(energy) + force = lambda: -energy_grad(*state.positions()) + a = lambda: force()/m + + with open('./res/integrated.mols', 'w') as f: + a_t = a() + for i in range(steps): + if i % 100 == 0: print(f'Step {i}') + v_t = state.molecules[:,3:6].ravel() + + # change in position + delta_x = v_t*timestep + 0.5*a_t*timestep**2 + state.molecules[:, 0:3] += delta_x.reshape((-1,3)) + + # change in velocity + a_tplus1 = a() + delta_v = (a_t + a_tplus1)/2*timestep + state.molecules[:, 3:6] += delta_v.reshape((-1,3)) + state.molecules[:, 0:3] = np.mod(state.molecules[:,0:3], state.size) + + a_t = a_tplus1 + f.write(state.serialize()) + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/Exercise_02/plot.py b/Exercise_02/plot.py new file mode 100644 index 0000000..28fbef5 --- /dev/null +++ b/Exercise_02/plot.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +import sys +from state import State +from matplotlib import pyplot as plt + +def main(argv): + path = argv[1] + with open(path) as f: + lines = f.readlines() + count = int(lines[0]) + size = float(lines[2]) + split = count + 3 + states = [lines[i:i+split] for i in range(0, len(lines), split)] + states = [State.deserialize(state) for state in states] + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + current = 0 + while True: + ax.cla() + plt.xlim([0,size]) + plt.ylim([0,size]) + plt.title(f'Step {current}') + ax.set_zlim([0,size]) + ax.scatter( + states[current].molecules[:,0], + states[current].molecules[:,1], + states[current].molecules[:,2], + c=range(count) + ) + plt.pause(0.02) + + current += 1 + current %= len(states) + + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/Exercise_02/state.py b/Exercise_02/state.py new file mode 100644 index 0000000..fbd1df3 --- /dev/null +++ b/Exercise_02/state.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from math import e +from typing import List +import numpy as np +from constants import alpha, D_e, r_e +from scipy.optimize import minimize, Bounds +from scipy.spatial import distance_matrix as dm +from datetime import datetime +from jax import numpy as jnp + +@dataclass +class State: + count: int = 0 + size: float = 1.0 + molecules: np.matrix = np.zeros((0,6)) + temperature: float = 0 + + def energy(self) -> float: + return energy(self.positions()) + + def minimize(self) -> None: + res = minimize(energy, self.positions(), method='L-BFGS-B', bounds=Bounds(0,self.size), jac='3-point') + min = res.x + min = np.matrix(min) + min.shape = (self.count,3) + self.molecules[:,0:3] = min + + @staticmethod + def deserialize(lines: List[str]) -> None: + # disregard comment line + lines.pop(1) + + count = int(lines.pop(0)) + size = float(lines.pop(0)) + lines = (line.strip() for line in lines) + lines = (line.split() for line in lines if line) + molecules = np.matrix([[float(f) for f in line] for line in lines]) + + return State(count, size, molecules) + + def serialize(self) -> str: + serialized = '' + serialized += f'{self.count}\n' + serialized += f'Written at {datetime.now()}\n' + serialized += f'{self.size}\n' + + m = np.asmatrix(self.molecules) + for i in range(m.shape[0]): + serialized += '\t'.join(['{:+10.8f}'.format(f) for f in m[i,:].tolist()[0]]) + '\n' + + return serialized + + def positions(self) -> List[float]: + return self.molecules[:,0:3].ravel() + +def energy (p: np.matrix): + P = jnp.reshape(p, (-1,3)) + V = _morse(distances(P)) + return jnp.nansum(jnp.tril(V, k=-1)) + +@jnp.vectorize +def _morse(r: float): + return D_e * (e**(-2*alpha*(r-r_e)) - 2*e**(-alpha*(r - r_e))) + +def distances(P): + P = jnp.asarray(P) + + XYZ = jnp.asarray([ + # insert new axis with None to get an outer difference matrix for each dimension + P[:, 0, None] - P[None, :, 0], + P[:, 1, None] - P[None, :, 1], + P[:, 2, None] - P[None, :, 2] + ]) + + XYZ = abs(XYZ) + + # nearest image convention + XYZ = XYZ.at[XYZ > 0.5].add(-0.5) + XYZ = XYZ**2 + + D = jnp.sum(XYZ, axis=0) + D = D.at[D <= 0].set(0.0001) + return jnp.sqrt(D) diff --git a/Exercise_02/util.py b/Exercise_02/util.py new file mode 100644 index 0000000..ab25657 --- /dev/null +++ b/Exercise_02/util.py @@ -0,0 +1,12 @@ +from state import State + +def readFile (path: str) -> State: + with open(path) as f: + lines = f.readlines() + + return State.deserialize(lines) + + +def writeFile (path: str, state: State) -> None: + with open(path, 'w') as f: + f.write(state.serialize()) \ No newline at end of file