NSSC/Exercise_02/state.py

86 lines
2.3 KiB
Python

from dataclasses import dataclass
from math import e
from typing import List
import numpy as np
from constants import alpha, D_e, r_e
from scipy.optimize import minimize, Bounds
from scipy.spatial import distance_matrix as dm
from datetime import datetime
from jax import numpy as jnp
@dataclass
class State:
count: int = 0
size: float = 1.0
molecules: np.matrix = np.zeros((0,6))
temperature: float = 0
def energy(self) -> float:
return energy(self.positions())
def minimize(self) -> None:
# report with jax gradient
res = minimize(energy, self.positions(), method='CG', jac='3-point')
min = res.x
min = np.matrix(min)
min.shape = (self.count,3)
self.molecules[:,0:3] = min
@staticmethod
def deserialize(lines: List[str]) -> None:
# disregard comment line
lines.pop(1)
count = int(lines.pop(0))
size = float(lines.pop(0))
lines = (line.strip() for line in lines)
lines = (line.split() for line in lines if line)
molecules = np.matrix([[float(f) for f in line] for line in lines])
return State(count, size, molecules)
def serialize(self) -> str:
serialized = ''
serialized += f'{self.count}\n'
serialized += f'Written at {datetime.now()}\n'
serialized += f'{self.size}\n'
m = np.asmatrix(self.molecules)
for i in range(m.shape[0]):
serialized += '\t'.join(['{:+10.8f}'.format(f) for f in m[i,:].tolist()[0]]) + '\n'
return serialized
def positions(self) -> List[float]:
return self.molecules[:,0:3].ravel()
def energy (p: np.matrix):
P = jnp.reshape(p, (-1,3))
V = _morse(distances(P))
return jnp.nansum(jnp.tril(V, k=-1))
@jnp.vectorize
def _morse(r: float):
return D_e * (e**(-2*alpha*(r-r_e)) - 2*e**(-alpha*(r - r_e)))
def distances(P):
P = jnp.asarray(P)
XYZ = jnp.asarray([
# insert new axis with None to get an outer difference matrix for each dimension
P[:, 0, None] - P[None, :, 0],
P[:, 1, None] - P[None, :, 1],
P[:, 2, None] - P[None, :, 2]
])
XYZ = abs(XYZ)
# nearest image convention
# replace with modulo
XYZ = XYZ.at[XYZ > 0.5].add(-0.5)
XYZ = XYZ**2
D = jnp.sum(XYZ, axis=0)
D = D.at[D <= 0].set(0.0001)
return jnp.sqrt(D)