Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/pyro.poutine.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ __________________
:undoc-members:
:show-inheritance:

SubstituteMessenger
___________________

.. automodule:: pyro.poutine.substitute_messenger
:members:
:undoc-members:
:show-inheritance:

TraceMessenger
_______________

Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
replay,
scale,
seed,
substitute,
trace,
uncondition,
)
Expand Down Expand Up @@ -47,6 +48,7 @@
"queue",
"scale",
"seed",
"substitute",
"trace",
"Trace",
"uncondition",
Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .runtime import NonlocalExit
from .scale_messenger import ScaleMessenger
from .seed_messenger import SeedMessenger
from .substitute_messenger import SubstituteMessenger
from .trace_messenger import TraceMessenger
from .uncondition_messenger import UnconditionMessenger

Expand All @@ -97,6 +98,7 @@
SeedMessenger,
TraceMessenger,
UnconditionMessenger,
SubstituteMessenger,
]


Expand Down
85 changes: 85 additions & 0 deletions pyro/poutine/substitute_messenger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings

from pyro import params
from pyro.poutine.messenger import Messenger
from pyro.poutine.util import is_validation_enabled


class SubstituteMessenger(Messenger):
"""
Given a stochastic function with param calls and a set of parameter values,
create a stochastic function where all param calls are substituted with
the fixed values.
data should be a dict of names to values.
Consider the following Pyro program:

>>> def model(x):
... a = pyro.param("a", torch.tensor(0.5))
... x = pyro.sample("x", dist.Bernoulli(probs=a))
... return x
>>> substituted_model = pyro.poutine.substitute(model, data={"s": 0.3})

In this example, site `a` will now have value `0.3`.
:param data: dictionary of values keyed by site names.
:returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger`
"""

def __init__(self, data):
"""
:param data: values for the parameters.
Constructor
"""
super().__init__()
self.data = data
self._data_cache = {}

def __enter__(self):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
self._param_hits = set()
self._param_misses = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
extra = set(self.data) - self._param_hits
if extra:
warnings.warn(
"pyro.module data did not find params ['{}']. "
"Did you instead mean one of ['{}']?".format(
"', '".join(extra), "', '".join(self._param_misses)
)
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
return None

def _pyro_param(self, msg):
"""
Overrides the `pyro.param` with substituted values.
If the param name does not match the name the keys in `data`,
that param value is unchanged.
"""
name = msg["name"]
param_name = params.user_param_name(name)

if param_name in self.data.keys():
msg["value"] = self.data[param_name]
if is_validation_enabled():
self._param_hits.add(param_name)
else:
if is_validation_enabled():
self._param_misses.add(param_name)
return None

if name in self._data_cache:
# Multiple pyro.param statements with the same
# name. Block the site and fix the value.
msg["value"] = self._data_cache[name]["value"]
else:
self._data_cache[name] = msg
30 changes: 30 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,36 @@ def _test_scale_factor(batch_size_outer, batch_size_inner, expected):
_test_scale_factor(2, 1, [2.0] * 2)


class SubstituteHandlerTests(NormalNormalNormalHandlerTestCase):
def test_substitute(self):
data = {"loc1": torch.randn(2)}
tr2 = poutine.trace(poutine.substitute(self.guide, data=data)).get_trace()
assert "loc1" in tr2
assert tr2.nodes["loc1"]["type"] == "param"
assert tr2.nodes["loc1"]["value"] is data["loc1"]

def test_stack_overwrite_behavior(self):
data1 = {"loc1": torch.randn(2)}
data2 = {"loc1": torch.randn(2)}
with poutine.trace() as tr:
cm = poutine.substitute(
poutine.substitute(self.guide, data=data1), data=data2
)
cm()
assert tr.trace.nodes["loc1"]["value"] is data2["loc1"]

def test_stack_success(self):
data1 = {"loc1": torch.randn(2)}
data2 = {"loc2": torch.randn(2)}
tr = poutine.trace(
poutine.substitute(poutine.substitute(self.guide, data=data1), data=data2)
).get_trace()
assert tr.nodes["loc1"]["type"] == "param"
assert tr.nodes["loc1"]["value"] is data1["loc1"]
assert tr.nodes["loc2"]["type"] == "param"
assert tr.nodes["loc2"]["value"] is data2["loc2"]


class ConditionHandlerTests(NormalNormalNormalHandlerTestCase):
def test_condition(self):
data = {"latent2": torch.randn(2)}
Expand Down