Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __set_compiler_flags():
from pymc.printing import *
from pymc.pytensorf import *
from pymc.sampling import *
from pymc.sampling import external
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
Expand Down
15 changes: 15 additions & 0 deletions pymc/sampling/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc.sampling.external.jax import Blackjax, Numpyro
from pymc.sampling.external.nutpie import Nutpie
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from pymc.sampling.external.nutpie import Nutpie
from pymc.sampling.external.nutpie import Nutpie
__all__ = ["Blackjax", "Numpyro", "Nutpie"]

You probably don't need to export the ABC right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not importing it anymore

51 changes: 51 additions & 0 deletions pymc/sampling/external/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any

from pytensor.scalar import discrete_dtypes

from pymc.model.core import modelcontext
from pymc.util import RandomSeed


class ExternalSampler(ABC):
def __init__(self, model=None):
model = modelcontext(model)
self.model = model

@abstractmethod
def sample(
self,
*,
tune: int,
draws: int,
chains: int,
initvals: dict[str, Any] | Sequence[dict[str, Any]],
random_seed: RandomSeed,
progressbar: bool,
var_names: Sequence[str] | None = None,
idata_kwargs: dict[str, Any] | None = None,
compute_convergence_checks: bool,
**kwargs,
):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass
raise NotImplementedError

make it fail fast

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's an abstractmethod, so the classes can't be instantiated if not implemented

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but someone could (mistakenly) implement like this, which would fail silently:

class baseclass(ABC):

    @abstractmethod
    def requiredmethod(self):
        pass


class derivedclass(baseclass):

    def requiredmethod(self):
        return super().mymethod()

Just a defensive programming thing I guess



class NUTSExternalSampler(ExternalSampler):
def __init__(self, model=None):
super().__init__(model)
if any(var.dtype in discrete_dtypes for var in model.free_RVs):
raise ValueError("External NUTS samplers can only sample continuous variables")
86 changes: 86 additions & 0 deletions pymc/sampling/external/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Literal

from arviz import InferenceData

from pymc.sampling.external.base import NUTSExternalSampler
from pymc.util import RandomState


class JAXNUTSSampler(NUTSExternalSampler):
nuts_sampler: Literal["numpyro", "blackjax"]

def __init__(
self,
model=None,
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
chain_method: Literal["parallel", "vectorized"] = "parallel",
jitter: bool = True,
keep_untransformed: bool = False,
nuts_kwargs: dict | None = None,
):
super().__init__(model)
self.postprocessing_backend = postprocessing_backend
self.chain_method = chain_method
self.jitter = jitter
self.keep_untransformed = keep_untransformed
self.nuts_kwargs = nuts_kwargs or {}

def sample(
self,
*,
tune: int = 1000,
draws: int = 1000,
chains: int = 4,
initvals=None,
random_seed: RandomState | None = None,
progressbar: bool = True,
var_names: Sequence[str] | None = None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
target_accept: float = 0.8,
**kwargs,
) -> InferenceData:
from pymc.sampling.jax import sample_jax_nuts

return sample_jax_nuts(
tune=tune,
draws=draws,
chains=chains,
target_accept=target_accept,
random_seed=random_seed,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
initvals=initvals,
jitter=self.jitter,
model=self.model,
chain_method=self.chain_method,
postprocessing_backend=self.postprocessing_backend,
keep_untransformed=self.keep_untransformed,
nuts_kwargs=self.nuts_kwargs,
nuts_sampler=self.nuts_sampler,
**kwargs,
)


class Numpyro(JAXNUTSSampler):
nuts_sampler = "numpyro"


class Blackjax(JAXNUTSSampler):
nuts_sampler = "blackjax"
146 changes: 146 additions & 0 deletions pymc/sampling/external/nutpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from arviz import InferenceData, dict_to_dataset

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
from pymc.sampling.external.base import NUTSExternalSampler
from pymc.stats.convergence import log_warnings, run_convergence_checks
from pymc.util import _get_seeds_per_chain


class Nutpie(NUTSExternalSampler):
def __init__(
self,
model=None,
backend="numba",
gradient_backend="pytensor",
compile_kwargs=None,
sample_kwargs=None,
):
super().__init__(model)
self.backend = backend
self.gradient_backend = gradient_backend
self.compile_kwargs = compile_kwargs or {}
self.sample_kwargs = sample_kwargs or {}
self.compiled_model = None

def sample(
self,
*,
tune,
draws,
chains,
initvals,
random_seed,
progressbar,
var_names,
idata_kwargs,
compute_convergence_checks,
**kwargs,
):
try:
import nutpie
except ImportError as err:
raise ImportError(
"nutpie not found. Install it with conda install -c conda-forge nutpie"
) from err

from nutpie.sample import _BackgroundSampler

if initvals:
warnings.warn(
"initvals are currently ignored by the nutpie sampler.",
UserWarning,
)
if idata_kwargs:
warnings.warn(
"idata_kwargs are currently ignored by the nutpie sampler.",
UserWarning,
)

compiled_model = nutpie.compile_pymc_model(
self.model,
var_names=var_names,
backend=self.backend,
gradient_backend=self.gradient_backend,
**self.compile_kwargs,
)

result = nutpie.sample(
compiled_model,
tune=tune,
draws=draws,
chains=chains,
seed=_get_seeds_per_chain(random_seed, 1)[0],
progress_bar=progressbar,
**self.sample_kwargs,
**kwargs,
)
if isinstance(result, _BackgroundSampler):
# Wrap _BackgroundSampler so that when sampling is finished we run post_process_sampler
class NutpieBackgroundSamplerWrapper(_BackgroundSampler):
def __init__(self, *args, pymc_model, compute_convergence_checks, **kwargs):
self.pymc_model = pymc_model
self.compute_convergence_checks = compute_convergence_checks
super().__init__(*args, **kwargs, return_raw_trace=False)

def _extract(self, *args, **kwargs):
idata = super()._extract(*args, **kwargs)
return Nutpie._post_process_sample(
model=self.pymc_model,
idata=idata,
compute_convergence_checks=self.compute_convergence_checks,
)

# non-blocked sampling
return NutpieBackgroundSamplerWrapper(
result,
pymc_model=self.model,
compute_convergence_checks=compute_convergence_checks,
)
else:
return self._post_process_sample(self.model, result, compute_convergence_checks)

@staticmethod
def _post_process_sample(
model, idata: InferenceData, compute_convergence_checks
) -> InferenceData:
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
if compute_convergence_checks:
log_warnings(run_convergence_checks(idata, model))

coords, dims = coords_and_dims_for_inferencedata(model)
constant_data = dict_to_dataset(
find_constants(model),
library=idata.attrs.get("library", None),
coords=coords,
dims=dims,
default_dims=[],
)
observed_data = dict_to_dataset(
find_observations(model),
library=idata.attrs.get("library", None),
coords=coords,
dims=dims,
default_dims=[],
)
idata.add_groups(
{"constant_data": constant_data, "observed_data": observed_data},
coords=coords,
dims=dims,
)
return idata
Loading
Loading