Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Jun 2, 2024
2 parents e23602b + 64e71ee commit 45341b1
Show file tree
Hide file tree
Showing 109 changed files with 2,398 additions and 469 deletions.
18 changes: 9 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ on:
branches: [dev, master]

env:
CXX: g++-8
CC: gcc-8
CXX: g++-9
CC: gcc-9
# See coveralls-python - Github Actions support:
# https://github.com/TheKevJames/coveralls-python/blob/master/docs/usage/configuration.rst#github-actions-support
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build graphviz
sudo apt-get install gcc-9 g++-9 ninja-build graphviz
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build graphviz pandoc
sudo apt-get install gcc-9 g++-9 ninja-build graphviz pandoc
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -146,7 +146,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -180,7 +180,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -212,7 +212,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down Expand Up @@ -244,7 +244,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
all: docs test

install: FORCE
pip install -e .[dev,profile]
pip install -e .[dev,profile] --config-settings editable_mode=strict

uninstall: FORCE
pip uninstall pyro-ppl
Expand All @@ -21,7 +21,7 @@ lint: FORCE
ruff check .
black --check *.py pyro examples tests scripts profiler
python scripts/update_headers.py --check
mypy --install-types --non-interactive pyro scripts
mypy --install-types --non-interactive pyro scripts tests

license: FORCE
python scripts/update_headers.py
Expand Down
26 changes: 25 additions & 1 deletion docker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ cmd?=bash
# Determine name of docker image
build run notebook: img_prefix=pyro-cpu
build-gpu run-gpu notebook-gpu: img_prefix=pyro-gpu
build run lab: img_prefix=pyro-cpu
build-gpu run-gpu lab-gpu: img_prefix=pyro-gpu

ifeq ($(img), )
IMG_NAME=${img_prefix}-${pyro_branch}-${python_version}
Expand Down Expand Up @@ -121,10 +123,32 @@ notebook: ##

notebook-gpu: create-host-workspace
notebook-gpu: ##
## Start a juptyer notebook on the Pyro GPU docker container.
## Start a jupyter notebook on the Pyro GPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --runtime=nvidia --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME}

notebook: create-host-workspace
lab: ##
## Start jupyterlab on the Pyro CPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} jupyter lab --port=8888 --no-browser --ip=0.0.0.0

lab-gpu: create-host-workspace
lab-gpu: ##
## Start jupyterlab on the Pyro GPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --runtime=nvidia --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} jupyter lab --port=8888 --no-browser --ip=0.0.0.0


2 changes: 1 addition & 1 deletion docker/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -xe

pip install --upgrade pip
pip install jupyter matplotlib
pip install notebook ipywidgets matplotlib

# 1. Install PyTorch
# Use conda package if pytorch_branch = 'release'.
Expand Down
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ Stable
:undoc-members:
:show-inheritance:

StableWithLogProb
-----------------
.. autoclass:: pyro.distributions.StableWithLogProb
:members:
:undoc-members:
:show-inheritance:

TruncatedPolyaGamma
-------------------
.. autoclass:: pyro.distributions.TruncatedPolyaGamma
Expand Down
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def per_param_optim_args(param_name):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Pyro AIR example", argument_default=argparse.SUPPRESS
)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--num-chains", nargs="?", default=4, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
parser.add_argument("--jit", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/scoping_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/tree_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=100, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/cevae/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Causal Effect Variational Autoencoder"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Regional compartmental epidemiology modeling using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Compartmental epidemiology modeling using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/forecast/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def transform(pred, truth):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example")
parser.add_argument("--train-window", default=2160, type=int)
parser.add_argument("--test-window", default=336, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="MAP Baum-Welch learning Bach Chorales"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/gp/sv-dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Pyro GP MNIST Example")
parser.add_argument(
"--data-dir",
Expand Down
4 changes: 1 addition & 3 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def main(args):
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
torch._utils._accumulate(data_lengths), data_lengths
)
for offset, length in zip(np.cumsum(data_lengths), data_lengths)
]
else:
dataset_train = dataset
Expand Down
4 changes: 1 addition & 3 deletions examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def main(args):
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
torch._utils._accumulate(data_lengths), data_lengths
)
for offset, length in zip(np.cumsum(data_lengths), data_lengths)
]
else:
dataset_train = dataset
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/oed/ab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(num_vi_steps, num_bo_steps, seed):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="A/B test experiment design using VI")
parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int)
parser.add_argument("--num-bo-steps", nargs="?", default=5, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/timeseries/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="contrib.timeseries example usage")
parser.add_argument("-n", "--num-steps", default=300, type=int)
parser.add_argument("-s", "--seed", default=0, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/cvae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def do_evaluation():

# parse command-line arguments and execute the main method
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")

parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", type=int, default=5000)
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Eight Schools MCMC")
parser.add_argument(
"--num-samples",
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Eight Schools SVI")
parser.add_argument(
"--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="MAP Baum-Welch learning Bach Chorales"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/inclined_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=500, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Amortized Latent Dirichlet Allocation"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior")
parser.add_argument("--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--n", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def guide(data):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-b", "--backend", default="minipyro")
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Example illustrating NeuTra Reparametrizer"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.0")
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
args = parser.parse_args()
Expand Down
Loading

0 comments on commit 45341b1

Please sign in to comment.