Skip to content

Commit

Permalink
Tidy neb (#438)
Browse files Browse the repository at this point in the history
* Refactor NEB

* Fix setting NEB interpolator when passing band

* Rename NEB interpolate function

* Fix NEB optimize when using band

* Add NEB tests
  • Loading branch information
ElliottKasoar authored Feb 20, 2025
1 parent 04e544b commit bb587cc
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 33 deletions.
77 changes: 44 additions & 33 deletions janus_core/calculations/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(

# Identify whether interpolating
if band_structs or band_path:
self.inerpolator = None
self.interpolator = None
if init_struct or init_struct_path or final_struct or final_struct_path:
raise ValueError(
"Band cannot be specified in combination with an initial or final "
Expand All @@ -293,7 +293,7 @@ def __init__(
# Read all image by default for band
read_kwargs.setdefault("index", ":")
else:
if interpolator is None:
if self.interpolator is None:
raise ValueError(
"An interpolator must be specified when using an initial and final "
"structure"
Expand All @@ -320,7 +320,7 @@ def __init__(
file_prefix=file_prefix,
)

if interpolator:
if self.interpolator:
if not isinstance(self.struct, Atoms):
raise ValueError("`init_struct` must be a single structure.")
if not self.struct.calc:
Expand Down Expand Up @@ -426,6 +426,9 @@ def plot(self) -> Figure | None:
Figure | None
Plotted NEB band.
"""
if not hasattr(self.neb, "nebtools"):
self.run_nebtools()

if self.plot_band:
fig = self.nebtools.plot_band()
fig.savefig(self.plot_file)
Expand All @@ -434,7 +437,7 @@ def plot(self) -> Figure | None:

return fig

def set_interpolator(self) -> None:
def interpolate(self) -> None:
"""Interpolate images to create initial band."""
match self.interpolator:
case "ase":
Expand Down Expand Up @@ -482,6 +485,40 @@ def set_interpolator(self) -> None:
case _:
raise ValueError("Invalid interpolator selected")

def optimize(self):
"""Run NEB optimization."""
if not hasattr(self, "neb"):
self.interpolate()

optimizer = self.optimizer(self.neb, **self.optimizer_kwargs)
optimizer.run(fmax=self.fmax, steps=self.steps)
if self.logger:
self.logger.info("Optimization steps: %s", optimizer.nsteps)

# Optionally write band images to file
output_structs(
images=self.images,
struct_path=self.struct_path,
write_results=self.write_band,
write_kwargs=self.write_kwargs,
)

def run_nebtools(self):
"""Run NEBTools analysis."""
self.nebtools = NEBTools(self.images[1:-1])
barrier, delta_E = self.nebtools.get_barrier() # noqa: N806
max_force = self.nebtools.get_fmax()
self.results = {
"barrier": barrier,
"delta_E": delta_E,
"max_force": max_force,
}

if self.write_results:
with open(self.results_file, "w", encoding="utf8") as out:
print("#Barrier [eV] | delta E [eV] | Max force [eV/Å] ", file=out)
print(*self.results.values(), file=out)

def run(self) -> dict[str, float]:
"""
Run Nudged Elastic Band method.
Expand All @@ -507,30 +544,9 @@ def run(self) -> dict[str, float]:
)
GeomOpt(self.final_struct, **self.minimize_kwargs).run()

self.set_interpolator()

optimizer = self.optimizer(self.neb, **self.optimizer_kwargs)
optimizer.run(fmax=self.fmax, steps=self.steps)
if self.logger:
self.logger.info("Optimization steps: %s", optimizer.nsteps)

# Optionally write band images to file
output_structs(
images=self.images,
struct_path=self.struct_path,
write_results=self.write_band,
write_kwargs=self.write_kwargs,
)

self.nebtools = NEBTools(self.images[1:-1])
barrier, delta_E = self.nebtools.get_barrier() # noqa: N806
max_force = self.nebtools.get_fmax()
self.results = {
"barrier": barrier,
"delta_E": delta_E,
"max_force": max_force,
}

self.interpolate()
self.optimize()
self.run_nebtools()
self.plot()

if self.logger:
Expand All @@ -544,9 +560,4 @@ def run(self) -> dict[str, float]:
self.struct.info["emissions"] = emissions
self.tracker.stop()

if self.write_results:
with open(self.results_file, "w", encoding="utf8") as out:
print("#Barrier [eV] | delta E [eV] | Max force [eV/Å] ", file=out)
print(*self.results.values(), file=out)

return self.results
45 changes: 45 additions & 0 deletions tests/test_neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,48 @@ def test_set_calc(tmp_path, LFPO_start_b, LFPO_end_b):
assert neb.results["barrier"] == pytest.approx(7817.150960456944)
assert neb.results["delta_E"] == pytest.approx(78.34034136421758)
assert neb.results["max_force"] == pytest.approx(148695.846153840771)


def test_neb_functions(tmp_path, LFPO_start_b, LFPO_end_b):
"""Test individual NEB functions."""
file_prefix = tmp_path / "LFPO"

neb = NEB(
init_struct=LFPO_start_b,
final_struct=LFPO_end_b,
arch="mace",
model_path=MODEL_PATH,
n_images=5,
interpolator="ase",
file_prefix=file_prefix,
)
neb.interpolate()
neb.optimize()
neb.run_nebtools()

assert len(neb.images) == 7
assert all(key in neb.results for key in ("barrier", "delta_E", "max_force"))
assert neb.results["barrier"] == pytest.approx(7817.150960456944)
assert neb.results["delta_E"] == pytest.approx(78.34034136421758)
assert neb.results["max_force"] == pytest.approx(148695.846153840771)


def test_neb_plot(tmp_path):
"""Test plotting NEB before running NEBTools."""
file_prefix = tmp_path / "LFPO"

neb = NEB(
band_path=DATA_PATH / "LiFePO4-neb-band.xyz",
arch="mace",
model_path=MODEL_PATH,
steps=2,
file_prefix=file_prefix,
)
neb.optimize()
neb.plot()

assert len(neb.images) == 7
assert all(key in neb.results for key in ("barrier", "delta_E", "max_force"))
assert neb.results["barrier"] == pytest.approx(0.67567742247752)
assert neb.results["delta_E"] == pytest.approx(5.002693796996027e-07)
assert neb.results["max_force"] == pytest.approx(1.5425684122118983)

0 comments on commit bb587cc

Please sign in to comment.