"""
ML interatomic potential runners.
Two backends:
- **torchsim** (preferred): GPU-accelerated batch processing via ``torch_sim``.
Handles all structures in a single batch for maximum throughput.
- **ASE** (fallback): Uses any ASE-compatible calculator with ASE optimizers.
Processes structures one at a time; works without torchsim.
The ``torch_sim`` package is an optional dependency (``pip install airsspy[ml]``).
When unavailable, the ASE fallback is used automatically.
Model specification format (torchsim)::
backend:model_id
For example ``mace:medium``, ``mace:/path/to/model.pt``, ``sevennet:sevennet-mf-ompa``.
ASE calculator specification format::
module.path:ClassName@model
For example ``mace.calculators:MACECalculator@medium``.
"""
import importlib
import logging
import tempfile
from pathlib import Path
from typing import Optional, Union
import numpy as np
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.io import read as ase_read
from ase.io import write as ase_write
logger = logging.getLogger(__name__)
# 1 eV/Ang^3 = 160.21766208 GPa
EV_PER_ANG3_TO_GPA = 160.21766208
StructureInput = Union[str, Atoms]
def _resolve_calculator(calculator_spec: str, **kwargs):
"""Import and instantiate an ASE calculator from a spec string.
Args:
calculator_spec: Calculator specification in the format
``module.path:ClassName@model``. The ``@model`` part is optional
and is passed as the first positional argument to the constructor.
**kwargs: Additional keyword arguments passed to the constructor.
Returns:
An instance of the calculator class.
Raises:
ValueError: If the spec cannot be parsed.
"""
# Split off the @model suffix
model = None
if "@" in calculator_spec:
calculator_spec, model = calculator_spec.rsplit("@", 1)
# Split module:class
if ":" in calculator_spec:
module_path, class_name = calculator_spec.rsplit(":", 1)
else:
# Try treating the last dot-separated segment as the class
parts = calculator_spec.rsplit(".", 1)
if len(parts) == 2:
module_path, class_name = parts
else:
raise ValueError(
f"Cannot parse calculator spec: {calculator_spec!r}. "
"Use 'module.path:ClassName' or 'module.path.ClassName'."
)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
if model is not None:
return cls(model, **kwargs)
return cls(**kwargs)
def _cell_to_atoms(cell_path: str):
"""Read a .cell file into an ASE Atoms object."""
atoms = ase_read(cell_path)
atoms.pbc = True
return atoms
def _cell_content_to_atoms(cell_content: str):
"""Read CASTEP cell content into ASE Atoms without leaving a .cell file."""
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
"w", suffix=".cell", prefix="airsspy-ml-", delete=False
) as handle:
handle.write(cell_content)
tmp_path = handle.name
return _cell_to_atoms(tmp_path)
finally:
if tmp_path is not None:
Path(tmp_path).unlink(missing_ok=True)
def _structure_input_to_atoms(structure_input: StructureInput):
"""Return ASE Atoms from either cell text or an existing Atoms object."""
if isinstance(structure_input, Atoms):
atoms = structure_input.copy()
atoms.pbc = True
return atoms
return _cell_content_to_atoms(structure_input)
def _get_pressure_gpa(atoms) -> float:
"""Extract scalar pressure in GPa from an ASE Atoms with stress.
ASE stress is in eV/Ang^3 (Voigt notation).
Hydrostatic pressure = -trace(stress) / 3, converted to GPa.
"""
try:
stress = atoms.get_stress() # eV/Ang^3
pressure_ev_ang3 = -(stress[0] + stress[1] + stress[2]) / 3.0
return pressure_ev_ang3 * EV_PER_ANG3_TO_GPA
except Exception:
return 0.0
def _normalize_static_forces(forces):
"""Return forces for one static-batch structure."""
if forces is None:
return None
if getattr(forces, "ndim", 0) == 3 and forces.shape[0] == 1:
return forces[0]
return forces
def _normalize_static_stress(stress, index: int):
"""Return stress for one static-batch structure."""
if stress is None:
return None
if getattr(stress, "ndim", 0) == 3:
if stress.shape[0] == 1:
return stress[0]
return stress[index]
return stress
def _as_numpy_array(value):
"""Convert torch/array-like values from torch-sim to numpy."""
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
return np.asarray(value)
def _static_prop(props: dict, *names: str):
"""Return the first available torch-sim static property."""
for name in names:
if name in props:
return props[name]
names_text = ", ".join(names)
raise KeyError(f"torch-sim static output missing expected key(s): {names_text}")
[docs]
def compose_ml_task_doc(struct_name: str, calculator_spec: str = "") -> dict:
"""Extract results from a completed ML calculation.
Reads the ``.extxyz`` output file (with SinglePointCalculator attached),
creates an ASE Atoms object, writes a ``.res`` file, and returns a
dictionary suitable for constructing an ``AirssResultDoc``.
Args:
struct_name: Structure name (without extension).
calculator_spec: The calculator spec string (for REM metadata).
Returns:
Dictionary with energy, structure, volume, formula, etc.
"""
from pymatgen.io.ase import AseAtomsAdaptor
from ..restools import save_airss_res
extxyz_path = struct_name + ".extxyz"
atoms = ase_read(extxyz_path)
energy = None
pressure = 0.0
forces = None
if atoms.calc is not None:
try:
energy = atoms.get_potential_energy()
except Exception:
pass
try:
forces = atoms.get_forces()
except Exception:
pass
pressure = _get_pressure_gpa(atoms)
volume = atoms.get_volume()
# Compute symmetry via spglib
try:
import spglib
sg = spglib.get_spacegroup(
(
atoms.get_cell().array,
atoms.get_scaled_positions(),
atoms.get_atomic_numbers(),
),
symprec=0.1,
)
sym = sg.split()[0] if sg else "P1"
except (ImportError, Exception):
sym = "P1"
# REM lines for ML calculation
rem_lines = ["", f"ML Calculator {calculator_spec}"]
relax_status = atoms.info.get("relax_status")
if relax_status is not None:
rem_lines.append(f"ML Relax status {relax_status}")
if "relax_converged" in atoms.info:
rem_lines.append(f"ML Relax converged {bool(atoms.info['relax_converged'])}")
if atoms.info.get("relax_steps") is not None:
rem_lines.append(f"ML Relax steps {atoms.info['relax_steps']}")
rem_lines.append("")
enthalpy = energy
if energy is not None and pressure is not None and volume is not None:
enthalpy = energy + pressure * volume / EV_PER_ANG3_TO_GPA
info = {
"uid": struct_name,
"P": pressure,
"V": volume,
"H": enthalpy if enthalpy is not None else 0.0,
"nat": len(atoms),
"sym": sym,
"rem": rem_lines,
}
# Build atom annotations from forces if available
atom_annotations = None
if forces is not None:
atom_annotations = []
for f in forces:
atom_annotations.append(f"{f[0]:.6f} {f[1]:.6f} {f[2]:.6f}")
save_airss_res(
atoms,
info,
fname=struct_name + ".res",
force_write=True,
atom_annotations=atom_annotations,
)
structure = AseAtomsAdaptor.get_structure(atoms)
return {
"structure": structure,
"volume": structure.volume,
"reduced_formula": structure.reduced_formula,
"formula": structure.composition.formula.replace(" ", ""),
"natoms": len(atoms),
"label": struct_name,
"energy": energy,
"energy_per_atom": energy / len(atoms) if energy else None,
"pressure": pressure,
"total_time": None,
"relax_converged": atoms.info.get("relax_converged"),
"relax_status": atoms.info.get("relax_status"),
"relax_steps": atoms.info.get("relax_steps"),
"res_content": Path(struct_name + ".res").read_text()
if Path(struct_name + ".res").is_file()
else None,
"rem_lines": rem_lines,
}
# ---------------------------------------------------------------------------
# TorchSim batch runners (optional dependency)
# ---------------------------------------------------------------------------
[docs]
def has_torchsim() -> bool:
"""Check whether ``torch_sim`` is importable."""
try:
import torch_sim # noqa: F401
return True
except Exception as exc:
logger.debug("torch_sim is unavailable: %s", exc)
return False
[docs]
class TorchSimRunner:
"""Reusable torch-sim model context for chunked ML runs."""
def __init__(self, model_spec: str, *, device: Optional[str] = None) -> None:
import torch
self.model_spec = model_spec
self.device = (
torch.device(device)
if device
else (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
)
self.dtype = torch.float32 if self.device.type == "cuda" else torch.float64
self.model = _load_torchsim_model(
model_spec,
device=self.device,
dtype=self.dtype,
)
[docs]
def relax_batch(
self,
struct_names: list[str],
structures: list[StructureInput],
*,
max_steps: int = 300,
force_tol: float = 0.05,
optimizer: str = "fire",
cell_filter: str = "frechet",
convergence_mode: str = "force_stress",
scalar_pressure: float = 0.0,
) -> dict[str, int]:
"""Relax one batch of structures using the loaded model."""
import torch_sim as ts
results: dict[str, int] = {}
opt_map = {
"fire": ts.Optimizer.fire,
"lbfgs": ts.Optimizer.lbfgs,
"bfgs": ts.Optimizer.bfgs,
"gradient_descent": ts.Optimizer.gradient_descent,
}
opt = opt_map[optimizer.lower()]
cf_map = {
"frechet": ts.CellFilter.frechet,
"unit": ts.CellFilter.unit,
}
cf = cf_map[cell_filter.lower()]
conv_fn = _resolve_torchsim_convergence(
convergence_mode,
force_tol=force_tol,
)
atoms_list = [_structure_input_to_atoms(structure) for structure in structures]
state = ts.io.atoms_to_state(
atoms_list,
device=self.device,
dtype=self.dtype,
)
converged_state = ts.optimize(
system=state,
model=self.model,
optimizer=opt,
convergence_fn=conv_fn,
max_steps=max_steps,
init_kwargs={
"cell_filter": cf,
"scalar_pressure": scalar_pressure / EV_PER_ANG3_TO_GPA,
},
)
final_atoms_list = ts.io.state_to_atoms(converged_state)
outputs = self.model.forward(converged_state)
energy_values = outputs["energy"].detach().cpu().reshape(-1).tolist()
force_values = outputs["forces"].detach().cpu().numpy()
stress_tensor = outputs.get("stress")
stress_values = (
stress_tensor.detach().cpu().numpy() if stress_tensor is not None else None
)
force_offset = 0
for index, (name, atoms) in enumerate(zip(struct_names, final_atoms_list)):
natoms = len(atoms)
calc_kwargs: dict = {
"energy": float(energy_values[index]),
"forces": force_values[force_offset : force_offset + natoms],
}
force_offset += natoms
if stress_values is not None and index < len(stress_values):
calc_kwargs["stress"] = stress_values[index]
atoms.calc = SinglePointCalculator(atoms, **calc_kwargs)
ase_write(name + ".extxyz", atoms, format="extxyz")
results[name] = 0
return results
[docs]
def static_batch(
self,
struct_names: list[str],
structures: list[StructureInput],
) -> dict[str, int]:
"""Run one static batch using the loaded model."""
import torch_sim as ts
results: dict[str, int] = {}
atoms_list = [_structure_input_to_atoms(structure) for structure in structures]
state = ts.io.atoms_to_state(
atoms_list,
device=self.device,
dtype=self.dtype,
)
props_list = ts.static(system=state, model=self.model)
final_atoms_list = ts.io.state_to_atoms(state)
for index, (name, atoms) in enumerate(zip(struct_names, final_atoms_list)):
props = props_list[index]
energy_prop = _static_prop(props, "energy", "potential_energy")
energy = float(_as_numpy_array(energy_prop).reshape(-1)[0])
forces = _as_numpy_array(props["forces"]) if "forces" in props else None
stress = _as_numpy_array(props["stress"]) if "stress" in props else None
calc_kwargs: dict = {"energy": energy}
if forces is not None:
calc_kwargs["forces"] = _normalize_static_forces(forces)
if stress is not None:
calc_kwargs["stress"] = _normalize_static_stress(stress, index)
atoms.calc = SinglePointCalculator(atoms, **calc_kwargs)
ase_write(name + ".extxyz", atoms, format="extxyz")
results[name] = 0
return results
def _torchsim_relax_batch(
model_spec: str,
struct_names: list[str],
structures: list[StructureInput],
*,
device: Optional[str] = None,
max_steps: int = 300,
force_tol: float = 0.05,
optimizer: str = "fire",
cell_filter: str = "frechet",
convergence_mode: str = "force_stress",
scalar_pressure: float = 0.0,
) -> dict[str, int]:
"""Relax a batch of structures using torchsim."""
return TorchSimRunner(model_spec, device=device).relax_batch(
struct_names,
structures,
max_steps=max_steps,
force_tol=force_tol,
optimizer=optimizer,
cell_filter=cell_filter,
convergence_mode=convergence_mode,
scalar_pressure=scalar_pressure,
)
def _torchsim_static_batch(
model_spec: str,
struct_names: list[str],
structures: list[StructureInput],
*,
device: Optional[str] = None,
) -> dict[str, int]:
"""Run single-point calculations on a batch of structures using torchsim."""
return TorchSimRunner(model_spec, device=device).static_batch(
struct_names,
structures,
)
def _load_torchsim_model(model_spec: str, *, device, dtype):
"""Load a torchsim model from a ``backend:model_id`` spec string.
Supported backends: mace, fairchem, mattersim, sevennet, orb,
graphpes, metatomic, nequix.
"""
backend, model_id = model_spec.split(":", 1)
if backend == "mace":
from pathlib import Path as _P
from torch_sim.models.mace import MaceModel
model_path = _P(model_id)
if model_path.is_file():
return MaceModel(
model=model_path,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
from mace.calculators.foundations_models import mace_mp
raw_model = mace_mp(
model=model_id,
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
)
return MaceModel(
model=raw_model,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
if backend == "fairchem":
from torch_sim.models.fairchem import FairChemModel
return FairChemModel(
model=model_id,
device=device,
dtype=dtype,
compute_stress=True,
)
if backend == "mattersim":
from mattersim.forcefield import Potential
from torch_sim.models.mattersim import MatterSimModel
raw_model = Potential.from_checkpoint(
load_path=model_id,
model_name="m3gnet",
device=str(device),
load_training_state=False,
)
return MatterSimModel(model=raw_model, device=device, dtype=dtype)
if backend == "sevennet":
from torch_sim.models.sevennet import SevenNetModel
return SevenNetModel(model=model_id, device=device, dtype=dtype)
if backend == "orb":
from torch_sim.models.orb import OrbModel
return OrbModel(model=model_id, device=device)
if backend == "graphpes":
from torch_sim.models.graphpes_framework import GraphPESWrapper
return GraphPESWrapper(
model=model_id,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
if backend == "metatomic":
from torch_sim.models.metatomic import MetatomicModel
return MetatomicModel(
model=model_id,
device=device,
compute_forces=True,
compute_stress=True,
)
if backend == "nequix":
from torch_sim.models.nequix import NequixModel
return NequixModel(model=model_id, device=device, dtype=dtype)
raise ValueError(
f"Unsupported torchsim backend: {backend!r}. "
"Expected: mace, fairchem, mattersim, sevennet, orb, "
"graphpes, metatomic, nequix."
)
def _resolve_torchsim_convergence(mode: str, *, force_tol: float = 0.05):
"""Build a convergence function for torchsim.optimize()."""
import torch_sim as ts
if mode == "force":
return ts.generate_force_convergence_fn(force_tol=force_tol)
if mode == "force_stress":
stress_tol_ev = 1.0 / EV_PER_ANG3_TO_GPA # 1 GPa in eV/Ang^3
def convergence_fn(state, _last_energy=None):
force_conv = ts.system_wise_max_force(state) < force_tol
if state.stress is None:
raise ValueError("Stress required for force_stress convergence.")
stress_conv = state.stress.abs().amax(dim=(1, 2)) < stress_tol_ev
return force_conv & stress_conv
return convergence_fn
if mode == "energy":
return ts.generate_energy_convergence_fn()
raise ValueError(f"Unsupported convergence mode: {mode!r}")