Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user-provided CyIpopt callbacks with 13 arguments #3289

Merged
merged 14 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# We don't raise unittest.SkipTest if not cyipopt_available as there is a
# test below that tests an exception when cyipopt is unavailable.
cyipopt_ge_1_3 = hasattr(cyipopt, "CyIpoptEvaluationError")
ipopt_ge_3_14 = cyipopt.IPOPT_VERSION >= (3, 14, 0)


def create_model1():
Expand Down Expand Up @@ -326,3 +327,91 @@ def test_solve_without_objective(self):
res = solver.solve(m, tee=True)
pyo.assert_optimal_termination(res)
self.assertAlmostEqual(m.x[1].value, 9.0)

def test_solve_13arg_callback(self):
m = create_model1()

iterate_data = []

def intermediate(
nlp,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
):
x = nlp.get_primals()
y = nlp.get_duals()
iterate_data.append((x, y))

x_sol = np.array([3.85958688, 4.67936007, 3.10358931])
y_sol = np.array([-1.0, 53.90357665])

solver = pyo.SolverFactory("cyipopt", intermediate_callback=intermediate)
res = solver.solve(m, tee=True)
pyo.assert_optimal_termination(res)

# Make sure iterate vectors have the right shape and that the final
# iterate contains the primal solution we expect.
for x, y in iterate_data:
self.assertEqual(x.shape, (3,))
self.assertEqual(y.shape, (2,))
x, y = iterate_data[-1]
self.assertTrue(np.allclose(x_sol, x))
# Note that we can't assert that dual variables in the NLP are those
# at the solution because, at this point in the algorithm, the NLP
# only has access to the *previous iteration's* dual values.

# The 13-arg callback works with cyipopt < 1.3, but we will use the
# get_current_iterate method, which is only available in 1.3+ and IPOPT 3.14+
@unittest.skipIf(
not cyipopt_available or not cyipopt_ge_1_3 or not ipopt_ge_3_14,
"cyipopt version < 1.3.0",
)
def test_solve_get_current_iterate(self):
m = create_model1()

iterate_data = []

def intermediate(
nlp,
problem,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
):
iterate = problem.get_current_iterate()
x = iterate["x"]
y = iterate["mult_g"]
iterate_data.append((x, y))

x_sol = np.array([3.85958688, 4.67936007, 3.10358931])
y_sol = np.array([-1.0, 53.90357665])

solver = pyo.SolverFactory("cyipopt", intermediate_callback=intermediate)
res = solver.solve(m, tee=True)
pyo.assert_optimal_termination(res)

# Make sure iterate vectors have the right shape and that the final
# iterate contains the primal and dual solution we expect.
for x, y in iterate_data:
self.assertEqual(x.shape, (3,))
self.assertEqual(y.shape, (2,))
x, y = iterate_data[-1]
self.assertTrue(np.allclose(x_sol, x))
self.assertTrue(np.allclose(y_sol, y))
106 changes: 92 additions & 14 deletions pyomo/contrib/pynumero/interfaces/cyipopt_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
objects for the matrices (e.g., AmplNLP and PyomoNLP)
"""
import abc
import inspect

from pyomo.common.dependencies import attempt_import, numpy as np, numpy_available
from pyomo.contrib.pynumero.exceptions import PyNumeroEvaluationError
Expand Down Expand Up @@ -309,6 +310,49 @@
# cyipopt.Problem.__init__
super(CyIpoptNLP, self).__init__()

# Pre-Pyomo 6.7.4.dev0, we had no way to pass the cyipopt.Problem object
# to the user in an intermediate callback. This prevented them from calling
# the useful get_current_iterate and get_current_violations methods. Now,
# we support this by adding the Problem object to the args we pass to a user's
# callback. To preserve backwards compatibility, we inspect the user's
# callback to infer whether they want this argument. To preserve backwards
# compatibility if the user asked for variable-length *args, we do not pass
# the Problem object as an argument in this case.
Comment on lines +318 to +320
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a big fan of backwards compatibility, but in this case, I think I disagree: if the user defined the callback using *args, then it is their responsibility to track any changes to our callback API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counterpoint is that they may reasonably expect the callback API to be stable. Personally, I would rather we didn't support *args at all. Maybe we should raise a deprecation warning if a 12-arg callback or *args is provided?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counter-counterpoint is this is in contrib, so this is where we make the weakest (i.e., no) guarantee on backwards compatibility.

With so many arguments, I like the model where we pass everything by name, and deprecate all use of positional arguments. We could be clever and even allow callbacks with subsets of named arguments (which would support backwards compatibility), future-proof us to passing new arguments, and remove the current reliance on a large set of ordered arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While CyIpoptInterface is in contrib, a user can provide a callback via pyo.SolverFactory("cyipopt", intermediate_callback=callback). I've always had the impression that solvers accessible via SolverFactory (without some prefix e.g. "contrib.cyipopt"), should remain stable. To me, it's a gray area. That said, I do like the idea of only supporting named arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Robbybp how do you want to proceed with this? Should we merge it as-is to get it into the August release and open an issue to deprecate positional arguments in favor of named arguments? Or wait and modify this PR directly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this as-is to get it in for the release. I created issue #3354 to track this discussion.

# A more maintainable solution may be to force users to accept **kwds if they
# want "extra info." If we find ourselves continuing to augment this callback,
# this may be worth considering. -RBP
self._use_13arg_callback = None
if self._intermediate_callback is not None:
signature = inspect.signature(self._intermediate_callback)
positional_kinds = {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
}
positional = [
param
for param in signature.parameters.values()
if param.kind in positional_kinds
]
has_var_args = any(
p.kind is inspect.Parameter.VAR_POSITIONAL
for p in signature.parameters.values()
)

if len(positional) == 13 and not has_var_args:
# If *args is expected, we do not use the new callback
# signature.
self._use_13arg_callback = True
elif len(positional) == 12 or has_var_args:
# If *args is expected, we use the old callback signature
# for backwards compatibility.
self._use_13arg_callback = False
else:
raise ValueError(

Check warning on line 350 in pyomo/contrib/pynumero/interfaces/cyipopt_interface.py

View check run for this annotation

Codecov / codecov/patch

pyomo/contrib/pynumero/interfaces/cyipopt_interface.py#L350

Added line #L350 was not covered by tests
"Invalid intermediate callback. A function with either 12 or 13"
" positional arguments, or a variable number of arguments, is"
" expected."
)

def _set_primals_if_necessary(self, x):
if not np.array_equal(x, self._cached_x):
self._nlp.set_primals(x)
Expand Down Expand Up @@ -436,19 +480,53 @@
alpha_pr,
ls_trials,
):
"""Calls user's intermediate callback

