from dataclasses import dataclass from math import e from typing import List import numpy as np from constants import alpha, D_e, r_e from scipy.optimize import minimize, Bounds from scipy.spatial import distance_matrix as dm from datetime import datetime from jax import numpy as jnp @dataclass class State: count: int = 0 size: float = 1.0 molecules: np.matrix = np.zeros((0,6)) temperature: float = 0 def energy(self) -> float: return energy(self.positions()) def minimize(self) -> None: # report with jax gradient res = minimize(energy, self.positions(), method='CG', jac='3-point') min = res.x min = np.matrix(min) min.shape = (self.count,3) self.molecules[:,0:3] = min @staticmethod def deserialize(lines: List[str]) -> None: # disregard comment line lines.pop(1) count = int(lines.pop(0)) size = float(lines.pop(0)) lines = (line.strip() for line in lines) lines = (line.split() for line in lines if line) molecules = np.matrix([[float(f) for f in line] for line in lines]) return State(count, size, molecules) def serialize(self) -> str: serialized = '' serialized += f'{self.count}\n' serialized += f'Written at {datetime.now()}\n' serialized += f'{self.size}\n' m = np.asmatrix(self.molecules) for i in range(m.shape[0]): serialized += '\t'.join(['{:+10.8f}'.format(f) for f in m[i,:].tolist()[0]]) + '\n' return serialized def positions(self) -> List[float]: return self.molecules[:,0:3].ravel() def energy (p: np.matrix): P = jnp.reshape(p, (-1,3)) V = _morse(distances(P)) return jnp.nansum(jnp.tril(V, k=-1)) @jnp.vectorize def _morse(r: float): return D_e * (e**(-2*alpha*(r-r_e)) - 2*e**(-alpha*(r - r_e))) def distances(P): P = jnp.asarray(P) XYZ = jnp.asarray([ # insert new axis with None to get an outer difference matrix for each dimension P[:, 0, None] - P[None, :, 0], P[:, 1, None] - P[None, :, 1], P[:, 2, None] - P[None, :, 2] ]) XYZ = abs(XYZ) # nearest image convention # replace with modulo XYZ = XYZ.at[XYZ > 0.5].add(-0.5) XYZ = XYZ**2 D = jnp.sum(XYZ, axis=0) D = D.at[D <= 0].set(0.0001) return jnp.sqrt(D)