diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e9e21f0967..1bbf4c35d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' pip install flake8 black isort>=5.0 mypy nbstripout nbformat - name: Lint run: | @@ -51,7 +51,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build graphviz - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html @@ -79,7 +79,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html @@ -113,7 +113,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html @@ -147,7 +147,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html @@ -179,7 +179,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html @@ -211,7 +211,7 @@ jobs: sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt-get update sudo apt-get install gcc-8 g++-8 ninja-build - python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' + python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html diff --git a/docs/requirements.txt b/docs/requirements.txt index da561a25c1..4af4a755cf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,3 +7,4 @@ opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 funsor[torch] +setuptools<60 diff --git a/docs/source/index.rst b/docs/source/index.rst index c5db42fdc6..82b70e684f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ Pyro Documentation optimization poutine ops + settings testing .. toctree:: diff --git a/docs/source/settings.rst b/docs/source/settings.rst new file mode 100644 index 0000000000..e2cd61fd9d --- /dev/null +++ b/docs/source/settings.rst @@ -0,0 +1,6 @@ +Settings +-------- + +.. automodule:: pyro.settings + :members: + :member-order: bysource diff --git a/examples/air/main.py b/examples/air/main.py index 3725bd4aea..c37a8b1c38 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -272,7 +272,7 @@ def per_param_optim_args(param_name): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Pyro AIR example", argument_default=argparse.SUPPRESS ) diff --git a/examples/baseball.py b/examples/baseball.py index 10a6806b8d..91c986b85c 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -392,7 +392,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Baseball batting average using HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int) parser.add_argument("--num-chains", nargs="?", default=4, type=int) diff --git a/examples/contrib/autoname/mixture.py b/examples/contrib/autoname/mixture.py index 96c68ee44d..918aa8eb60 100644 --- a/examples/contrib/autoname/mixture.py +++ b/examples/contrib/autoname/mixture.py @@ -74,7 +74,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-epochs", default=200, type=int) parser.add_argument("--jit", action="store_true") diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 717c4f4b51..7c70b96f60 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -71,7 +71,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-epochs", default=200, type=int) args = parser.parse_args() diff --git a/examples/contrib/autoname/tree_data.py b/examples/contrib/autoname/tree_data.py index 1a0f4440f8..7ad85db42f 100644 --- a/examples/contrib/autoname/tree_data.py +++ b/examples/contrib/autoname/tree_data.py @@ -104,7 +104,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-epochs", default=100, type=int) args = parser.parse_args() diff --git a/examples/contrib/cevae/synthetic.py b/examples/contrib/cevae/synthetic.py index 919ebd61e3..fa4848a828 100644 --- a/examples/contrib/cevae/synthetic.py +++ b/examples/contrib/cevae/synthetic.py @@ -86,7 +86,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Causal Effect Variational Autoencoder" ) diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index 02205028c9..873419a46b 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -166,7 +166,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Regional compartmental epidemiology modeling using HMC" ) diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 7cc817a811..223e288b3c 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -334,7 +334,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Compartmental epidemiology modeling using HMC" ) diff --git a/examples/contrib/forecast/bart.py b/examples/contrib/forecast/bart.py index a75c26de8d..3cf44fd621 100644 --- a/examples/contrib/forecast/bart.py +++ b/examples/contrib/forecast/bart.py @@ -31,6 +31,7 @@ def preprocess(args): arrivals = dataset["counts"][:, :, i].sum(-1) departures = dataset["counts"][:, i, :].sum(-1) data = torch.stack([arrivals, departures], dim=-1) + print(f"Loaded data of shape {tuple(data.shape)}") # This simple example uses no covariates, so we will construct a # zero-element tensor of the correct length as empty covariates. @@ -165,7 +166,7 @@ def transform(pred, truth): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example") parser.add_argument("--train-window", default=2160, type=int) parser.add_argument("--test-window", default=336, type=int) diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index dfe930e056..55628fef4d 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -820,7 +820,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="MAP Baum-Welch learning Bach Chorales" ) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index ea49069981..a16e06631a 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -193,7 +193,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Pyro GP MNIST Example") parser.add_argument( "--data-dir", diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 32ad730987..6c311ed5b2 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -125,7 +125,7 @@ def main(num_vi_steps, num_bo_steps, seed): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="A/B test experiment design using VI") parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int) parser.add_argument("--num-bo-steps", nargs="?", default=5, type=int) diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index 4d95613b8f..7bd7ffce4b 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -186,7 +186,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="contrib.timeseries example usage") parser.add_argument("-n", "--num-steps", default=300, type=int) parser.add_argument("-s", "--seed", default=0, type=int) diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 18cdcaa959..bdd437d9e8 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -87,7 +87,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument( diff --git a/examples/dmm.py b/examples/dmm.py index 9731530849..5768acf33b 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -571,7 +571,7 @@ def do_evaluation(): # parse command-line arguments and execute the main method if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-epochs", type=int, default=5000) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index 0f5e61159a..c3792d476f 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -43,7 +43,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Eight Schools MCMC") parser.add_argument( "--num-samples", diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index ace2c24a52..22e96b6ef3 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -81,7 +81,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Eight Schools SVI") parser.add_argument( "--lr", type=float, default=0.01, help="learning rate (default: 0.01)" diff --git a/examples/hmm.py b/examples/hmm.py index 0881cb2c91..7d1b13d259 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -734,7 +734,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="MAP Baum-Welch learning Bach Chorales" ) diff --git a/examples/inclined_plane.py b/examples/inclined_plane.py index 11f66dbe94..a7dbffdbfd 100644 --- a/examples/inclined_plane.py +++ b/examples/inclined_plane.py @@ -145,7 +145,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=500, type=int) args = parser.parse_args() diff --git a/examples/lda.py b/examples/lda.py index 3796757d69..de5058c1f3 100644 --- a/examples/lda.py +++ b/examples/lda.py @@ -149,7 +149,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Amortized Latent Dirichlet Allocation" ) diff --git a/examples/lkj.py b/examples/lkj.py index 04b4700089..d343e75c51 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -56,7 +56,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior") parser.add_argument("--num-samples", nargs="?", default=200, type=int) parser.add_argument("--n", nargs="?", default=500, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index 524ff96130..49c7d7c017 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -65,7 +65,7 @@ def guide(data): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-b", "--backend", default="minipyro") parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index 4eaa8e78bc..3304dbfa58 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -232,7 +232,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser( description="Example illustrating NeuTra Reparametrizer" ) diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index c443d01ac5..350c5fefc1 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -177,7 +177,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=10, type=int) args = parser.parse_args() diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index 859e135929..e9bb94accd 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -217,7 +217,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=10, type=int) parser.add_argument("--price", default=10000, type=int) diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index 37a44fc8e3..e8d1317ca1 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -79,7 +79,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=10, type=int) parser.add_argument("--depth", default=2, type=int) diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index e43c6b1b23..1fd4d0c297 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -96,7 +96,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=10, type=int) parser.add_argument("--depth", default=3, type=int) diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index 4c4e7dfadb..d63c2b5a53 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -351,7 +351,7 @@ def is_all_qud(world): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="parse args") parser.add_argument("-n", "--num-samples", default=10, type=int) args = parser.parse_args() diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 6679f49d27..a7858e0702 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -407,7 +407,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") # Parse command line arguments parser = argparse.ArgumentParser( description="single-cell ANnotation using Variational Inference" diff --git a/examples/sir_hmc.py b/examples/sir_hmc.py index 11851ec784..8e587b24ba 100644 --- a/examples/sir_hmc.py +++ b/examples/sir_hmc.py @@ -633,7 +633,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC") parser.add_argument("-p", "--population", default=10, type=int) parser.add_argument("-m", "--min-observations", default=3, type=int) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index dd9c935e3e..fd9bed8cc4 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -269,7 +269,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument( diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 0d5d2b0094..83d7e723e8 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -364,7 +364,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Krylov KIT") parser.add_argument("--num-data", type=int, default=750) parser.add_argument("--num-steps", type=int, default=1000) diff --git a/examples/svi_horovod.py b/examples/svi_horovod.py index 5a05955a7d..1c55e78a2e 100644 --- a/examples/svi_horovod.py +++ b/examples/svi_horovod.py @@ -154,7 +154,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Distributed training via Horovod") parser.add_argument("-o", "--outfile") parser.add_argument("-s", "--size", default=1000000, type=int) diff --git a/examples/toy_mixture_model_discrete_enumeration.py b/examples/toy_mixture_model_discrete_enumeration.py index 26ac1452fb..eed2b8126d 100644 --- a/examples/toy_mixture_model_discrete_enumeration.py +++ b/examples/toy_mixture_model_discrete_enumeration.py @@ -133,7 +133,7 @@ def get_true_pred_CPDs(CPD, posterior_param): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="Toy mixture model") parser.add_argument("-n", "--num-steps", default=4000, type=int) parser.add_argument("-o", "--num-obs", default=10000, type=int) diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index 82ba77fab4..e357751910 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -433,7 +433,7 @@ def main(args): ) if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN)) diff --git a/examples/vae/vae.py b/examples/vae/vae.py index fce421b113..ae3b5d0561 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -216,7 +216,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument( diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index 6c0333b70e..3399cf6af5 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -262,7 +262,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith("1.8.2") + assert pyro.__version__.startswith("1.8.3") parser = argparse.ArgumentParser(description="VAE using MNIST dataset") parser.add_argument("-n", "--num-epochs", nargs="?", default=10, type=int) parser.add_argument("--batch_size", nargs="?", default=128, type=int) diff --git a/profiler/gaussianhmm.py b/profiler/gaussianhmm.py new file mode 100644 index 0000000000..898b37f4e0 --- /dev/null +++ b/profiler/gaussianhmm.py @@ -0,0 +1,83 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +import torch +from tqdm.auto import tqdm + +import pyro.distributions as dist + + +def random_mvn(batch_shape, dim, requires_grad=False): + rank = dim + dim + loc = torch.randn(batch_shape + (dim,), requires_grad=requires_grad) + cov = torch.randn(batch_shape + (dim, rank)) + cov = cov.matmul(cov.transpose(-1, -2)) + scale_tril = torch.linalg.cholesky(cov) + scale_tril.requires_grad_(requires_grad) + return dist.MultivariateNormal(loc, scale_tril=scale_tril) + + +def main(args): + if args.cuda: + torch.set_default_tensor_type("torch.cuda.FloatTensor") + + hidden_dim = args.hidden_dim + obs_dim = args.obs_dim + duration = args.duration + batch_shape = (args.batch_size,) + + # Initialize parts. + init_dist = random_mvn(batch_shape, hidden_dim, requires_grad=args.grad) + trans_dist = random_mvn( + batch_shape + (duration,), hidden_dim, requires_grad=args.grad + ) + obs_dist = random_mvn(batch_shape + (1,), obs_dim, requires_grad=args.grad) + trans_mat = 0.1 * torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) + obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim)) + + if args.grad: + # Collect parameters. + params = [ + init_dist.loc, + init_dist.scale_tril, + trans_dist.loc, + trans_dist.scale_tril, + obs_dist.loc, + obs_dist.scale_tril, + trans_mat.requires_grad_(), + obs_mat.requires_grad_(), + ] + + # Build a distribution. + d = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) + + for step in tqdm(range(args.num_steps)): + if not args.grad: + # Time forward only. + d.sample() + continue + + # Time forward + backward. + x = d.rsample() + grads = torch.autograd.grad( + x.sum(), params, allow_unused=True, retain_graph=True + ) + assert not all(g is None for g in grads) + del x + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GaussianHMM profiler") + parser.add_argument("--hidden-dim", type=int, default=4) + parser.add_argument("--obs-dim", type=int, default=4) + parser.add_argument("--duration", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=3) + parser.add_argument("-n", "--num-steps", type=int, default=100) + parser.add_argument("--cuda", action="store_true", default=False) + parser.add_argument("--grad", action="store_true", default=False) + args = parser.parse_args() + main(args) diff --git a/pyro/__init__.py b/pyro/__init__.py index c8c07e6eb2..25a1db0150 100644 --- a/pyro/__init__.py +++ b/pyro/__init__.py @@ -25,8 +25,10 @@ ) from pyro.util import set_rng_seed +from . import settings + # After changing this, run scripts/update_version.py -version_prefix = "1.8.2" +version_prefix = "1.8.3" # Get the __version__ string from the auto-generated _version.py file, if exists. try: @@ -41,6 +43,7 @@ "condition", "deterministic", "do", + "enable_module_local_param", "enable_validation", "factor", "get_param_store", @@ -49,6 +52,7 @@ "log", "markov", "module", + "module_local_param_enabled", "param", "plate", "plate", @@ -58,6 +62,7 @@ "render_model", "sample", "set_rng_seed", + "settings", "subsample", "validation_enabled", ] diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 5524766414..2112d9979b 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -15,10 +15,12 @@ gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian, + sequential_gaussian_filter_sample, + sequential_gaussian_tensordot, ) from pyro.ops.indexing import Vindex from pyro.ops.special import safe_log -from pyro.ops.tensor_utils import cholesky, cholesky_solve +from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky from . import constraints from .torch import Categorical, Gamma, Independent, MultivariateNormal @@ -159,115 +161,6 @@ def _sequential_index(samples): return samples.squeeze(-3)[..., :duration, :] -def _sequential_gaussian_tensordot(gaussian): - """ - Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: - - x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] - """ - assert isinstance(gaussian, Gaussian) - assert gaussian.dim() % 2 == 0, "dim is not even" - batch_shape = gaussian.batch_shape[:-1] - state_dim = gaussian.dim() // 2 - while gaussian.batch_shape[-1] > 1: - time = gaussian.batch_shape[-1] - even_time = time // 2 * 2 - even_part = gaussian[..., :even_time] - x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) - x, y = x_y[..., 0], x_y[..., 1] - contracted = gaussian_tensordot(x, y, state_dim) - if time > even_time: - contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) - gaussian = contracted - return gaussian[..., 0] - - -def _is_subshape(x, y): - return broadcast_shape(x, y) == y - - -def _sequential_gaussian_filter_sample(init, trans, sample_shape): - """ - Draws a reparameterized sample from a Markov product of Gaussians via - parallel-scan forward-filter backward-sample. - """ - assert isinstance(init, Gaussian) - assert isinstance(trans, Gaussian) - assert trans.dim() == 2 * init.dim() - assert _is_subshape(trans.batch_shape[:-1], init.batch_shape) - state_dim = trans.dim() // 2 - device = trans.precision.device - perm = torch.cat( - [ - torch.arange(1 * state_dim, 2 * state_dim, device=device), - torch.arange(0 * state_dim, 1 * state_dim, device=device), - torch.arange(2 * state_dim, 3 * state_dim, device=device), - ] - ) - - # Forward filter, similar to _sequential_gaussian_tensordot(). - tape = [] - shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted. - gaussian = trans - while gaussian.batch_shape[-1] > 1: - time = gaussian.batch_shape[-1] - even_time = time // 2 * 2 - even_part = gaussian[..., :even_time] - x_y = even_part.reshape(shape + (even_time // 2, 2)) - x, y = x_y[..., 0], x_y[..., 1] - x = x.event_pad(right=state_dim) - y = y.event_pad(left=state_dim) - joint = (x + y).event_permute(perm) - tape.append(joint) - contracted = joint.marginalize(left=state_dim) - if time > even_time: - contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) - gaussian = contracted - gaussian = gaussian[..., 0] + init.event_pad(right=state_dim) - - # Backward sample. - shape = sample_shape + init.batch_shape - result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim)) - for joint in reversed(tape): - # The following comments demonstrate two example computations, one - # EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be - # a single sampled event of shape (state_dim,). - if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case. - # Suppose e.g. result = [z0, z2, z4] - cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4] - cond = cond[..., 1:-1, :] # [z0, z2, z2, z4] - cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4] - sample = joint.condition(cond).rsample() # [z1, z3] - sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0] - result = torch.stack( - [ - result, # [z0, z2, z4] - sample, # [z1, z3, 0] - ], - dim=-2, - ) # [[z0, z1], [z2, z3], [z4, 0]] - result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0] - result = result[..., :-1, :] # [z0, z1, z2, z3, z4] - else: # ODD case. - assert joint.batch_shape[-1] == result.size(-2) - 2 - # Suppose e.g. result = [z0, z2, z3] - cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2] - cond = cond[..., 1:-1, :] # [z0, z2] - cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2] - sample = joint.condition(cond).rsample() # [z1] - sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3] - result = torch.stack( - [ - result[..., :-1, :], # [z0, z2] - sample, # [z1, z3] - ], - dim=-2, - ) # [[z0, z1], [z2, z3]] - result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3] - - return result[..., 1:, :] # [z1, z2, z3, ...] - - def _sequential_gamma_gaussian_tensordot(gamma_gaussian): """ Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes:: @@ -657,9 +550,9 @@ def expand(self, batch_shape, _instance=None): new._obs = self._obs new._trans = self._trans - # To save computation in _sequential_gaussian_tensordot(), we expand + # To save computation in sequential_gaussian_tensordot(), we expand # only _init, which is applied only after - # _sequential_gaussian_tensordot(). + # sequential_gaussian_tensordot(). batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) new._init = self._init.expand(batch_shape) @@ -679,7 +572,7 @@ def log_prob(self, value): ) # Eliminate time dimension. - result = _sequential_gaussian_tensordot(result.expand(result.batch_shape)) + result = sequential_gaussian_tensordot(result.expand(result.batch_shape)) # Combine initial factor. result = gaussian_tensordot(self._init, result, dims=self.hidden_dim) @@ -695,7 +588,8 @@ def rsample(self, sample_shape=torch.Size()): left=self.hidden_dim ) trans = trans.expand(trans.batch_shape[:-1] + (self.duration,)) - z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = z[..., 1:, :] # drop the initial hidden state x = self._obs.left_condition(z).rsample() return x @@ -705,7 +599,8 @@ def rsample_posterior(self, value, sample_shape=torch.Size()): """ trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) trans = trans.expand(trans.batch_shape) - z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = z[..., 1:, :] # drop the initial hidden state return z def filter(self, value): @@ -726,16 +621,16 @@ def filter(self, value): logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) # Eliminate time dimension. - logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) + logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Convert to a distribution precision = logp.precision - loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze( - -1 - ) + loc = cholesky_solve( + logp.info_vec.unsqueeze(-1), safe_cholesky(precision) + ).squeeze(-1) return MultivariateNormal( loc, precision_matrix=precision, validate_args=self._validate_args ) @@ -780,7 +675,7 @@ def conjugate_update(self, other): logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad( left=new.hidden_dim ) - logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) + logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim) log_normalizer = logp.event_logsumexp() new._init = new._init - log_normalizer @@ -970,8 +865,8 @@ def expand(self, batch_shape, _instance=None): new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined - # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), - # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). + # by broadcasting all three. To save computation in sequential_gaussian_tensordot(), + # we expand only _init, which is applied only after sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs @@ -1033,7 +928,7 @@ def filter(self, value): gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args ) # Conditional of last state on unit scale - scale_tril = cholesky(logp.precision) + scale_tril = safe_cholesky(logp.precision) loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1) mvn = MultivariateNormal( loc, scale_tril=scale_tril, validate_args=self._validate_args @@ -1380,8 +1275,8 @@ def expand(self, batch_shape, _instance=None): new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined - # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), - # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). + # by broadcasting all three. To save computation in sequential_gaussian_tensordot(), + # we expand only _init, which is applied only after sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs @@ -1411,7 +1306,7 @@ def log_prob(self, value): logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. - logp = _sequential_gaussian_tensordot(logp) + logp = sequential_gaussian_tensordot(logp) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 9459803212..c6e8c7e1d2 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -9,6 +9,7 @@ from pyro.distributions.util import broadcast_shape, sum_rightmost from pyro.ops.special import log_binomial +from .. import settings from . import constraints @@ -98,6 +99,22 @@ def log_prob(self, value): ) +@settings.register( + "binomial_approx_sample_thresh", __name__, "Binomial.approx_sample_thresh" +) +def _validate_thresh(thresh): + assert isinstance(thresh, float) + assert 0 < thresh + + +@settings.register( + "binomial_approx_log_prob_tol", __name__, "Binomial.approx_log_prob_tol" +) +def _validate_tol(tol): + assert isinstance(tol, float) + assert 0 <= tol + + # This overloads .log_prob() and .enumerate_support() to speed up evaluating # log_prob on the support of this variable: we can completely avoid tensor ops # and merely reshape the self.logits tensor. This is especially important for diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index 3b890f5a3b..d2e4d22684 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -89,7 +89,7 @@ def log_abs_det_jacobian(self, x, y): class CholeskyTransform(Transform): r""" - Transform via the mapping :math:`y = cholesky(x)`, where `x` is a + Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a positive definite matrix. """ bijective = True @@ -116,7 +116,7 @@ def log_abs_det_jacobian(self, x, y): class CorrMatrixCholeskyTransform(CholeskyTransform): r""" - Transform via the mapping :math:`y = cholesky(x)`, where `x` is a + Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a correlation matrix. """ bijective = True diff --git a/pyro/distributions/transforms/generalized_channel_permute.py b/pyro/distributions/transforms/generalized_channel_permute.py index 0231c63a46..1d2e733823 100644 --- a/pyro/distributions/transforms/generalized_channel_permute.py +++ b/pyro/distributions/transforms/generalized_channel_permute.py @@ -174,7 +174,7 @@ def __init__(self, channels=3, permutation=None): W, _ = torch.linalg.qr(torch.randn(channels, channels)) # Construct the partially pivoted LU-form and the pivots - LU, pivots = W.lu() + LU, pivots = torch.linalg.lu_factor(W) # Convert the pivots into the permutation matrix if permutation is None: diff --git a/pyro/distributions/transforms/spline.py b/pyro/distributions/transforms/spline.py index ba7d240a1b..2f32f61aac 100644 --- a/pyro/distributions/transforms/spline.py +++ b/pyro/distributions/transforms/spline.py @@ -254,6 +254,8 @@ def _monotonic_rational_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + # Make sure outside_interval input can be reversed as identity. + discriminant = discriminant.masked_fill(outside_interval_mask, 0) assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant)) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index f3ddbf51a2..ba59273069 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -15,9 +15,18 @@ from pyro.util import ignore_jit_warnings +from .. import settings + _VALIDATION_ENABLED = __debug__ torch_dist.Distribution.set_default_validate_args(__debug__) +settings.register("validate_distributions_pyro", __name__, "_VALIDATION_ENABLED") +settings.register( + "validate_distributions_torch", + "torch.distributions.distribution", + "Distribution._validate_args", +) + log_sum_exp = logsumexp # DEPRECATED diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 3abe07d748..05ebd6f2a9 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -5,6 +5,8 @@ import warnings from abc import ABCMeta, abstractmethod +import torch + import pyro import pyro.poutine as poutine from pyro.infer.util import is_validation_enabled @@ -12,6 +14,17 @@ from pyro.util import check_site_shape +class ELBOModule(torch.nn.Module): + def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"): + super().__init__() + self.model = model + self.guide = guide + self.elbo = elbo + + def forward(self, *args, **kwargs): + return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs) + + class ELBO(object, metaclass=ABCMeta): """ :class:`ELBO` is the top-level interface for stochastic variational @@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta): :class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`. + .. note:: Derived classes now provide a more idiomatic PyTorch interface via + :meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s, + which is useful for integrating Pyro's variational inference tooling with + standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s + and the large ecosystem of libraries like PyTorch Lightning + and the PyTorch JIT that work with these interfaces:: + + model = Model() + guide = pyro.infer.autoguide.AutoNormal(model) + + elbo_ = pyro.infer.Trace_ELBO(num_particles=10) + + # Fix the model/guide pair + elbo = elbo_(model, guide) + + # perform any data-dependent initialization + elbo(data) + + optim = torch.optim.Adam(elbo.parameters(), lr=0.001) + + for _ in range(100): + optim.zero_grad() + loss = elbo(data) + loss.backward() + optim.step() + + Note that Pyro's global parameter store may cause this new interface to + behave unexpectedly relative to standard PyTorch when working with + :class:`~pyro.nn.PyroModule` s. + + Users are therefore strongly encouraged to use this interface in conjunction + with :func:`~pyro.enable_module_local_param` which will override the default + implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances. + :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. :param int max_plate_nesting: Optional bound on max number of nested @@ -86,6 +133,13 @@ def __init__( self.jit_options = jit_options self.tail_adaptive_beta = tail_adaptive_beta + def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ELBOModule: + """ + Given a model and guide, returns a :class:`~torch.nn.Module` which + computes the ELBO loss when called with arguments to the model and guide. + """ + return ELBOModule(model, guide, self) + def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 6de9ddc51a..ff73dedde5 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -66,6 +66,8 @@ class HMC(MCMCKernel): step size, hence the sampling will be slower and more robust. Default to 0.8. :param callable init_strategy: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. + :param min_stepsize (float): Lower bound on stepsize in adaptation strategy. + :param max_stepsize (float): Upper bound on stepsize in adaptation strategy. .. note:: Internally, the mass matrix will be ordered according to the order of the names of latent variables, not the order of their appearance in @@ -108,6 +110,9 @@ def __init__( ignore_jit_warnings=False, target_accept_prob=0.8, init_strategy=init_to_uniform, + *, + min_stepsize: float = 1e-10, + max_stepsize: float = 1e10, ): if not ((model is None) ^ (potential_fn is None)): raise ValueError("Only one of `model` or `potential_fn` must be specified.") @@ -119,6 +124,8 @@ def __init__( self._jit_options = jit_options self._ignore_jit_warnings = ignore_jit_warnings self._init_strategy = init_strategy + self._min_stepsize = min_stepsize + self._max_stepsize = max_stepsize self.potential_fn = potential_fn if trajectory_length is not None: @@ -188,9 +195,11 @@ def _find_reasonable_step_size(self, z): step_size_scale = 2**direction direction_new = direction # keep scale step_size until accept_prob crosses its target - # TODO: make thresholds for too small step_size or too large step_size t = 0 - while direction_new == direction: + while ( + direction_new == direction + and self._min_stepsize < step_size < self._max_stepsize + ): t += 1 step_size = step_size_scale * step_size r, r_unscaled = self._sample_r(name="r_presample_{}".format(t)) @@ -206,6 +215,8 @@ def _find_reasonable_step_size(self, z): energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 + step_size = max(step_size, self._min_stepsize) + step_size = min(step_size, self._max_stepsize) return step_size def _sample_r(self, name): diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 3ee94b884d..7ea460c1ec 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -16,7 +16,11 @@ from pyro.ops.rings import MarginalRing from pyro.poutine.util import site_is_subsample +from .. import settings + _VALIDATION_ENABLED = __debug__ +settings.register("validate_infer", __name__, "_VALIDATION_ENABLED") + LAST_CACHE_SIZE = [Counter()] # for profiling diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 7a87631e2e..8168461f0c 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -23,6 +23,17 @@ from pyro.ops.provenance import detach_provenance from pyro.poutine.runtime import _PYRO_PARAM_STORE +_MODULE_LOCAL_PARAMS: bool = False + + +@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS") +def _validate_module_local_params(value: bool) -> None: + assert isinstance(value, bool) + + +def _is_module_local_param_enabled() -> bool: + return pyro.settings.get("module_local_params") + class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): """ @@ -178,8 +189,13 @@ def __init__(self): self.active = 0 self.cache = {} self.used = False + if _is_module_local_param_enabled(): + self.param_state = {"params": {}, "constraints": {}} def __enter__(self): + if not self.active and _is_module_local_param_enabled(): + self._param_ctx = pyro.get_param_store().scope(state=self.param_state) + self.param_state = self._param_ctx.__enter__() self.active += 1 self.used = True @@ -187,6 +203,9 @@ def __exit__(self, type, value, traceback): self.active -= 1 if not self.active: self.cache.clear() + if _is_module_local_param_enabled(): + self._param_ctx.__exit__(type, value, traceback) + del self._param_ctx def get(self, name): if self.active: @@ -409,6 +428,8 @@ def named_pyro_params(self, prefix="", recurse=True): yield elem def _pyro_set_supermodule(self, name, context): + if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"): + self._check_module_local_param_usage() self._pyro_name = name self._pyro_context = context for key, value in self._modules.items(): @@ -424,7 +445,26 @@ def _pyro_get_fullname(self, name): def __call__(self, *args, **kwargs): with self._pyro_context: - return super().__call__(*args, **kwargs) + result = super().__call__(*args, **kwargs) + if ( + pyro.settings.get("validate_poutine") + and not self._pyro_context.active + and _is_module_local_param_enabled() + ): + self._check_module_local_param_usage() + return result + + def _check_module_local_param_usage(self) -> None: + self_nn_params = set(id(p) for p in self.parameters()) + self_pyro_params = set( + id(p if not hasattr(p, "unconstrained") else p.unconstrained()) + for p in self._pyro_context.param_state["params"].values() + ) + if not self_pyro_params <= self_nn_params: + raise NotImplementedError( + "Support for global pyro.param statements in PyroModules " + "with local param mode enabled is not yet implemented." + ) def __getattr__(self, name): # PyroParams trigger pyro.param statements. diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 3712aa8af5..12f17e973c 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import math +from typing import Optional, Tuple import torch from torch.distributions.utils import lazy_property from torch.nn.functional import pad from pyro.distributions.util import broadcast_shape -from pyro.ops.tensor_utils import cholesky, matmul, matvecmul, triangular_solve +from pyro.ops.tensor_utils import matmul, matvecmul, safe_cholesky, triangular_solve class Gaussian: @@ -28,7 +29,12 @@ class Gaussian: :param torch.Tensor precision: precision matrix of this gaussian. """ - def __init__(self, log_normalizer, info_vec, precision): + def __init__( + self, + log_normalizer: torch.Tensor, + info_vec: torch.Tensor, + precision: torch.Tensor, + ): # NB: using info_vec instead of mean to deal with rank-deficient problem assert info_vec.dim() >= 1 assert precision.dim() >= 2 @@ -48,21 +54,21 @@ def batch_shape(self): self.precision.shape[:-2], ) - def expand(self, batch_shape): + def expand(self, batch_shape) -> "Gaussian": n = self.dim() log_normalizer = self.log_normalizer.expand(batch_shape) info_vec = self.info_vec.expand(batch_shape + (n,)) precision = self.precision.expand(batch_shape + (n, n)) return Gaussian(log_normalizer, info_vec, precision) - def reshape(self, batch_shape): + def reshape(self, batch_shape) -> "Gaussian": n = self.dim() log_normalizer = self.log_normalizer.reshape(batch_shape) info_vec = self.info_vec.reshape(batch_shape + (n,)) precision = self.precision.reshape(batch_shape + (n, n)) return Gaussian(log_normalizer, info_vec, precision) - def __getitem__(self, index): + def __getitem__(self, index) -> "Gaussian": """ Index into the batch_shape of a Gaussian. """ @@ -73,7 +79,7 @@ def __getitem__(self, index): return Gaussian(log_normalizer, info_vec, precision) @staticmethod - def cat(parts, dim=0): + def cat(parts, dim=0) -> "Gaussian": """ Concatenate a list of Gaussians along a given batch dimension. """ @@ -85,7 +91,7 @@ def cat(parts, dim=0): ] return Gaussian(*args) - def event_pad(self, left=0, right=0): + def event_pad(self, left=0, right=0) -> "Gaussian": """ Pad along event dimension. """ @@ -95,7 +101,7 @@ def event_pad(self, left=0, right=0): precision = pad(self.precision, lr + lr) return Gaussian(log_normalizer, info_vec, precision) - def event_permute(self, perm): + def event_permute(self, perm) -> "Gaussian": """ Permute along event dimension. """ @@ -105,7 +111,7 @@ def event_permute(self, perm): precision = self.precision[..., perm][..., perm, :] return Gaussian(self.log_normalizer, info_vec, precision) - def __add__(self, other): + def __add__(self, other: "Gaussian") -> "Gaussian": """ Adds two Gaussians in log-density space. """ @@ -120,12 +126,12 @@ def __add__(self, other): return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def __sub__(self, other): + def __sub__(self, other: "Gaussian") -> "Gaussian": if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer - other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def log_density(self, value): + def log_density(self, value: torch.Tensor) -> torch.Tensor: """ Evaluate the log density of this Gaussian at a point value:: @@ -135,24 +141,31 @@ def log_density(self, value): """ if value.size(-1) == 0: batch_shape = broadcast_shape(value.shape[:-1], self.batch_shape) - return self.log_normalizer.expand(batch_shape) + result: torch.Tensor = self.log_normalizer.expand(batch_shape) + return result result = (-0.5) * matvecmul(self.precision, value) result = result + self.info_vec result = (value * result).sum(-1) return result + self.log_normalizer - def rsample(self, sample_shape=torch.Size()): + def rsample( + self, sample_shape=torch.Size(), noise: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Reparameterized sampler. """ - P_chol = cholesky(self.precision) + P_chol = safe_cholesky(self.precision) loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1) shape = sample_shape + self.batch_shape + (self.dim(), 1) - noise = torch.randn(shape, dtype=loc.dtype, device=loc.device) + if noise is None: + noise = torch.randn(shape, dtype=loc.dtype, device=loc.device) + else: + noise = noise.reshape(shape) noise = triangular_solve(noise, P_chol, upper=False, transpose=True).squeeze(-1) - return loc + noise + sample: torch.Tensor = loc + noise + return sample - def condition(self, value): + def condition(self, value: torch.Tensor) -> "Gaussian": """ Condition this Gaussian on a trailing subset of its state. This should satisfy:: @@ -189,7 +202,7 @@ def condition(self, value): ) return Gaussian(log_normalizer, info_vec, precision) - def left_condition(self, value): + def left_condition(self, value: torch.Tensor) -> "Gaussian": """ Condition this Gaussian on a leading subset of its state. This should satisfy:: @@ -217,7 +230,7 @@ def left_condition(self, value): ) return self.event_permute(perm).condition(value) - def marginalize(self, left=0, right=0): + def marginalize(self, left=0, right=0) -> "Gaussian": """ Marginalizing out variables on either side of the event dimension:: @@ -241,7 +254,7 @@ def marginalize(self, left=0, right=0): P_aa = self.precision[..., a, a] P_ba = self.precision[..., b, a] P_bb = self.precision[..., b, b] - P_b = cholesky(P_bb) + P_b = safe_cholesky(P_bb) P_a = triangular_solve(P_ba, P_b, upper=False) P_at = P_a.transpose(-1, -2) precision = P_aa - matmul(P_at, P_a) @@ -259,22 +272,23 @@ def marginalize(self, left=0, right=0): ) return Gaussian(log_normalizer, info_vec, precision) - def event_logsumexp(self): + def event_logsumexp(self) -> torch.Tensor: """ Integrates out all latent state (i.e. operating on event dimensions). """ n = self.dim() - chol_P = cholesky(self.precision) + chol_P = safe_cholesky(self.precision) chol_P_u = triangular_solve( self.info_vec.unsqueeze(-1), chol_P, upper=False ).squeeze(-1) u_P_u = chol_P_u.pow(2).sum(-1) - return ( + log_Z: torch.Tensor = ( self.log_normalizer + 0.5 * n * math.log(2 * math.pi) + 0.5 * u_P_u - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1) ) + return log_Z class AffineNormal: @@ -339,15 +353,21 @@ def left_condition(self, value): else: return self.to_gaussian().left_condition(value) - def rsample(self, sample_shape=torch.Size()): + def rsample( + self, sample_shape=torch.Size(), noise: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Reparameterized sampler. """ if self.matrix.size(-2) > 0: raise NotImplementedError shape = sample_shape + self.batch_shape + self.loc.shape[-1:] - noise = torch.randn(shape, dtype=self.loc.dtype, device=self.loc.device) - return self.loc + noise * self.scale + if noise is None: + noise = torch.randn(shape, dtype=self.loc.dtype, device=self.loc.device) + else: + noise = noise.reshape(shape) + sample: torch.Tensor = self.loc + noise * self.scale + return sample def to_gaussian(self): if self._gaussian is None: @@ -355,7 +375,7 @@ def to_gaussian(self): torch.distributions.Normal(self.loc, scale=self.scale), 1 ) y_gaussian = mvn_to_gaussian(mvn) - self._gaussian = _matrix_and_gaussian_to_gaussian(self.matrix, y_gaussian) + self._gaussian = matrix_and_gaussian_to_gaussian(self.matrix, y_gaussian) return self._gaussian def expand(self, batch_shape): @@ -426,7 +446,17 @@ def mvn_to_gaussian(mvn): return Gaussian(log_normalizer, info_vec, precision) -def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian): +def matrix_and_gaussian_to_gaussian( + matrix: torch.Tensor, y_gaussian: Gaussian +) -> Gaussian: + """ + Constructs a conditional Gaussian for ``p(y|x)`` where + ``y - x @ matrix ~ y_gaussian``. + + :param torch.Tensor matrix: A right-acting transformation matrix. + :param Gaussian y_gaussian: A distribution over noise of ``y - x@matrix``. + :rtype: Gaussian + """ P_yy = y_gaussian.precision neg_P_xy = matmul(matrix, P_yy) P_xy = -neg_P_xy @@ -471,13 +501,13 @@ def matrix_and_mvn_to_gaussian(matrix, mvn): return AffineNormal(matrix, mvn.base_dist.loc, mvn.base_dist.scale) y_gaussian = mvn_to_gaussian(mvn) - result = _matrix_and_gaussian_to_gaussian(matrix, y_gaussian) + result = matrix_and_gaussian_to_gaussian(matrix, y_gaussian) assert result.batch_shape == batch_shape assert result.dim() == x_dim + y_dim return result -def gaussian_tensordot(x, y, dims=0): +def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gaussian: """ Computes the integral over two gaussians: @@ -520,7 +550,7 @@ def gaussian_tensordot(x, y, dims=0): b = xb + yb # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral) - L = cholesky(Pbb + Qbb) + L = safe_cholesky(Pbb + Qbb) LinvB = triangular_solve(B, L, upper=False) LinvBt = LinvB.transpose(-2, -1) Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False) @@ -538,3 +568,148 @@ def gaussian_tensordot(x, y, dims=0): log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision) + + +def sequential_gaussian_tensordot(gaussian: Gaussian) -> Gaussian: + """ + Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: + + x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] + + :param Gaussian gaussian: A batched Gaussian whose rightmost dimension is time. + :returns: A Markov product of the Gaussian along its time dimension. + :rtype: Gaussian + """ + assert isinstance(gaussian, Gaussian) + assert gaussian.dim() % 2 == 0, "dim is not even" + batch_shape = gaussian.batch_shape[:-1] + state_dim = gaussian.dim() // 2 + while gaussian.batch_shape[-1] > 1: + time = gaussian.batch_shape[-1] + even_time = time // 2 * 2 + even_part = gaussian[..., :even_time] + x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) + x, y = x_y[..., 0], x_y[..., 1] + contracted = gaussian_tensordot(x, y, state_dim) + if time > even_time: + contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) + gaussian = contracted + return gaussian[..., 0] + + +def sequential_gaussian_filter_sample( + init: Gaussian, + trans: Gaussian, + sample_shape: Tuple[int, ...] = (), + noise: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Draws a reparameterized sample from a Markov product of Gaussians via + parallel-scan forward-filter backward-sample. + + :param Gaussian init: A Gaussian representing an initial state. + :param Gaussian trans: A Gaussian representing as series of state transitions, + with time as the rightmost batch dimension. This must have twice the event + dim as ``init``: ``trans.dim() == 2 * init.dim()``. + :param tuple sample_shape: An optional extra shape of samples to draw. + :param torch.Tensor noise: An optional standard white noise tensor of shape + ``sample_shape + batch_shape + (duration, state_dim)``, where + ``duration = 1 + trans.batch_shape[-1]`` is the number of time points + to be sampled, and ``state_dim = init.dim()`` is the state dimension. + This is useful for computing the mean (pass zeros), varying temperature + (pass scaled noise), and antithetic sampling (pass ``cat([z,-z])``). + :returns: A reparametrized sample of shape + ``sample_shape + batch_shape + (duration, state_dim)``. + :rtype: torch.Tensor + """ + assert isinstance(init, Gaussian) + assert isinstance(trans, Gaussian) + assert trans.dim() == 2 * init.dim() + state_dim = trans.dim() // 2 + batch_shape = broadcast_shape(trans.batch_shape[:-1], init.batch_shape) + if init.batch_shape != batch_shape: + init = init.expand(batch_shape) + dtype = trans.precision.dtype + device = trans.precision.device + perm = torch.cat( + [ + torch.arange(1 * state_dim, 2 * state_dim, device=device), + torch.arange(0 * state_dim, 1 * state_dim, device=device), + torch.arange(2 * state_dim, 3 * state_dim, device=device), + ] + ) + + # Forward filter, similar to sequential_gaussian_tensordot(). + tape = [] + shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted. + gaussian = trans + while gaussian.batch_shape[-1] > 1: + time = gaussian.batch_shape[-1] + even_time = time // 2 * 2 + even_part = gaussian[..., :even_time] + x_y = even_part.reshape(shape + (even_time // 2, 2)) + x, y = x_y[..., 0], x_y[..., 1] + x = x.event_pad(right=state_dim) + y = y.event_pad(left=state_dim) + joint = (x + y).event_permute(perm) + tape.append(joint) + contracted = joint.marginalize(left=state_dim) + if time > even_time: + contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) + gaussian = contracted + gaussian = gaussian[..., 0] + init.event_pad(right=state_dim) + + # Generate noise in batch, then allow blocks to be consumed incrementally. + duration = 1 + trans.batch_shape[-1] + shape = torch.Size(sample_shape) + init.batch_shape + result_shape = shape + (duration, state_dim) + noise_stride = shape.numel() * state_dim # noise is consumed in time blocks + noise_position: int = 0 + if noise is None: + noise = torch.randn(result_shape, dtype=dtype, device=device) + assert noise.shape == result_shape + + def rsample(g: Gaussian, sample_shape: Tuple[int, ...] = ()) -> torch.Tensor: + """Samples, extracting a time-block of noise.""" + nonlocal noise_position + assert noise is not None + numel = torch.Size(sample_shape + g.batch_shape + (g.dim(),)).numel() + assert numel % noise_stride == 0 + beg: int = noise_position + end: int = noise_position + numel // noise_stride + assert end <= duration, "too little noise provided" + noise_position = end + return g.rsample(sample_shape, noise=noise[..., beg:end, :]) + + # Backward sample. + result = rsample(gaussian, sample_shape).reshape(shape + (2, state_dim)) + for joint in reversed(tape): + # The following comments demonstrate two example computations, one + # EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be + # a single sampled event of shape (state_dim,). + if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case. + # Suppose e.g. result = [z0, z2, z4] + cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4] + cond = cond[..., 1:-1, :] # [z0, z2, z2, z4] + cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4] + sample = rsample(joint.condition(cond)) # [z1, z3] + zipper = result.new_empty(shape + (2 * result.size(-2) - 1, state_dim)) + zipper[..., ::2, :] = result # [z0, _, z2, _, z4] + zipper[..., 1::2, :] = sample # [_, z1, _, z3, _] + result = zipper # [z0, z1, z2, z3, z4] + else: # ODD case. + assert joint.batch_shape[-1] == result.size(-2) - 2 + # Suppose e.g. result = [z0, z2, z3] + cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2] + cond = cond[..., 1:-1, :] # [z0, z2] + cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2] + sample = rsample(joint.condition(cond)) # [z1] + zipper = result.new_empty(shape + (2 * result.size(-2) - 2, state_dim)) + zipper[..., ::2, :] = result[..., :-1, :] # [z0, _, z2, _] + zipper[..., -1, :] = result[..., -1, :] # [_, _, _, z3] + zipper[..., 1:-1:2, :] = sample # [_, z1, _, _] + result = zipper # [z0, z1, z2, z3] + + assert noise_position == duration, "too much noise provided" + assert result.shape == result_shape + return result # [z0, z1, z2, ...] diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 7efc847d59..17cb367a17 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -6,7 +6,16 @@ import torch from torch.fft import irfft, rfft +from .. import settings + _ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0) +CHOLESKY_RELATIVE_JITTER = 4.0 # in units of finfo.eps + + +@settings.register("cholesky_relative_jitter", __name__, "CHOLESKY_RELATIVE_JITTER") +def _validate_jitter(value): + assert isinstance(value, (float, int)) + assert 0 <= value def as_complex(x): @@ -393,9 +402,19 @@ def inverse_haar_transform(x): return x -def cholesky(x): +def safe_cholesky(x): if x.size(-1) == 1: + if CHOLESKY_RELATIVE_JITTER: + x = x.clamp(min=torch.finfo(x.dtype).tiny) return x.sqrt() + + if CHOLESKY_RELATIVE_JITTER: + # Add adaptive jitter. + x = x.clone() + x_max = x.data.abs().max(-1).values + jitter = CHOLESKY_RELATIVE_JITTER * torch.finfo(x.dtype).eps * x_max + x.data.diagonal(dim1=-1, dim2=-2).add_(jitter) + return torch.linalg.cholesky(x) diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index e90c2a5917..3a0ec0316b 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -1,7 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from .. import settings + _VALIDATION_ENABLED = __debug__ +settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED") def enable_validation(is_validate): diff --git a/pyro/settings.py b/pyro/settings.py new file mode 100644 index 0000000000..38dc1e01de --- /dev/null +++ b/pyro/settings.py @@ -0,0 +1,163 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example usage:: + + # Simple getting and setting. + print(pyro.settings.get()) # print all settings + print(pyro.settings.get("cholesky_relative_jitter")) # print one + pyro.settings.set(cholesky_relative_jitter=0.5) # set one + pyro.settings.set(**my_settings) # set many + + # Use as a contextmanager. + with pyro.settings.context(cholesky_relative_jitter=0.5): + my_function() + + # Use as a decorator. + fn = pyro.settings.context(cholesky_relative_jitter=0.5)(my_function) + fn() + + # Register a new setting. + pyro.settings.register( + "binomial_approx_sample_thresh", # alias + "pyro.distributions.torch", # module + "Binomial.approx_sample_thresh", # deep name + ) + + # Register a new setting on a user-provided validator. + @pyro.settings.register( + "binomial_approx_sample_thresh", # alias + "pyro.distributions.torch", # module + "Binomial.approx_sample_thresh", # deep name + ) + def validate_thresh(thresh): # called each time setting is set + assert isinstance(thresh, float) + assert thresh > 0 + +Default Settings +---------------- + +{defaults} + +Settings Interface +------------------ +""" + +# This library must have no dependencies on other pyro modules. +import functools +from contextlib import contextmanager +from importlib import import_module +from typing import Any, Callable, Dict, Iterator, Optional, Tuple + +# Docs are updated by register(). +_doc_template = __doc__ + +# Global registry mapping alias:str to (modulename, deepname, validator) +# triples where deepname may have dots to indicate e.g. class variables. +_REGISTRY: Dict[str, Tuple[str, str, Optional[Callable]]] = {} + + +def get(alias: Optional[str] = None) -> Any: + """ + Gets one or all global settings. + + :param str alias: The name of a registered setting. + :returns: The currently set value. + """ + if alias is None: + # Return dict of all settings. + return {alias: get(alias) for alias in sorted(_REGISTRY)} + # Get a single setting. + module, deepname, validator = _REGISTRY[alias] + value = import_module(module) + for name in deepname.split("."): + value = getattr(value, name) + return value + + +def set(**kwargs) -> None: + r""" + Sets one or more settings. + + :param \*\*kwargs: alias=value pairs. + """ + for alias, value in kwargs.items(): + module, deepname, validator = _REGISTRY[alias] + if validator is not None: + validator(value) + destin = import_module(module) + names = deepname.split(".") + for name in names[:-1]: + destin = getattr(destin, name) + setattr(destin, names[-1], value) + + +@contextmanager +def context(**kwargs) -> Iterator[None]: + r""" + Context manager to temporarily override one or more settings. This also + works as a decorator. + + :param \*\*kwargs: alias=value pairs. + """ + old = {alias: get(alias) for alias in kwargs} + try: + set(**kwargs) + yield + finally: + set(**old) + + +def register( + alias: str, + modulename: str, + deepname: str, + validator: Optional[Callable] = None, +) -> Callable: + """ + Register a global settings. + + This should be declared in the module where the setting is defined. + + This can be used either as a declaration:: + + settings.register("my_setting", __name__, "MY_SETTING") + + or as a decorator on a user-defined validator function:: + + @settings.register("my_setting", __name__, "MY_SETTING") + def _validate_my_setting(value): + assert isinstance(value, float) + assert 0 < value + + :param str alias: A valid python identifier serving as a settings alias. + Lower snake case preferred, e.g. ``my_setting``. + :param str modulename: The module name where the setting is declared, + typically ``__name__``. + :param str deepname: A ``.``-separated string of names. E.g. for a module + constant, use ``MY_CONSTANT``. For a class attributue, use + ``MyClass.my_attribute``. + :param callable validator: Optional validator that inputs a value, + possibly raises validation errors, and returns None. + """ + global __doc__ + assert isinstance(alias, str) + assert alias.isidentifier() + assert isinstance(modulename, str) + assert isinstance(deepname, str) + _REGISTRY[alias] = modulename, deepname, validator + + # Add default value to module docstring. + __doc__ = _doc_template.format( + defaults="\n".join(f"- {a} = {get(a)}" for a in sorted(_REGISTRY)) + ) + + # Support use as a decorator on an optional user-provided validator. + if validator is None: + # Return a decorator, but its fine if user discards this. + return functools.partial(register, alias, modulename, deepname) + else: + # Test current value passes validation. + validator(get(alias)) + return validator diff --git a/setup.py b/setup.py index 990d7f9402..5a1727ab1e 100644 --- a/setup.py +++ b/setup.py @@ -69,13 +69,13 @@ "graphviz>=0.8", "matplotlib>=1.3", "torchvision>=0.12.0", - "visdom>=0.1.4", + "visdom>=0.1.4,<0.2.2", # FIXME visdom.utils is unavailable >=0.2.2 "pandas", "pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125 "scikit-learn", "seaborn>=0.11.0", "wget", - "lap", + "lap", # Requires setuptools<60 # 'biopython>=1.54', # 'scanpy>=1.4', # Requires HDF5 # 'scvi>=0.6', # Requires loopy and other fragile packages diff --git a/tests/distributions/test_hmm.py b/tests/distributions/test_hmm.py index 1bd88a3d1f..01ea316855 100644 --- a/tests/distributions/test_hmm.py +++ b/tests/distributions/test_hmm.py @@ -13,8 +13,6 @@ import pyro.distributions as dist from pyro.distributions.hmm import ( _sequential_gamma_gaussian_tensordot, - _sequential_gaussian_filter_sample, - _sequential_gaussian_tensordot, _sequential_logmatmulexp, ) from pyro.distributions.util import broadcast_shape @@ -36,7 +34,7 @@ random_gamma, random_gamma_gaussian, ) -from tests.ops.gaussian import assert_close_gaussian, random_gaussian, random_mvn +from tests.ops.gaussian import random_mvn logger = logging.getLogger(__name__) @@ -93,35 +91,6 @@ def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): assert_close(actual, expected) -@pytest.mark.parametrize("num_steps", list(range(1, 20))) -@pytest.mark.parametrize("state_dim", [1, 2, 3]) -@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) -def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): - g = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) - actual = _sequential_gaussian_tensordot(g) - assert actual.dim() == g.dim() - assert actual.batch_shape == batch_shape - - # Check against hand computation. - expected = g[..., 0] - for t in range(1, num_steps): - expected = gaussian_tensordot(expected, g[..., t], state_dim) - assert_close_gaussian(actual, expected) - - -@pytest.mark.parametrize("num_steps", list(range(1, 20))) -@pytest.mark.parametrize("state_dim", [1, 2, 3]) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) -def test_sequential_gaussian_filter_sample( - sample_shape, batch_shape, state_dim, num_steps -): - init = random_gaussian(batch_shape, state_dim) - trans = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) - sample = _sequential_gaussian_filter_sample(init, trans, sample_shape) - assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim) - - @pytest.mark.parametrize("num_steps", list(range(1, 20))) @pytest.mark.parametrize("state_dim", [1, 2, 3]) @pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 41198b9f2a..90b8af0eec 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -66,7 +66,97 @@ def forward(self, *args, **kwargs): svi.step(data) -def test_names(): +@pytest.mark.parametrize("local_params", [True, False]) +@pytest.mark.parametrize("num_particles", [1, 2]) +@pytest.mark.parametrize("vectorize_particles", [True, False]) +@pytest.mark.parametrize("Autoguide", [pyro.infer.autoguide.AutoNormal]) +def test_svi_elbomodule_interface( + local_params, num_particles, vectorize_particles, Autoguide +): + class Model(PyroModule): + def __init__(self): + super().__init__() + self.loc = nn.Parameter(torch.zeros(2)) + self.scale = PyroParam(torch.ones(2), constraint=constraints.positive) + self.z = PyroSample( + lambda self: dist.Normal(self.loc, self.scale).to_event(1) + ) + + def forward(self, data): + loc, log_scale = self.z.unbind(-1) + with pyro.plate("data"): + pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), obs=data) + + with pyro.settings.context(module_local_params=local_params): + data = torch.randn(5) + model = Model() + model(data) # initialize + + guide = Autoguide(model) + guide(data) # initialize + + elbo = Trace_ELBO( + vectorize_particles=vectorize_particles, num_particles=num_particles + ) + elbo = elbo(model, guide) + assert isinstance(elbo, torch.nn.Module) + assert set( + k[: -len("_unconstrained")] if k.endswith("_unconstrained") else k + for k, v in elbo.named_parameters() + ) == set("model." + k for k, v in model.named_pyro_params()) | set( + "guide." + k for k, v in guide.named_pyro_params() + ) + + adam = torch.optim.Adam(elbo.parameters(), lr=0.0001) + for _ in range(3): + adam.zero_grad() + loss = elbo(data) + loss.backward() + adam.step() + + guide2 = Autoguide(model) + guide2(data) # initialize + if local_params: + assert set(pyro.get_param_store().keys()) == set() + for (name, p), (name2, p2) in zip( + guide.named_parameters(), guide2.named_parameters() + ): + assert name == name2 + assert not torch.allclose(p, p2) + else: + assert set(pyro.get_param_store().keys()) != set() + for (name, p), (name2, p2) in zip( + guide.named_parameters(), guide2.named_parameters() + ): + assert name == name2 + assert torch.allclose(p, p2) + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_local_param_global_behavior_fails(local_params): + class Model(PyroModule): + def __init__(self): + super().__init__() + self.global_nn_param = nn.Parameter(torch.zeros(2)) + + def forward(self): + global_param = pyro.param("_global_param", lambda: torch.randn(2)) + global_nn_param = pyro.param("global_nn_param", self.global_nn_param) + return global_param, global_nn_param + + with pyro.settings.context(module_local_params=local_params): + model = Model() + if local_params: + assert pyro.settings.get("module_local_params") + with pytest.raises(NotImplementedError): + model() + else: + assert not pyro.settings.get("module_local_params") + model() + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_names(local_params): class Model(PyroModule): def __init__(self): super().__init__() @@ -86,34 +176,39 @@ def forward(self): self.p.v self.p.w - model = Model() - - # Check named_parameters. - expected = { - "x", - "y_unconstrained", - "m.u", - "p.v", - "p.w_unconstrained", - } - actual = set(name for name, _ in model.named_parameters()) - assert actual == expected - - # Check pyro.param names. - expected = {"x", "y", "m$$$u", "p.v", "p.w"} - with poutine.trace(param_only=True) as param_capture: - model() - actual = { - name - for name, site in param_capture.trace.nodes.items() - if site["type"] == "param" - } - assert actual == expected - - # Check pyro_parameters method - expected = {"x", "y", "m.u", "p.v", "p.w"} - actual = set(k for k, v in model.named_pyro_params()) - assert actual == expected + with pyro.settings.context(module_local_params=local_params): + model = Model() + + # Check named_parameters. + expected = { + "x", + "y_unconstrained", + "m.u", + "p.v", + "p.w_unconstrained", + } + actual = set(name for name, _ in model.named_parameters()) + assert actual == expected + + # Check pyro.param names. + expected = {"x", "y", "m$$$u", "p.v", "p.w"} + with poutine.trace(param_only=True) as param_capture: + model() + actual = { + name + for name, site in param_capture.trace.nodes.items() + if site["type"] == "param" + } + assert actual == expected + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == expected + + # Check pyro_parameters method + expected = {"x", "y", "m.u", "p.v", "p.w"} + actual = set(k for k, v in model.named_pyro_params()) + assert actual == expected def test_delete(): @@ -258,7 +353,8 @@ def test_constraints(shape, constraint_): assert not hasattr(module, "x_unconstrained") -def test_clear(): +@pytest.mark.parametrize("local_params", [True, False]) +def test_clear(local_params): class Model(PyroModule): def __init__(self): super().__init__() @@ -272,28 +368,43 @@ def __init__(self): def forward(self): return [x.clone() for x in [self.x, self.m.weight, self.m.bias, self.p.x]] - assert set(pyro.get_param_store().keys()) == set() - m = Model() - state0 = m() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - - # mutate - for x in pyro.get_param_store().values(): - x.unconstrained().data += torch.randn(()) - state1 = m() - for x, y in zip(state0, state1): - assert not (x == y).all() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - - clear(m) - del m - assert set(pyro.get_param_store().keys()) == set() - - m = Model() - state2 = m() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - for actual, expected in zip(state2, state0): - assert_equal(actual, expected) + with pyro.settings.context(module_local_params=local_params): + m = Model() + state0 = m() + + # mutate + for _, x in m.named_pyro_params(): + x.unconstrained().data += torch.randn(()) + state1 = m() + for x, y in zip(state0, state1): + assert not (x == y).all() + + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == { + "x", + "m$$$weight", + "m$$$bias", + "p.x", + } + clear(m) + del m + assert set(pyro.get_param_store().keys()) == set() + + m = Model() + state2 = m() + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == { + "x", + "m$$$weight", + "m$$$bias", + "p.x", + } + for actual, expected in zip(state2, state0): + assert_equal(actual, expected) def test_sample(): @@ -532,48 +643,65 @@ def randomize(model): assert_identical(actual, expected) -def test_torch_serialize_attributes(): - module = PyroModule() - module.x = PyroParam(torch.tensor(1.234), constraints.positive) - module.y = nn.Parameter(torch.randn(3)) - assert isinstance(module.x, torch.Tensor) - - # Work around https://github.com/pytorch/pytorch/issues/27972 - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - f = io.BytesIO() - torch.save(module, f) - pyro.clear_param_store() - f.seek(0) - actual = torch.load(f) - - assert_equal(actual.x, module.x) - actual_names = {name for name, _ in actual.named_parameters()} - expected_names = {name for name, _ in module.named_parameters()} - assert actual_names == expected_names - - -def test_torch_serialize_decorators(): - module = DecoratorModel(3) - module() # initialize - - # Work around https://github.com/pytorch/pytorch/issues/27972 - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - f = io.BytesIO() - torch.save(module, f) - pyro.clear_param_store() - f.seek(0) - actual = torch.load(f) - - assert_equal(actual.x, module.x) - assert_equal(actual.y, module.y) - assert_equal(actual.z, module.z) - assert actual.s.shape == module.s.shape - assert actual.t.shape == module.t.shape - actual_names = {name for name, _ in actual.named_parameters()} - expected_names = {name for name, _ in module.named_parameters()} - assert actual_names == expected_names +@pytest.mark.parametrize("local_params", [True, False]) +def test_torch_serialize_attributes(local_params): + with pyro.settings.context(module_local_params=local_params): + module = PyroModule() + module.x = PyroParam(torch.tensor(1.234), constraints.positive) + module.y = nn.Parameter(torch.randn(3)) + assert isinstance(module.x, torch.Tensor) + + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(module, f) + pyro.clear_param_store() + f.seek(0) + actual = torch.load(f) + + assert_equal(actual.x, module.x) + actual_names = {name for name, _ in actual.named_parameters()} + expected_names = {name for name, _ in module.named_parameters()} + assert actual_names == expected_names + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_torch_serialize_decorators(local_params): + with pyro.settings.context(module_local_params=local_params): + module = DecoratorModel(3) + module() # initialize + + module2 = DecoratorModel(3) + module2() # initialize + + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(module, f) + pyro.clear_param_store() + f.seek(0) + actual = torch.load(f) + + assert_equal(actual.x, module.x) + assert_equal(actual.y, module.y) + assert_equal(actual.z, module.z) + assert actual.s.shape == module.s.shape + assert actual.t.shape == module.t.shape + actual_names = {name for name, _ in actual.named_parameters()} + expected_names = {name for name, _ in module.named_parameters()} + assert actual_names == expected_names + + actual() + if local_params: + assert len(set(pyro.get_param_store().keys())) == 0 + assert torch.all(module.y != module2.y) + assert torch.all(actual.y != module2.y) + else: + assert len(set(pyro.get_param_store().keys())) > 0 + assert_equal(module.y, module2.y) + assert_equal(actual.y, module2.y) def test_pyro_serialize(): diff --git a/tests/ops/gaussian.py b/tests/ops/gaussian.py index 9bf926bbb7..040273b78a 100644 --- a/tests/ops/gaussian.py +++ b/tests/ops/gaussian.py @@ -8,30 +8,32 @@ from tests.common import assert_close -def random_gaussian(batch_shape, dim, rank=None): +def random_gaussian(batch_shape, dim, rank=None, *, requires_grad=False): """ Generate a random Gaussian for testing. """ if rank is None: rank = dim + dim - log_normalizer = torch.randn(batch_shape) - info_vec = torch.randn(batch_shape + (dim,)) + log_normalizer = torch.randn(batch_shape, requires_grad=requires_grad) + info_vec = torch.randn(batch_shape + (dim,), requires_grad=requires_grad) samples = torch.randn(batch_shape + (dim, rank)) precision = torch.matmul(samples, samples.transpose(-2, -1)) + precision.requires_grad_(requires_grad) result = Gaussian(log_normalizer, info_vec, precision) assert result.dim() == dim assert result.batch_shape == batch_shape return result -def random_mvn(batch_shape, dim): +def random_mvn(batch_shape, dim, *, requires_grad=False): """ Generate a random MultivariateNormal distribution for testing. """ rank = dim + dim - loc = torch.randn(batch_shape + (dim,)) + loc = torch.randn(batch_shape + (dim,), requires_grad=requires_grad) cov = torch.randn(batch_shape + (dim, rank)) cov = cov.matmul(cov.transpose(-1, -2)) + cov.requires_grad_(requires_grad) return dist.MultivariateNormal(loc, cov) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index fa5924b128..0982d8eb91 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -15,8 +15,11 @@ AffineNormal, Gaussian, gaussian_tensordot, + matrix_and_gaussian_to_gaussian, matrix_and_mvn_to_gaussian, mvn_to_gaussian, + sequential_gaussian_filter_sample, + sequential_gaussian_tensordot, ) from tests.common import assert_close from tests.ops.gaussian import assert_close_gaussian, random_gaussian, random_mvn @@ -376,7 +379,7 @@ def test_gaussian_tensordot( nc = y_dim - dot_dims try: torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb]) - except RuntimeError: + except Exception: pytest.skip("Cannot marginalize the common variables of two Gaussians.") z = gaussian_tensordot(x, y, dot_dims) @@ -488,3 +491,122 @@ def check_equal(actual, expected, atol=0.01, rtol=0): funsor.ops.mean, "particle" ) check_equal(fp_entropy.data, entropy) + + +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) +def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): + g = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) + actual = sequential_gaussian_tensordot(g) + assert actual.dim() == g.dim() + assert actual.batch_shape == batch_shape + + # Check against hand computation. + expected = g[..., 0] + for t in range(1, num_steps): + expected = gaussian_tensordot(expected, g[..., t], state_dim) + assert_close_gaussian(actual, expected) + + +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) +def test_sequential_gaussian_filter_sample( + sample_shape, batch_shape, state_dim, num_steps +): + init = random_gaussian(batch_shape, state_dim, requires_grad=True) + trans = random_gaussian( + batch_shape + (num_steps,), state_dim + state_dim, requires_grad=True + ) + duration = 1 + num_steps + + # Check shape. + sample = sequential_gaussian_filter_sample(init, trans, sample_shape) + assert sample.shape == sample_shape + batch_shape + (duration, state_dim) + + # Check gradients. + assert sample.requires_grad + loss = (torch.randn_like(sample) * sample).sum() + params = [init.info_vec, init.precision, trans.info_vec, trans.precision] + torch.autograd.grad(loss, params) + + +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) +def test_sequential_gaussian_filter_sample_antithetic( + sample_shape, batch_shape, state_dim, num_steps +): + init = random_gaussian(batch_shape, state_dim) + trans = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) + duration = 1 + num_steps + + noise = torch.randn(sample_shape + batch_shape + (duration, state_dim)) + zero = torch.zeros_like(noise) + sample = sequential_gaussian_filter_sample(init, trans, sample_shape, noise) + mean = sequential_gaussian_filter_sample(init, trans, sample_shape, zero) + assert sample.shape == sample_shape + batch_shape + (duration, state_dim) + assert mean.shape == sample_shape + batch_shape + (duration, state_dim) + + # Check that antithetic sampling works as expected. + noise3 = torch.stack([noise, zero, -noise]) + sample3 = sequential_gaussian_filter_sample( + init, trans, (3,) + sample_shape, noise3 + ) + expected = torch.stack([sample, mean, 2 * mean - sample]) + assert torch.allclose(sample3, expected) + + +@pytest.mark.filterwarnings("ignore:Singular matrix in cholesky") +@pytest.mark.parametrize("num_steps", [10, 100, 1000, 10000, 100000, 1000000]) +def test_sequential_gaussian_filter_sample_stability(num_steps): + # This tests long-chain filtering at low precision. + zero = torch.zeros((), dtype=torch.float) + eye = torch.eye(4, dtype=torch.float) + noise = torch.randn(num_steps, 4, dtype=torch.float, requires_grad=True) + trans_matrix = torch.tensor( + [ + [ + 0.8571434617042542, + -0.23285813629627228, + 0.05360094830393791, + -0.017088839784264565, + ], + [ + 0.7609677314758301, + 0.6596274971961975, + -0.022656921297311783, + 0.05166701227426529, + ], + [ + 3.0979342460632324, + 5.446939945220947, + -0.3425334692001343, + 0.01096670888364315, + ], + [ + -1.8180007934570312, + -0.4965082108974457, + -0.006048532668501139, + -0.08525419235229492, + ], + ], + dtype=torch.float, + requires_grad=True, + ) + + init = Gaussian(zero, zero.expand(4), eye) + trans = matrix_and_gaussian_to_gaussian( + trans_matrix, Gaussian(zero, zero.expand(4), eye) + ).expand((num_steps - 1,)) + + # Check numerically stabilized value. + x = sequential_gaussian_filter_sample(init, trans, (), noise) + assert torch.isfinite(x).all() + + # Check gradients. + grads = torch.autograd.grad(x.sum(), [trans_matrix, noise]) + assert all(torch.isfinite(g).all() for g in grads) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 7aa5e8fdc8..6b6dc59d8a 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -56,27 +56,25 @@ def optim_params(param_name): elif param_name == free_param: return {"lr": 0.01} + def get_steps(adam): + state = adam.get_state()["loc_q"]["state"] + return int(list(state.values())[0]["step"]) + adam = optim.Adam(optim_params) adam2 = optim.Adam(optim_params) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO()) svi.step() - adam_initial_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][ - 1 - ]["step"] + adam_initial_step_count = get_steps(adam) with TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "optimizer_state.pt") adam.save(filename) svi.step() - adam_final_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][ - 1 - ]["step"] + adam_final_step_count = get_steps(adam) adam2.load(filename) svi2.step() - adam2_step_count_after_load_and_step = list( - adam2.get_state()["loc_q"]["state"].items() - )[0][1]["step"] + adam2_step_count_after_load_and_step = get_steps(adam2) assert adam_initial_step_count == 1 assert adam_final_step_count == 2 diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000000..fa584b7aa9 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,50 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pyro import settings + +_TEST_SETTING: float = 0.1 + +pytestmark = pytest.mark.stage("unit") + + +def test_settings(): + v0 = settings.get() + assert isinstance(v0, dict) + assert all(isinstance(alias, str) for alias in v0) + assert settings.get("validate_distributions_pyro") is True + assert settings.get("validate_distributions_torch") is True + assert settings.get("validate_poutine") is True + assert settings.get("validate_infer") is True + + +def test_register(): + with pytest.raises(KeyError): + settings.get("test_setting") + + @settings.register("test_setting", "tests.test_settings", "_TEST_SETTING") + def _validate(value): + assert isinstance(value, float) + assert 0 < value + + # Test simple get and set. + assert settings.get("test_setting") == 0.1 + settings.set(test_setting=0.2) + assert settings.get("test_setting") == 0.2 + with pytest.raises(AssertionError): + settings.set(test_setting=-0.1) + + # Test context manager. + with settings.context(test_setting=0.3): + assert settings.get("test_setting") == 0.3 + assert settings.get("test_setting") == 0.2 + + # Test decorator. + @settings.context(test_setting=0.4) + def fn(): + assert settings.get("test_setting") == 0.4 + + fn() + assert settings.get("test_setting") == 0.2 diff --git a/tutorial/source/air.ipynb b/tutorial/source/air.ipynb index ad5a613d9b..9c69613b6e 100644 --- a/tutorial/source/air.ipynb +++ b/tutorial/source/air.ipynb @@ -41,7 +41,7 @@ "import numpy as np\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/bayesian_regression.ipynb b/tutorial/source/bayesian_regression.ipynb index 2a7140e34a..7a05fc981b 100644 --- a/tutorial/source/bayesian_regression.ipynb +++ b/tutorial/source/bayesian_regression.ipynb @@ -69,7 +69,7 @@ "\n", "# for CI testing\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(1)\n", "\n", "\n", diff --git a/tutorial/source/bayesian_regression_ii.ipynb b/tutorial/source/bayesian_regression_ii.ipynb index 02827db6b3..99cb3dd1c9 100644 --- a/tutorial/source/bayesian_regression_ii.ipynb +++ b/tutorial/source/bayesian_regression_ii.ipynb @@ -44,7 +44,7 @@ "import pyro.optim as optim\n", "\n", "pyro.set_rng_seed(1)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/bo.ipynb b/tutorial/source/bo.ipynb index 0db806931b..a0474608bf 100644 --- a/tutorial/source/bo.ipynb +++ b/tutorial/source/bo.ipynb @@ -54,7 +54,7 @@ "import pyro\n", "import pyro.contrib.gp as gp\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(1)" ] }, diff --git a/tutorial/source/custom_objectives_training.ipynb b/tutorial/source/custom_objectives.ipynb similarity index 100% rename from tutorial/source/custom_objectives_training.ipynb rename to tutorial/source/custom_objectives.ipynb diff --git a/tutorial/source/dirichlet_process_mixture.ipynb b/tutorial/source/dirichlet_process_mixture.ipynb index 3266dc9d6e..bc0d918ff1 100644 --- a/tutorial/source/dirichlet_process_mixture.ipynb +++ b/tutorial/source/dirichlet_process_mixture.ipynb @@ -76,7 +76,7 @@ "from pyro.infer import Predictive, SVI, Trace_ELBO\n", "from pyro.optim import Adam\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(0)" ] }, diff --git a/tutorial/source/easyguide.ipynb b/tutorial/source/easyguide.ipynb index 9661244484..1c3f0fd2e0 100644 --- a/tutorial/source/easyguide.ipynb +++ b/tutorial/source/easyguide.ipynb @@ -44,7 +44,7 @@ "from torch.distributions import constraints\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/ekf.ipynb b/tutorial/source/ekf.ipynb index ce44e2e75b..38dcceecd3 100644 --- a/tutorial/source/ekf.ipynb +++ b/tutorial/source/ekf.ipynb @@ -98,7 +98,7 @@ "from pyro.contrib.tracking.measurements import PositionMeasurement\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/enumeration.ipynb b/tutorial/source/enumeration.ipynb index 58a033ecdc..343087daed 100644 --- a/tutorial/source/enumeration.ipynb +++ b/tutorial/source/enumeration.ipynb @@ -50,7 +50,7 @@ "from pyro.ops.indexing import Vindex\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(0)" ] }, diff --git a/tutorial/source/epi_intro.ipynb b/tutorial/source/epi_intro.ipynb index b5c51db19d..b18bbe3e82 100644 --- a/tutorial/source/epi_intro.ipynb +++ b/tutorial/source/epi_intro.ipynb @@ -58,7 +58,7 @@ "from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "torch.set_default_dtype(torch.double) # Required for MCMC inference.\n", "smoke_test = ('CI' in os.environ)" ] diff --git a/tutorial/source/forecasting_dlm.ipynb b/tutorial/source/forecasting_dlm.ipynb index 68f7388c15..9c05a69aa6 100644 --- a/tutorial/source/forecasting_dlm.ipynb +++ b/tutorial/source/forecasting_dlm.ipynb @@ -46,7 +46,7 @@ "from pyro.ops.stats import quantile\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "\n", "pyro.set_rng_seed(20200928)\n", "\n", diff --git a/tutorial/source/forecasting_i.ipynb b/tutorial/source/forecasting_i.ipynb index c55689aba6..42931b6a82 100644 --- a/tutorial/source/forecasting_i.ipynb +++ b/tutorial/source/forecasting_i.ipynb @@ -47,7 +47,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(20200221)" ] }, diff --git a/tutorial/source/forecasting_ii.ipynb b/tutorial/source/forecasting_ii.ipynb index df8592f217..11b2eb7d43 100644 --- a/tutorial/source/forecasting_ii.ipynb +++ b/tutorial/source/forecasting_ii.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(20200305)" ] }, diff --git a/tutorial/source/forecasting_iii.ipynb b/tutorial/source/forecasting_iii.ipynb index f218ab4fe0..9e6c7f15e3 100644 --- a/tutorial/source/forecasting_iii.ipynb +++ b/tutorial/source/forecasting_iii.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(20200305)" ] }, diff --git a/tutorial/source/gmm.ipynb b/tutorial/source/gmm.ipynb index 301ea4c719..973c6b2765 100644 --- a/tutorial/source/gmm.ipynb +++ b/tutorial/source/gmm.ipynb @@ -41,7 +41,7 @@ "from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/gp.ipynb b/tutorial/source/gp.ipynb index 5c1b61ac1e..db48a4bf63 100644 --- a/tutorial/source/gp.ipynb +++ b/tutorial/source/gp.ipynb @@ -69,7 +69,7 @@ "\n", "\n", "smoke_test = \"CI\" in os.environ # ignore; used to check code integrity in the Pyro repo\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(0)" ] }, diff --git a/tutorial/source/gplvm.ipynb b/tutorial/source/gplvm.ipynb index fd3a6e14d3..0765118e98 100644 --- a/tutorial/source/gplvm.ipynb +++ b/tutorial/source/gplvm.ipynb @@ -39,7 +39,7 @@ "import pyro.ops.stats as stats\n", "\n", "smoke_test = ('CI' in os.environ) # ignore; used to check code integrity in the Pyro repo\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.set_rng_seed(1)" ] }, diff --git a/tutorial/source/intro_long.ipynb b/tutorial/source/intro_long.ipynb index 1c842a5f48..de498fee0f 100644 --- a/tutorial/source/intro_long.ipynb +++ b/tutorial/source/intro_long.ipynb @@ -108,7 +108,7 @@ "outputs": [], "source": [ "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(1)\n", diff --git a/tutorial/source/jit.ipynb b/tutorial/source/jit.ipynb index 4404b82cbb..adebeb5774 100644 --- a/tutorial/source/jit.ipynb +++ b/tutorial/source/jit.ipynb @@ -48,7 +48,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/model_rendering.ipynb b/tutorial/source/model_rendering.ipynb index f34d83b4b7..3712844d6f 100644 --- a/tutorial/source/model_rendering.ipynb +++ b/tutorial/source/model_rendering.ipynb @@ -25,7 +25,7 @@ "import pyro.distributions.constraints as constraints\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/modules.ipynb b/tutorial/source/modules.ipynb index f52874bac6..404f1d21eb 100644 --- a/tutorial/source/modules.ipynb +++ b/tutorial/source/modules.ipynb @@ -61,7 +61,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')" + "assert pyro.__version__.startswith('1.8.3')" ] }, { diff --git a/tutorial/source/prior_predictive.ipynb b/tutorial/source/prior_predictive.ipynb index d01db9cb36..073b0a2e4c 100644 --- a/tutorial/source/prior_predictive.ipynb +++ b/tutorial/source/prior_predictive.ipynb @@ -46,7 +46,7 @@ "import pyro.poutine as poutine\n", "from pyro.infer.resampler import Resampler\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "smoke_test = ('CI' in os.environ) # for CI testing only" ] }, diff --git a/tutorial/source/prodlda.ipynb b/tutorial/source/prodlda.ipynb index 868c9fa99d..8fae553d80 100644 --- a/tutorial/source/prodlda.ipynb +++ b/tutorial/source/prodlda.ipynb @@ -70,7 +70,7 @@ "from pyro.infer import MCMC, NUTS\n", "import torch\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "# Enable smoke test - run the notebook cells on CI.\n", "smoke_test = 'CI' in os.environ" ] diff --git a/tutorial/source/stable.ipynb b/tutorial/source/stable.ipynb index 6abe840bff..77af299691 100644 --- a/tutorial/source/stable.ipynb +++ b/tutorial/source/stable.ipynb @@ -62,7 +62,7 @@ "from pyro.ops.tensor_utils import convolve\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "smoke_test = ('CI' in os.environ)" ] }, diff --git a/tutorial/source/svi_part_i.ipynb b/tutorial/source/svi_part_i.ipynb index 4fdd533fbe..d18a767778 100644 --- a/tutorial/source/svi_part_i.ipynb +++ b/tutorial/source/svi_part_i.ipynb @@ -260,7 +260,7 @@ "smoke_test = ('CI' in os.environ)\n", "n_steps = 2 if smoke_test else 2000\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "\n", "# clear the param store in case we're in a REPL\n", "pyro.clear_param_store()\n", diff --git a/tutorial/source/svi_part_iii.ipynb b/tutorial/source/svi_part_iii.ipynb index b78b464e58..d1154c7b66 100644 --- a/tutorial/source/svi_part_iii.ipynb +++ b/tutorial/source/svi_part_iii.ipynb @@ -283,7 +283,7 @@ "from pyro.infer import SVI, TraceGraph_ELBO\n", "import sys\n", "\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "\n", "# this is for running the notebook in our testing framework\n", "smoke_test = ('CI' in os.environ)\n", diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index d5cbfe8f3f..30891fca1b 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -59,7 +59,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "\n", "# We'll ue this helper to check our models are correct.\n", "def test_model(model, guide, loss):\n", diff --git a/tutorial/source/tracking_1d.ipynb b/tutorial/source/tracking_1d.ipynb index ec11f0a4ef..847575a7e8 100644 --- a/tutorial/source/tracking_1d.ipynb +++ b/tutorial/source/tracking_1d.ipynb @@ -30,7 +30,7 @@ "from pyro.optim import Adam\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "smoke_test = ('CI' in os.environ)" ] }, diff --git a/tutorial/source/vae.ipynb b/tutorial/source/vae.ipynb index d3d570ab28..9ce54e9457 100644 --- a/tutorial/source/vae.ipynb +++ b/tutorial/source/vae.ipynb @@ -115,7 +115,7 @@ "metadata": {}, "outputs": [], "source": [ - "assert pyro.__version__.startswith('1.8.2')\n", + "assert pyro.__version__.startswith('1.8.3')\n", "pyro.distributions.enable_validation(False)\n", "pyro.set_rng_seed(0)\n", "# Enable smoke test - run the notebook cells on CI.\n",