This method has the call signature expected by CyIpopt. We then extend
this call signature to provide users of this interface class additional
functionality. Additional arguments are:

- The ``NLP`` object that was used to construct this class instance.
This is useful for querying the variables, constraints, and
derivatives during the callback.
- The class instance itself. This is useful for calling the
``get_current_iterate`` and ``get_current_violations`` methods, which
query Ipopt's internal data structures to provide this information.

"""
if self._intermediate_callback is not None:
return self._intermediate_callback(
self._nlp,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
if self._use_13arg_callback:
# This is the callback signature expected as of Pyomo 6.7.4.dev0
return self._intermediate_callback(
self._nlp,
self,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
else:
# This is the callback signature expected pre-Pyomo 6.7.4.dev0 and
# is supported for backwards compatibility.
return self._intermediate_callback(
self._nlp,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)
return True
66 changes: 64 additions & 2 deletions pyomo/contrib/pynumero/interfaces/tests/test_cyipopt_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ def hessian(self, x, y, obj_factor):
problem.solve(x0)


def _get_model_nlp_interface(halt_on_evaluation_error=None):
def _get_model_nlp_interface(halt_on_evaluation_error=None, intermediate_callback=None):
m = pyo.ConcreteModel()
m.x = pyo.Var([1, 2, 3], initialize=1.0)
m.obj = pyo.Objective(expr=m.x[1] * pyo.sqrt(m.x[2]) + m.x[1] * m.x[3])
m.eq1 = pyo.Constraint(expr=m.x[1] * pyo.sqrt(m.x[2]) == 1.0)
nlp = PyomoNLP(m)
interface = CyIpoptNLP(nlp, halt_on_evaluation_error=halt_on_evaluation_error)
interface = CyIpoptNLP(
nlp,
halt_on_evaluation_error=halt_on_evaluation_error,
intermediate_callback=intermediate_callback,
)
bad_primals = np.array([1.0, -2.0, 3.0])
indices = nlp.get_primal_indices([m.x[1], m.x[2], m.x[3]])
bad_primals = bad_primals[indices]
Expand Down Expand Up @@ -219,6 +223,64 @@ def test_error_in_hessian_halt(self):
with self.assertRaisesRegex(PyNumeroEvaluationError, msg):
interface.hessian(bad_x, [1.0], 0.0)

def test_intermediate_12arg(self):
iterate_data = []

def intermediate(
nlp,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
):
self.assertIsInstance(nlp, PyomoNLP)
iterate_data.append((inf_pr, inf_du))

m, nlp, interface, bad_x = _get_model_nlp_interface(
intermediate_callback=intermediate
)
# The interface's callback is always called with 11 arguments (by CyIpopt/Ipopt)
# but we add the NLP object to the arguments.
interface.intermediate(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)
self.assertEqual(iterate_data, [(4, 5)])

def test_intermediate_13arg(self):
iterate_data = []

def intermediate(
nlp,
problem,
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
):
self.assertIsInstance(nlp, PyomoNLP)
self.assertIsInstance(problem, cyipopt.Problem)
iterate_data.append((inf_pr, inf_du))

m, nlp, interface, bad_x = _get_model_nlp_interface(
intermediate_callback=intermediate
)
# The interface's callback is always called with 11 arguments (by CyIpopt/Ipopt)
# but we add the NLP object *and the cyipopt.Problem object* to the arguments.
interface.intermediate(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)
self.assertEqual(iterate_data, [(4, 5)])


if __name__ == "__main__":
unittest.main()
Loading