Skip to content

optimize.root raises error computing gradients with respect to inputs when no args are found #1466

Open
@jessegrabowski

Description

@jessegrabowski

Description

The following graph should be valid:

import pytensor
import pytensor.tensor as pt

x, y = variables = pt.tensor('variables', shape=(2, ))

eq_1 = x ** 2 - y - 1
eq_2 = x - y ** 2 + 1
solution, success = pt.optimize.root(equations=pt.stack([eq_1, eq_2]), 
                       variables=variables,
                       method='hybr',
                       optimizer_kwargs={'tol':1e-8})
pt.grad(solution.sum(), variables)

But we get an error because the gradient expression assumes there will always be args. We just need a check that args is None and quickly return pt.zeros_like(x) if so.

Error Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[21], line 1
----> 1 pt.grad(solution.sum(), variables)

File ~/Documents/Python/pytensor/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    744     if hasattr(g.type, "dtype"):
    745         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
    748     var_to_app_to_idx, grad_dict, _wrt, cost_name
    749 )
    751 rval: MutableSequence[Variable | None] = list(_rval)
    753 for i in range(len(_rval)):

File ~/Documents/Python/pytensor/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1538     # end if cache miss
   1539     return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
   1543 return rval

File ~/Documents/Python/pytensor/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1494 for node in node_to_idx:
   1495     for idx in node_to_idx[node]:
-> 1496         term = access_term_cache(node)[idx]
   1498         if not isinstance(term, Variable):
   1499             raise TypeError(
   1500                 f"{node.op}.grad returned {type(term)}, expected"
   1501                 " Variable instance."
   1502             )

File ~/Documents/Python/pytensor/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
   1318         if o_shape != g_shape:
   1319             raise ValueError(
   1320                 "Got a gradient of shape "
   1321                 + str(o_shape)
   1322                 + " on an output of shape "
   1323                 + str(g_shape)
   1324             )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1328 if input_grads is None:
   1329     raise TypeError(
   1330         f"{node.op}.grad returned NoneType, expected iterable."
   1331     )

File ~/Documents/Python/pytensor/pytensor/tensor/optimize.py:801, in RootOp.L_op(self, inputs, outputs, output_grads)
    798 df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1]
    799 df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore")
--> 801 grad_wrt_args = implict_optimization_grads(
    802     df_dx=df_dx,
    803     df_dtheta_columns=df_dtheta_columns,
    804     args=args,
    805     x_star=x_star,
    806     output_grad=output_grad,
    807     fgraph=self.fgraph,
    808 )
    810 return [zeros_like(x), *grad_wrt_args]

File ~/Documents/Python/pytensor/pytensor/tensor/optimize.py:334, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
    289 r"""
    290 Compute gradients of an optimization problem with respect to its parameters.
    291 
   (...)    328     The function graph that contains the inputs and outputs of the optimization problem.
    329 """
    330 df_dx = cast(TensorVariable, df_dx)
    332 df_dtheta = concatenate(
    333     [
--> 334         atleast_2d(jac_col, left=False)
    335         for jac_col in cast(list[TensorVariable], df_dtheta_columns)
    336     ],
    337     axis=-1,
    338 )
    340 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
    342 df_dx_star, df_dtheta_star = cast(
    343     list[TensorVariable],
    344     graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
    345 )

File ~/Documents/Python/pytensor/pytensor/tensor/basic.py:4440, in atleast_Nd(arry, n, left)
   4435 def atleast_Nd(
   4436     arry: np.ndarray | TensorVariable, *, n: int = 1, left: bool = True
   4437 ) -> TensorVariable:
   4438     """Convert input to an array with at least `n` dimensions."""
-> 4440     arry = as_tensor(arry)
   4442     if arry.ndim >= n:
   4443         result = arry

File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:50, in as_tensor_variable(x, name, ndim, **kwargs)
     18 def as_tensor_variable(
     19     x: TensorLike, name: str | None = None, ndim: int | None = None, **kwargs
     20 ) -> "TensorVariable":
     21     """Convert `x` into an equivalent `TensorVariable`.
     22 
     23     This function can be used to turn ndarrays, numbers, `ScalarType` instances,
   (...)     48 
     49     """
---> 50     return _as_tensor_variable(x, name, ndim, **kwargs)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/tensor/__init__.py:57, in _as_tensor_variable(x, name, ndim, **kwargs)
     53 @singledispatch
     54 def _as_tensor_variable(
     55     x: TensorLike, name: str | None, ndim: int | None, **kwargs
     56 ) -> "TensorVariable":
---> 57     raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")

NotImplementedError: Cannot convert None to a tensor variable.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions