Skip to content

Commit

Permalink
add mattersim calculator fixes stfc#425
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena committed Feb 14, 2025
1 parent f3d7f66 commit a8da9a9
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CorrelationKwargs(TypedDict, total=True):

# Janus specific
Architectures = Literal[
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet"
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "mattersim"
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]
Expand Down
11 changes: 11 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ def choose_calculator(
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)

elif arch == "mattersim":
from mattersim.forcefield import MatterSimCalculator
from mattersim import __version__

if isinstance(model_path, Path):
model_path = str(model_path)
elif not isinstance(model_path, str):
model_path = "mattersim-v1.0.0-5M"

calculator = MatterSimCalculator(load_path=model_path, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {arch=}. Suported architectures "
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ m3gnet = [
sevennet = [
"sevenn == 0.10.3",
]
mattersim = [
"mattersim == 1.1.1",
]
all = [
"janus-core[alignn]",
"janus-core[chgnet]",
"janus-core[m3gnet]",
"janus-core[sevennet]",
"janus-core[mattersim]",
]

[project.scripts]
Expand Down
1 change: 1 addition & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_invalid_device(arch):
("sevennet", "cpu", {"model_path": SEVENNET_PATH}),
("sevennet", "cpu", {}),
("sevennet", "cpu", {"model": "sevennet-0"}),
("mattersim", "cpu", {}),
],
)
def test_extra_mlips(arch, device, kwargs):
Expand Down
1 change: 1 addition & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def test_mlips(arch, device, expected_energy):
("sevennet", "cpu", -27.061979293823242, {"model_path": SEVENNET_PATH}),
("sevennet", "cpu", -27.061979293823242, {}),
("sevennet", "cpu", -27.061979293823242, {"model_path": "SevenNet-0_11July2024"}),
("mattersim", "cpu", -27.06208038330078, {}),
]


Expand Down

0 comments on commit a8da9a9

Please sign in to comment.