Skip to content

Commit 22aeb41

Browse files
committed
Warn when getting energy then forces with the ASE interface
This forces us to call the model twice, adding cost to the inference
1 parent cb25b17 commit 22aeb41

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,19 +296,19 @@ def calculate(
296296
system_changes=system_changes,
297297
)
298298

299+
ask_for_energy_gradient = (
300+
"forces" in properties or "stress" in properties or "stresses" in properties
301+
)
302+
ask_for_energy = "energy" in properties or "energies" in properties
303+
299304
# In the next few lines, we decide which properties to calculate among energy,
300305
# forces and stress. In addition to the requested properties, we calculate the
301306
# energy if any of the three is requested, as it is an intermediate step in the
302307
# calculation of the other two. We also calculate the forces if the stress is
303308
# requested, and vice-versa. The overhead for the latter operation is also
304309
# small, assuming that the majority of the model computes forces and stresses
305310
# by backward propagation as opposed to forward-mode differentiation.
306-
calculate_energy = (
307-
"energy" in properties
308-
or "energies" in properties
309-
or "forces" in properties
310-
or "stress" in properties
311-
)
311+
calculate_energy = ask_for_energy or ask_for_energy_gradient
312312
calculate_energies = "energies" in properties
313313
calculate_forces = "forces" in properties or "stress" in properties
314314
calculate_stress = "stress" in properties
@@ -318,14 +318,28 @@ def calculate(
318318
"periodic in all directions",
319319
stacklevel=2,
320320
)
321+
321322
if "forces" in properties and atoms.pbc.all():
322323
# we have PBCs, and, since the user/integrator requested forces, we will run
323324
# backward anyway, so let's do the stress as well for free (this saves
324325
# another forward-backward call later if the stress is requested)
325326
calculate_stress = True
327+
326328
if "stresses" in properties:
327329
raise NotImplementedError("'stresses' are not implemented yet")
328330

331+
if ask_for_energy_gradient and not ask_for_energy:
332+
# check if the user already computed energies in a previous call.
333+
energy = self.get_property("energy", atoms=atoms, allow_calculation=False)
334+
if energy is not None:
335+
# when requesting energy first and then forces, the strategy above will
336+
# force us to run the model twice
337+
warnings.warn(
338+
"forces or stress requested after having already computed the "
339+
"energy, this is slower than requesting the forces/stress first",
340+
stacklevel=2,
341+
)
342+
329343
with record_function("ASECalculator::prepare_inputs"):
330344
outputs = self._ase_properties_to_metatensor_outputs(properties)
331345
outputs.update(self._additional_output_requests)

python/metatomic_torch/tests/ase_calculator.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,36 @@ def check_against_ase_lj(atoms, calculator):
8989

9090
atoms.calc = calculator
9191

92-
assert np.allclose(ref.get_potential_energy(), atoms.get_potential_energy())
93-
assert np.allclose(ref.get_potential_energies(), atoms.get_potential_energies())
9492
assert np.allclose(ref.get_forces(), atoms.get_forces())
9593
assert np.allclose(ref.get_stress(), atoms.get_stress())
94+
assert np.allclose(ref.get_potential_energy(), atoms.get_potential_energy())
95+
assert np.allclose(ref.get_potential_energies(), atoms.get_potential_energies())
96+
97+
98+
def test_energy_force_order_warning(atoms, model):
99+
copy = atoms.copy()
100+
copy.calc = MetatomicCalculator(model)
101+
102+
message = (
103+
"forces or stress requested after having already computed the energy, "
104+
"this is slower than requesting the forces/stress first"
105+
)
106+
with pytest.warns(UserWarning, match=message):
107+
copy.get_potential_energy()
108+
copy.get_forces()
109+
110+
copy = atoms.copy()
111+
copy.calc = MetatomicCalculator(model)
112+
113+
with pytest.warns(UserWarning, match=message):
114+
copy.get_potential_energy()
115+
copy.get_stress()
116+
117+
# no warning
118+
atoms.calc = MetatomicCalculator(model)
119+
atoms.get_forces()
120+
atoms.get_stress()
121+
atoms.get_potential_energy()
96122

97123

98124
def test_python_model(model, model_different_units, atoms):

0 commit comments

Comments
 (0)