Source code for pyroapi.dispatch

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Dispatching allows you to dynamically set a backend using :func:`pyro_backend`
and to register new backends using :func:`register_backend` .  It's easiest to
see how to use these by example:

.. code-block:: python

    from pyroapi import distributions as dist
    from pyroapi import infer, ops, optim, pyro, pyro_backend

    # These model and guide are backend-agnostic.
    def model():
        locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
        p = ops.tensor([0.2, 0.3, 0.5])
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(p))
            pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

    def guide():
        p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
        with pyro.plate("plate", len(data), dim=-1):
            pyro.sample("x", dist.Categorical(p))

    # We can now set a backend at inference time.
    with pyro_backend("numpyro"):
        elbo = infer.Trace_ELBO(ignore_jit_warnings=True)
        adam = optim.Adam({"lr": 1e-6})
        inference = infer.SVI(model, guide, adam, elbo)
        for step in range(10):
            loss = inference.step(*args, **kwargs)
            print("step {} loss = {}".format(step, loss))

"""
import importlib
from contextlib import contextmanager

DEFAULT_RNG_SEED = 1
_ALIASES = {}


class GenericModule(object):
    """
    Wrapper for a module that can be dynamically routed to a custom backend.
    """
    current_backend = {}
    _modules = {}

    def __init__(self, name, default_backend):
        assert isinstance(name, str)
        assert isinstance(default_backend, str)
        self._name = name
        GenericModule.current_backend[name] = default_backend

    def __getattribute__(self, name):
        module_name = super(GenericModule, self).__getattribute__('_name')
        backend = GenericModule.current_backend[module_name]
        try:
            module = GenericModule._modules[backend]
        except KeyError:
            module = importlib.import_module(backend)
            GenericModule._modules[backend] = module
        if name.startswith('__'):
            return getattr(module, name)  # allow magic attributes to return AttributeError
        try:
            return getattr(module, name)
        except AttributeError:
            raise NotImplementedError('This Pyro backend does not implement {}.{}'
                                      .format(module_name, name))


[docs]@contextmanager def pyro_backend(*aliases, **new_backends): """ Context manager to set a custom backend for Pyro models. Backends can be specified either by name (for standard backends or backends registered through :func:`register_backend` ) or by providing kwargs mapping module name to backend module name. Standard backends include: pyro, minipyro, funsor, and numpy. """ if aliases: assert len(aliases) == 1 assert not new_backends new_backends = _ALIASES[aliases[0]] old_backends = {} for name, new_backend in new_backends.items(): old_backends[name] = GenericModule.current_backend[name] GenericModule.current_backend[name] = new_backend try: with handlers.seed(rng_seed=DEFAULT_RNG_SEED): yield finally: for name, old_backend in old_backends.items(): GenericModule.current_backend[name] = old_backend
[docs]def register_backend(alias, new_backends): """ Register a new backend alias. For example:: register_backend("minipyro", { "infer": "pyro.contrib.minipyro", "optim": "pyro.contrib.minipyro", "pyro": "pyro.contrib.minipyro", }) :param str alias: The name of the new backend. :param dict new_backends: A dict mapping standard module name (str) to new module name (str). This needs to include only nonstandard backends (e.g. if your backend uses torch ops, you need not override ``ops``) """ assert isinstance(new_backends, dict) assert all(isinstance(key, str) for key in new_backends.keys()) assert all(isinstance(value, str) for value in new_backends.values()) _ALIASES[alias] = new_backends.copy()
# These modules can be overridden. pyro = GenericModule('pyro', 'pyro') distributions = GenericModule('distributions', 'pyro.distributions') handlers = GenericModule('handlers', 'pyro.poutine') infer = GenericModule('infer', 'pyro.infer') optim = GenericModule('optim', 'pyro.optim') ops = GenericModule('ops', 'torch') # These are standard backends. register_backend('pyro', { 'distributions': 'pyro.distributions', 'handlers': 'pyro.poutine', 'infer': 'pyro.infer', 'ops': 'torch', 'optim': 'pyro.optim', 'pyro': 'pyro', }) register_backend('minipyro', { 'distributions': 'pyro.distributions', 'handlers': 'pyro.poutine', 'infer': 'pyro.contrib.minipyro', 'ops': 'torch', 'optim': 'pyro.contrib.minipyro', 'pyro': 'pyro.contrib.minipyro', }) register_backend('funsor', { 'distributions': 'funsor.torch.distributions', 'handlers': 'funsor.minipyro', 'infer': 'funsor.minipyro', 'ops': 'funsor.compat.ops', 'optim': 'funsor.minipyro', 'pyro': 'funsor.minipyro', }) register_backend('numpy', { 'distributions': 'numpyro.compat.distributions', 'handlers': 'numpyro.compat.handlers', 'infer': 'numpyro.compat.infer', 'ops': 'numpyro.compat.ops', 'optim': 'numpyro.compat.optim', 'pyro': 'numpyro.compat.pyro', })