Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop the support for PyTorch<2.0 #3272

Merged
merged 5 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ To run a single test from the command line
```sh
pytest -vs {path_to_test}::{test_name}
# or in cuda mode
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vs {path_to_test}::{test_name}
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vs {path_to_test}::{test_name}
```

To ensure documentation builds correctly, run
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ test-all: lint FORCE
| xargs pytest -vx --nbval-lax

test-cuda: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda

test-cuda-lax: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit --lax
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit --lax
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda

test-jit: FORCE
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,6 @@ def setup(app):
if "READTHEDOCS" in os.environ:
os.system("pip install numpy")
os.system(
"pip install torch==1.11.0+cpu torchvision==0.12.0+cpu "
"pip install torch==2.0+cpu torchvision==0.15.0+cpu "
"-f https://download.pytorch.org/whl/torch_stable.html"
)
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,6 @@ def main(args):
torch.multiprocessing.set_sharing_strategy("file_system")

if args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_device("cuda")

main(args)
2 changes: 1 addition & 1 deletion examples/contrib/cevae/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def generate_data(args):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

# Generate synthetic data.
pyro.set_rng_seed(args.seed)
Expand Down
9 changes: 3 additions & 6 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,9 @@ def main(args):
if args.warmup_steps is None:
args.warmup_steps = args.num_samples
if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
9 changes: 3 additions & 6 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,9 @@ def main(args):
if args.warmup_steps is None:
args.warmup_steps = args.num_samples
if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

logging.info("Loading data")
data = poly.load_data(poly.JSB_CHORALES)
Expand Down
5 changes: 2 additions & 3 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ def main(args):
)
args = parser.parse_args()

torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
torch.set_default_device("cuda")

main(args)
5 changes: 2 additions & 3 deletions examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,8 @@ def main(args):
)
args = parser.parse_args()

torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
torch.set_default_device("cuda")

main(args)
4 changes: 1 addition & 3 deletions examples/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def time_fn(fn, equation, *operands, **kwargs):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_device("cuda")

if args.method == "all":
for method in ["prob", "logprob", "gradient", "marginal", "map", "sample"]:
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

logging.info("Loading data")
data = poly.load_data(poly.JSB_CHORALES)
Expand Down
9 changes: 3 additions & 6 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,9 @@ def main(args):
args = parser.parse_args()

if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_tensor_type(torch.DoubleTensor)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_gamma_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible

torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_dtype(torch.float32)
pyro.util.set_rng_seed(0)


Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"""


torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_dtype(torch.float32)


def dot(X, Z):
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(args):
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")
device = torch.tensor(0).device

if args.horovod:
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import argparse

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch

import pyro
Expand Down
2 changes: 1 addition & 1 deletion profiler/gaussianhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def random_mvn(batch_shape, dim, requires_grad=False):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

hidden_dim = args.hidden_dim
obs_dim = args.obs_dim
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Parameterized(PyroModule):
>>> assert "b_scale_unconstrained" in dict(linear.named_parameters())

Note that by default, data of a parameter is a float :class:`torch.Tensor`
(unless we use :func:`torch.set_default_tensor_type` to change default
(unless we use :func:`torch.set_default_dtype` to change default
tensor type). To cast these parameters to a correct data type or GPU device,
we can call methods such as :meth:`~torch.nn.Module.double` or
:meth:`~torch.nn.Module.cuda`. See :class:`torch.nn.Module` for more
Expand Down
6 changes: 4 additions & 2 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def __init__(
self.rng_seed = (torch.initial_seed() + chain_id) % MAX_SEED
self.log_queue = log_queue
self.result_queue = result_queue
self.default_tensor_type = torch.Tensor().type()
self.default_dtype = torch.Tensor().dtype
self.default_device = torch.Tensor().device
self.hook = hook
self.event = event

def run(self, *args, **kwargs):
pyro.set_rng_seed(self.rng_seed)
torch.set_default_tensor_type(self.default_tensor_type)
torch.set_default_dtype(self.default_dtype)
torch.set_default_device(self.default_device)
kwargs = kwargs
logger = logging.getLogger("pyro.infer.mcmc")
logger_id = "CHAIN:{}".format(self.chain_id)
Expand Down
10 changes: 0 additions & 10 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def _track_provenance_set(x, provenance: frozenset):
@track_provenance.register(tuple)
@track_provenance.register(dict)
def _track_provenance_pytree(x, provenance: frozenset):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x

return tree_map(partial(track_provenance, provenance=provenance), x)


Expand Down Expand Up @@ -143,11 +138,6 @@ def _extract_provenance_set(x):
@extract_provenance.register(tuple)
@extract_provenance.register(dict)
def _extract_provenance_pytree(x):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x, frozenset()

flat_args, spec = tree_flatten(x)
xs = []
provenance = frozenset()
Expand Down
6 changes: 1 addition & 5 deletions pyro/optim/pytorch_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
del _PyroOptim

# Load all schedulers from PyTorch
# breaking change in torch >= 1.14: LRScheduler is new base class
if hasattr(torch.optim.lr_scheduler, "LRScheduler"):
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore
else: # for torch < 1.13, _LRScheduler is base class
_torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler # type: ignore
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore

for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
if not isinstance(_Optim, type):
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@
"jupyter>=1.0.0",
"graphviz>=0.8",
"matplotlib>=1.3",
"torchvision>=0.12.0",
"torchvision>=0.15.0",
"visdom>=0.1.4,<0.2.2", # FIXME visdom.utils is unavailable >=0.2.2
"pandas",
"pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125
"pillow>=8.3.1", # https://github.com/pytorch/pytorch/issues/61125
"scikit-learn",
"seaborn>=0.11.0",
"wget",
Expand Down Expand Up @@ -102,7 +102,7 @@
"numpy>=1.7",
"opt_einsum>=2.3.2",
"pyro-api>=0.1.1",
"torch>=1.11.0",
"torch>=2.0",
"tqdm>=4.36",
],
extras_require={
Expand Down Expand Up @@ -135,7 +135,7 @@
"yapf",
],
"horovod": ["horovod[pytorch]>=0.19"],
"lightning": ["pytorch_lightning"],
"lightning": ["lightning"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
Expand Down
23 changes: 3 additions & 20 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def wrapper(*args, **kwargs):
)

try:
import pytorch_lightning
import lightning
except ImportError:
pytorch_lightning = None
lightning = None
requires_lightning = pytest.mark.skipif(
pytorch_lightning is None, reason="pytorch lightning is not available"
lightning is None, reason="pytorch lightning is not available"
)

try:
Expand All @@ -93,23 +93,6 @@ def get_gpu_type(t):
return getattr(torch.cuda, t.__name__)


@contextlib.contextmanager
def tensors_default_to(host):
"""
Context manager to temporarily use Cpu or Cuda tensors in PyTorch.

:param str host: Either "cuda" or "cpu".
"""
assert host in ("cpu", "cuda"), host
old_module, name = torch.Tensor().type().rsplit(".", 1)
new_module = "torch.cuda" if host == "cuda" else "torch"
torch.set_default_tensor_type("{}.{}".format(new_module, name))
try:
yield
finally:
torch.set_default_tensor_type("{}.{}".format(old_module, name))


@contextlib.contextmanager
def default_dtype(dtype):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import pyro

torch.set_default_tensor_type(os.environ.get("PYRO_TENSOR_TYPE", "torch.DoubleTensor"))
DTYPE = getattr(torch, os.environ.get("PYRO_DTYPE", "float64"))
torch.set_default_dtype(DTYPE)
torch.set_default_device(os.environ.get("PYRO_DEVICE", "cpu"))


def pytest_configure(config):
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/timeseries/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
@pytest.mark.parametrize("T", [11, 37])
def test_timeseries_models(model, nu_statedim, obs_dim, T):
torch.set_default_tensor_type("torch.DoubleTensor")
torch.set_default_dtype(torch.float64)
dt = 0.1 + torch.rand(1).item()

if model == "lcmgp":
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/timeseries/test_lgssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.parametrize("obs_dim", [2, 4])
@pytest.mark.parametrize("T", [11, 17])
def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T):
torch.set_default_tensor_type("torch.DoubleTensor")
torch.set_default_dtype(torch.float64)

if model_class == "lgssm":
model = GenericLGSSM(
Expand Down
Loading
Loading