"""Lossless conversion between AIRSS .res and extended XYZ formats.
Supports round-tripping all data including forces (stored as extra
columns 8-10 on atom lines in .res), per-atom spins, REM metadata,
and all TITL fields. Output .res files are fully compatible with
cryan and other AIRSS tools.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.geometry import cell_to_cellpar
from ase.io import read as ase_read
from ase.io import write as ase_write
from .restools import RESFile, _get_res_lines, parse_titl
# ---------------------------------------------------------------------------
# Force parsing helpers
# ---------------------------------------------------------------------------
def _parse_res_forces(lines: list[str]) -> list[list[float]] | None:
"""Parse force columns from RES atom lines.
Atom line format: Symbol index x y z occ [spin] [fx fy fz]
Base columns (6): Symbol(0) index(1) x(2) y(3) z(4) occ(5)
Optional spin(6), then forces(7,8,9) — or if no spin, forces(6,7,8).
We detect forces by checking if there are 3+ extra columns beyond base
(for no-spin case: 9 tokens → forces at 6,7,8) or 4+ extra (with spin:
10 tokens → forces at 7,8,9).
Returns None if no force columns are present.
"""
# First pass: determine if spin column is present
has_spin = False
for line in lines:
tokens = line.split()
if not tokens or tokens[0] != "SFAC":
continue
# Look at the next atom line to determine column count
break
for line in lines:
tokens = line.split()
if not tokens:
continue
if tokens[0] == "SFAC":
# Check first atom line after SFAC for column count
continue
if tokens[0] in ("TITL", "CELL", "LATT", "REM", "END"):
continue
if tokens[0] and tokens[0][0].isalpha() and len(tokens) > 6:
# Check if column 6 looks like a spin value (small number)
# vs a force value. If we have 10 tokens, col 6 is spin, 7-9 are forces.
# If 9 tokens, col 6 is first force component.
if len(tokens) >= 10:
has_spin = True
break
# Second pass: parse forces
forces: list[list[float]] = []
in_sfac = False
found_forces = False
force_start = 7 if has_spin else 6 # index of first force component
for line in lines:
tokens = line.split()
if not tokens:
continue
if tokens[0] == "SFAC":
in_sfac = True
continue
if tokens[0] == "END":
in_sfac = False
continue
if tokens[0] in ("TITL", "CELL", "LATT", "REM"):
continue
if in_sfac and tokens[0] and tokens[0][0].isalpha():
n_extra = len(tokens) - 6 # columns beyond base (symbol..occ)
# With spin: n_extra=1 (spin only), 4 (spin+forces)
# Without spin: n_extra=0 (base only), 3 (forces only)
expected_for_force = 4 if has_spin else 3
if n_extra >= expected_for_force:
try:
fx = float(tokens[force_start])
fy = float(tokens[force_start + 1])
fz = float(tokens[force_start + 2])
forces.append([fx, fy, fz])
found_forces = True
except (ValueError, IndexError):
forces.append([0.0, 0.0, 0.0])
else:
forces.append([0.0, 0.0, 0.0])
return forces if found_forces else None
# ---------------------------------------------------------------------------
# RES → extxyz
# ---------------------------------------------------------------------------
def _resfile_to_atoms(res: RESFile) -> Atoms:
"""Convert a RESFile to an ASE Atoms object with full metadata."""
if res.structure is None:
raise ValueError(f"RESFile '{res.label}' has no structure loaded")
atoms = res.atoms
if atoms is None:
raise ValueError(f"RESFile '{res.label}' could not produce Atoms")
# Structure-level info
atoms.info["label"] = res.label or "Unknown"
atoms.info["pressure"] = res.pressure
atoms.info["spin"] = res.spin
atoms.info["spin_abs"] = res.spin_abs
atoms.info["symm"] = res.symm or "-"
# Copies from raw TITL line
if res.lines:
for line in res.lines:
if "TITL" in line:
try:
ti = parse_titl(line)
atoms.info["copies"] = int(ti.flag3)
except (ValueError, IndexError):
atoms.info["copies"] = 1
break
else:
atoms.info["copies"] = 1
# REM lines
rem = res.rem
if rem:
atoms.info["rem"] = rem
# Per-atom spins
spins = res.spins
if spins:
atoms.set_initial_magnetic_moments(spins)
# Forces from raw lines
if res.lines:
forces = _parse_res_forces(res.lines)
if forces is not None:
forces_arr = np.array(forces)
calc = SinglePointCalculator(atoms, energy=res.enthalpy, forces=forces_arr)
atoms.calc = calc
return atoms
# Energy only (no forces)
if res.enthalpy is not None:
calc = SinglePointCalculator(atoms, energy=res.enthalpy)
atoms.calc = calc
return atoms
[docs]
def res_to_extxyz(res_path: str | Path, extxyz_path: str | Path) -> int:
"""Convert a (packed) .res file to extxyz.
Returns the number of structures converted.
"""
res_path = Path(res_path)
extxyz_path = Path(extxyz_path)
res_objs = RESFile.from_packed(str(res_path), include_structure=True)
atoms_list = [_resfile_to_atoms(r) for r in res_objs]
ase_write(str(extxyz_path), atoms_list, format="extxyz")
return len(atoms_list)
# ---------------------------------------------------------------------------
# extxyz → RES
# ---------------------------------------------------------------------------
def _extract_energy(atoms: Atoms) -> float:
"""Extract energy from calculator or info dict."""
if atoms.calc is not None:
try:
return float(atoms.get_potential_energy())
except Exception:
pass
return float(atoms.info.get("energy", 0.0))
def _extract_forces(atoms: Atoms) -> np.ndarray | None:
"""Extract forces from calculator or arrays."""
if atoms.calc is not None:
try:
return np.array(atoms.get_forces())
except Exception:
pass
if "forces" in atoms.arrays:
return np.array(atoms.arrays["forces"])
return None
def _atoms_to_res_lines(atoms: Atoms) -> list[str]:
"""Convert an ASE Atoms object to RES lines (with optional force columns)."""
natoms = len(atoms)
species = atoms.get_chemical_symbols()
scaled = atoms.get_scaled_positions()
cellpar = cell_to_cellpar(atoms.cell)
# Extract metadata
label = str(atoms.info.get("label", atoms.info.get("name", "Unknown")))
pressure = float(atoms.info.get("pressure", 0.0))
energy = _extract_energy(atoms)
volume = atoms.get_volume()
spin = float(atoms.info.get("spin", 0.0))
spin_abs = float(atoms.info.get("spin_abs", 0.0))
symm = str(atoms.info.get("symm", atoms.info.get("spacegroup", "-")))
copies = int(atoms.info.get("copies", 1))
# Per-atom spins
spins = None
if atoms.has("initial_magmoms"):
magmoms = atoms.get_initial_magnetic_moments()
if np.any(magmoms != 0):
spins = [float(m) for m in magmoms]
# REM lines
rem_lines = atoms.info.get("rem")
if isinstance(rem_lines, str):
rem_lines = [rem_lines]
# Build TITL list matching _get_res_lines format
# Format: "{} {:.3f} {:.3f} {:.4f} {:.2f} {:.2f} {} ({}) {} {} {}"
# symm goes into "({})" so pass bare name
symm_bare = symm.strip("()")
titl_list = [
label,
pressure,
volume,
energy,
spin,
spin_abs,
natoms,
symm_bare,
"n",
"-",
str(copies),
]
# Get base RES lines (without forces)
lines = _get_res_lines(
titl_list, species, scaled.tolist(), cellpar, rem_lines, spins
)
# Append force columns if present
forces = _extract_forces(atoms)
if forces is not None:
# Find atom lines (between SFAC and END) and append forces
new_lines: list[str] = []
atom_idx = 0
in_atoms = False
for line in lines:
if line.startswith("SFAC"):
in_atoms = True
new_lines.append(line)
continue
if line.strip() == "END":
in_atoms = False
new_lines.append(line)
continue
if in_atoms and line and line[0].isalpha():
# Atom line — append forces
if atom_idx < len(forces):
fx, fy, fz = forces[atom_idx]
line = f"{line} {fx:>12.6f} {fy:>12.6f} {fz:>12.6f}"
atom_idx += 1
new_lines.append(line)
lines = new_lines
return lines
[docs]
def extxyz_to_res(extxyz_path: str | Path, output_dir: str | Path) -> int:
"""Convert an extxyz file to individual .res files (one per structure).
Each structure is written to ``<output_dir>/<label>.res``.
If output_dir doesn't exist, it is created.
Returns the number of structures converted.
"""
extxyz_path = Path(extxyz_path)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
try:
atoms_list = ase_read(str(extxyz_path), index=":")
except Exception:
atoms_list = [ase_read(str(extxyz_path))]
for i, atoms in enumerate(atoms_list):
res_lines = _atoms_to_res_lines(atoms)
label = str(atoms.info.get("label", atoms.info.get("name", f"struct_{i:04d}")))
# Sanitize label for filename
safe_label = label.replace("/", "_").replace(" ", "_")
res_file = output_dir / f"{safe_label}.res"
res_file.write_text("\n".join(res_lines) + "\n")
return len(atoms_list)
# ---------------------------------------------------------------------------
# Extract a single structure by label
# ---------------------------------------------------------------------------
def _find_in_extxyz(source_path: Path, label: str) -> Atoms | None:
"""Find a structure by label in an extxyz file."""
try:
atoms_list = ase_read(str(source_path), index=":")
except Exception:
atoms_list = [ase_read(str(source_path))]
for atoms in atoms_list:
al = str(atoms.info.get("label", atoms.info.get("name", "")))
if al == label:
return atoms
return None
def _find_in_packed_res(source_path: Path, label: str) -> RESFile | None:
"""Find a structure by label in a packed .res file.
Uses fast TITL-only parsing first to avoid loading full structures
for non-matching entries.
"""
res_objs = RESFile.from_packed(
str(source_path), include_structure=False, only_titl=True
)
target_idx = None
for i, res in enumerate(res_objs):
if res.label == label:
target_idx = i
break
if target_idx is None:
return None
# Reload only the matching structure with full data
return RESFile.from_packed(str(source_path), include_structure=True)[target_idx]