2022-04-10 09:17:15 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from math import e
|
|
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from constants import alpha, D_e, r_e
|
|
|
|
from scipy.optimize import minimize, Bounds
|
|
|
|
from scipy.spatial import distance_matrix as dm
|
|
|
|
from datetime import datetime
|
|
|
|
from jax import numpy as jnp
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class State:
|
|
|
|
count: int = 0
|
|
|
|
size: float = 1.0
|
|
|
|
molecules: np.matrix = np.zeros((0,6))
|
|
|
|
temperature: float = 0
|
|
|
|
|
|
|
|
def energy(self) -> float:
|
|
|
|
return energy(self.positions())
|
|
|
|
|
|
|
|
def minimize(self) -> None:
|
2022-04-10 10:12:50 +00:00
|
|
|
# report with jax gradient
|
|
|
|
res = minimize(energy, self.positions(), method='CG', jac='3-point')
|
2022-04-10 09:17:15 +00:00
|
|
|
min = res.x
|
|
|
|
min = np.matrix(min)
|
|
|
|
min.shape = (self.count,3)
|
|
|
|
self.molecules[:,0:3] = min
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def deserialize(lines: List[str]) -> None:
|
|
|
|
# disregard comment line
|
|
|
|
lines.pop(1)
|
|
|
|
|
|
|
|
count = int(lines.pop(0))
|
|
|
|
size = float(lines.pop(0))
|
|
|
|
lines = (line.strip() for line in lines)
|
|
|
|
lines = (line.split() for line in lines if line)
|
|
|
|
molecules = np.matrix([[float(f) for f in line] for line in lines])
|
|
|
|
|
|
|
|
return State(count, size, molecules)
|
|
|
|
|
|
|
|
def serialize(self) -> str:
|
|
|
|
serialized = ''
|
|
|
|
serialized += f'{self.count}\n'
|
|
|
|
serialized += f'Written at {datetime.now()}\n'
|
|
|
|
serialized += f'{self.size}\n'
|
|
|
|
|
|
|
|
m = np.asmatrix(self.molecules)
|
|
|
|
for i in range(m.shape[0]):
|
|
|
|
serialized += '\t'.join(['{:+10.8f}'.format(f) for f in m[i,:].tolist()[0]]) + '\n'
|
|
|
|
|
|
|
|
return serialized
|
|
|
|
|
|
|
|
def positions(self) -> List[float]:
|
|
|
|
return self.molecules[:,0:3].ravel()
|
|
|
|
|
|
|
|
def energy (p: np.matrix):
|
|
|
|
P = jnp.reshape(p, (-1,3))
|
|
|
|
V = _morse(distances(P))
|
|
|
|
return jnp.nansum(jnp.tril(V, k=-1))
|
|
|
|
|
|
|
|
@jnp.vectorize
|
|
|
|
def _morse(r: float):
|
|
|
|
return D_e * (e**(-2*alpha*(r-r_e)) - 2*e**(-alpha*(r - r_e)))
|
|
|
|
|
|
|
|
def distances(P):
|
|
|
|
P = jnp.asarray(P)
|
|
|
|
|
|
|
|
XYZ = jnp.asarray([
|
|
|
|
# insert new axis with None to get an outer difference matrix for each dimension
|
|
|
|
P[:, 0, None] - P[None, :, 0],
|
|
|
|
P[:, 1, None] - P[None, :, 1],
|
|
|
|
P[:, 2, None] - P[None, :, 2]
|
|
|
|
])
|
|
|
|
|
|
|
|
XYZ = abs(XYZ)
|
|
|
|
|
|
|
|
# nearest image convention
|
2022-04-10 10:12:50 +00:00
|
|
|
# replace with modulo
|
2022-04-10 09:17:15 +00:00
|
|
|
XYZ = XYZ.at[XYZ > 0.5].add(-0.5)
|
|
|
|
XYZ = XYZ**2
|
|
|
|
|
|
|
|
D = jnp.sum(XYZ, axis=0)
|
|
|
|
D = D.at[D <= 0].set(0.0001)
|
|
|
|
return jnp.sqrt(D)
|