Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions docs/source/learn/core_notebooks/dims_module.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"(dims_module)=\n",
"\n",
"# PyMC dims module"
"# Dims module"
]
},
{
Expand All @@ -17,24 +17,30 @@
"source": [
"## A short history of dims in PyMC\n",
"\n",
"PyMC introduced the ability to specify model variable `dims` in version 3.9 in June 2020 (5 years as of the time of writing). In the release notes, it was mentioned only after [14 other new features](https://github.com/pymc-devs/pymc/blob/1d00f3eb81723523968f3610e81a0c42fd96326f/RELEASE-NOTES.md?plain=1#L236), but over time it became a foundation of the library.\n",
"PyMC introduced the ability to specify model variable `dims` in version 3.9 (June 2020 5 years as of the time of writing). In the release notes it appeared only after [14 other new features](https://github.com/pymc-devs/pymc/blob/1d00f3eb81723523968f3610e81a0c42fd96326f/RELEASE-NOTES.md?plain=1#L236), but over time it has become a foundation of the library.\n",
"\n",
"It allows users to more naturally specify the dimensions of model variables with string names, and provides a \"seamless\" conversion to arviz {doc}`InferenceData <arviz:xarray_for_arviz>` objects, which have become the standard for storing and investigating results from probabilistic programming languages.\n",
"It allows users to specify the dimensions of model variables with string names, and provides a seamless conversion to ArviZ {doc}`InferenceData <arviz:xarray_for_arviz>` objects, which are the standard for storing and investigating results from probabilistic programming languages.\n",
"\n",
"However, the behavior of dims is rather limited. It can only be used to specify the shape of new random variables and label existing dimensions (e.g., in {func}`~pymc.Deterministic`). Otherwise it has no effect on the computation, unlike operations done with {class}`~arviz.InferenceData` variables, which are based on {lib}`xarray` and where dims inform array selection, alignment, and broadcasting behavior.\n",
"However, the behavior of dims is limited. It can only be used to specify the shape of new random variables and label existing dimensions (e.g., in {func}`~pymc.Deterministic`). Otherwise it has no effect on computation, unlike operations done with {class}`~arviz.InferenceData` variables, which are based on {lib}`xarray` and where dims inform array selection, alignment, and broadcasting behavior.\n",
"\n",
"As a result, in PyMC models users have to write computations that follow NumPy semantics, which often requires transpositions, reshapes, new axis (`None`) and numerical axis arguments sprinkled everywhere. It can be hard to get these right and in the end it's often hard to make sense of the written model.\n",
"As a result, in PyMC models users often have to write computations that follow NumPy semantics — with transpositions, reshapes, new axes (`None`), and numerical axis arguments sprinkled throughout. These can be hard to get right and the resulting models can be difficult to interpret.\n",
"\n",
"### Expanding the role of dims\n",
"\n",
"Now we are introducing an experimental {mod}`pymc.dims` module that allows users to define data, distributions, and math operations that respect dim semantics, following {mode}`xarray` operations **without coordinates** as closely as possible.\n",
"We are now introducing an experimental {mod}`pymc.dims` module that allows users to define data, distributions, and math operations that respect dim semantics, following {mod}`xarray` operations **without coordinates** as closely as possible.\n",
"\n",
":::{warning}The `dims` module is experimental, not exhaustively tested and the API is being iteratively worked on, and is therefore subject to changes between any two PyMC releases. We welcome users to test it and provide feedback, but we don't yet endorse its use for production.:::\n",
":::{warning}\n",
"The `dims` module is experimental, not exhaustively tested, and the API is still evolving. It may change between PyMC releases. We welcome users to try it and provide feedback, but we do not yet recommend it for production use.\n",
":::\n",
"\n",
"**Related API reference:** \n",
"- [Dims — distributions](https://www.pymc.io/projects/docs/en/stable/api/dims/distributions.html) \n",
"- [Dims — transforms](https://www.pymc.io/projects/docs/en/stable/api/dims/transforms.html) \n",
"- [Dims — math](https://www.pymc.io/projects/docs/en/stable/api/dims/math.html)\n",
"\n",
"## A simple example\n",
"\n",
"We'll start with a model written in current PyMC style, using synthetic data."
"We'll start with a model written in the current PyMC style, using synthetic data."
]
},
{
Expand Down Expand Up @@ -305,21 +311,21 @@
"id": "3985ee1b52708c0c",
"metadata": {},
"source": [
"Note we still use the same {class}`~pymc.Model` constructor, but everything else was now defined with an equivalent function or class defined in the {mod}`pymc.dims` module.\n",
"Note we still use the same {class}`~pymc.Model` constructor, but everything else is now defined with an equivalent function or class in the {mod}`pymc.dims` module.\n",
"\n",
"There are some notable differences:\n",
"\n",
"1. `ZeroSumNormal` takes a `core_dims` argument instead of `n_zerosum_axes`. This tells PyMC which of the `dims` that define the distribution are constrained to be zero-summed. All distributions that take non-scalar parameters now require a `core_dims` argument. Previously, they were assumed to be right-aligned by the user (see more in {doc}`dimensionality`). Now you don't have to worry about the order of the dimensions in your model, just their meaning!\n",
"1. [`ZeroSumNormal`](https://www.pymc.io/projects/docs/en/stable/api/dims/distributions.html#pymc.distributions.ZeroSumNormal) takes a `core_dims` argument instead of `n_zerosum_axes`. This tells PyMC which of the `dims` that define the distribution are constrained to be zero-summed. All distributions that take non-scalar parameters now require a `core_dims` argument. Previously, they were assumed to be right-aligned by the user (see {doc}`dimensionality`). Now you don't have to worry about the order of dimensions in your model, only their meaning.\n",
"\n",
"2. The `trial_preference` computation aligns dimensions for broadcasting automatically. Note we use {func}`pymc.dims.Deterministic` and not {func}`pymc.Deterministic`, which automatically propagates the `dims` to the model object.\n",
"2. The `trial_preference` computation aligns dimensions for broadcasting automatically. Note we use {func}`pymc.dims.Deterministic` (not {func}`pymc.Deterministic`), which automatically propagates the `dims` to the model object.\n",
"\n",
"3. The `softmax` operation specifies the `dim` argument, not the positional axis. Note: The parameter is called `dim` and not `core_dims` because we try to stay as close as possible to the Xarray API, which uses `dim` throughout. But we make an exception for distributions because they already have the `dims` argument.\n",
"3. The `softmax` operation specifies the `dim` argument, not the positional axis. *Note:* The parameter is called `dim` and not `core_dims` because we try to stay as close as possible to the Xarray API (which uses `dim` throughout). We make an exception for distributions because they already have the `dims` argument.\n",
"\n",
"4. The `Categorical` observed variable, like `ZeroSumNormal`, requires a `core_dims` argument to specify which dimension corresponds to the probability vector. Previously, it was necessary to place this dimension explicitly on the rightmost axis -- not any more!\n",
"4. [`Categorical`](https://www.pymc.io/projects/docs/en/stable/api/dims/distributions.html#pymc.distributions.Categorical) observed variables, like `ZeroSumNormal`, require a `core_dims` argument to specify which dimension corresponds to the probability vector. Previously, it was necessary to place this dimension explicitly on the rightmost axis not anymore!\n",
"\n",
"5. Even though dims were not specified for either `trial_preference` or `response`, PyMC automatically infers them.\n",
"5. Even when `dims` are not specified for either `trial_preference` or `response`, PyMC automatically infers them.\n",
"\n",
"The graphviz representation looks the same as before."
"The Graphviz representation looks the same as before."
]
},
{
Expand Down Expand Up @@ -492,9 +498,9 @@
"id": "e360b5f5e9e8ca1e",
"metadata": {},
"source": [
"The {mod}`pymc.dims` module functionality is built on top of the experimental {mod}`pytensor.xtensor` module in PyTensor, which is the {lib}`xarray` analogoue of the {mod}`pytensor.tensor` module you may be familiar with (see {doc}`pymc_and_pytensor`).\n",
"The {mod}`pymc.dims` module functionality is built on top of the experimental {mod}`pytensor.xtensor` module in PyTensor, which is the {lib}`xarray` analogue of the {mod}`pytensor.tensor` module you may be familiar with (see {doc}`pymc_and_pytensor`).\n",
"\n",
"Whereas regular distributions and math operations return {class}`~pytensor.tensor.TensorVariable` objects, the corresponding functions in the {mod}`pymc.dims` module returns {class}`~pytensor.xtensor.type.XTensorVariable` objects. These are very similar to {class}`~pytensor.tensor.TensorVariable`, but they have a `dims` attribute that determines their behavior."
"Whereas regular distributions and math operations return {class}`~pytensor.tensor.TensorVariable` objects, the corresponding functions in the {mod}`pymc.dims` module return {class}`~pytensor.xtensor.type.XTensorVariable` objects. These are very similar to {class}`~pytensor.tensor.TensorVariable`, but include a `dims` attribute that determines their behavior.\n"
]
},
{
Expand Down Expand Up @@ -806,7 +812,7 @@
}
},
"source": [
"This happens with `Deterministic`, `Potential` and every distribution in the `dims` module.\n",
"This happens with {func}`pymc.dims.Deterministic`, {func}`pymc.dims.Potential`, and every distribution in the {mod}`pymc.dims` module. \n",
"Any time you specify `dims`, you will get back a variable with dimensions in the same order."
]
},
Expand Down Expand Up @@ -871,7 +877,7 @@
"\n",
"#### Model constructors\n",
"\n",
"The following PyMC model constructors are available in the `dims` module.\n",
"The following PyMC model constructors are available in the {mod}`pymc.dims` module.\n",
"\n",
" * {func}`~pymc.dims.Data`\n",
" * {func}`~pymc.dims.Deterministic`\n",
Expand All @@ -885,7 +891,7 @@
"\n",
" * All vector arguments (and observed values) must have known dims. An error is raised otherwise.\n",
"\n",
" * Distributions with non-scalar inputs will require a `core_dims` argument. The meaning of the `core_dims` argument will be denoted in the docstrings of each distribution. For example, for the MvNormal, the `core_dims` are the two dimensions of the covariance matrix, one (and only one) of which must also be present in the mean parameter. The shared `core_dim` is the one that persists in the output. Sometimes the order of `core_dims` will be important!\n",
" * Distributions with non-scalar inputs will require a `core_dims` argument. The meaning of the `core_dims` argument will be denoted in the docstrings of each distribution. For example, for the [`MvNormal`](https://www.pymc.io/projects/docs/en/stable/api/dims/generated/pymc.dims.MvNormal.html#pymc.dims.MvNormal), the `core_dims` are the two dimensions of the covariance matrix, one (and only one) of which must also be present in the mean parameter. The shared `core_dim` is the one that persists in the output. Sometimes the order of `core_dims` will be important!\n",
"\n",
" * `dims` accept ellipsis, and variables are transposed to match the user-specified `dims` argument.\n",
"\n",
Expand All @@ -901,7 +907,7 @@
"\n",
"The expectation is that every {class}`xarray.DataArray` method in Xarray should have an equivalent version for XTensorVariables. So if you can do `x.diff(dim=\"a\")` in Xarray, you should be able to do `x.diff(dim=\"a\")` with XTensorVariables as well.\n",
"\n",
"In addition, many numerical operations are available in the {mod}`pymc.dims.math` module, which provides a superset of `ufuncs` functions found in Xarray (like `exp`). It also includes submodules such as `linalg` that provide counterpart to libraries like {lib}`xarray_einstats` (such as `linalg.solve`).\n",
"In addition, many numerical operations are available in the {mod}`pymc.dims.math` module, which provides a superset of `ufunc` functions found in Xarray (like `exp`). It also includes submodules such as `linalg` that provide counterpart functions to libraries like {lib}`xarray_einstats` (such as `linalg.solve`).\n",
"\n",
"Finally, functions that are available at the module level in Xarray (like `concat`) are also available in the {mod}`pymc.dims` namespace.\n",
"\n",
Expand Down Expand Up @@ -2888,7 +2894,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": null,
"id": "f994c42a-8468-4e98-96e2-eec1086cc7d4",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -3302,9 +3308,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "pymc",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "pymc"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -3316,7 +3322,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down