Source code for tramp.models.multi_layer_model

from ..priors.base_prior import Prior
from ..channels.base_channel import Channel
from ..likelihoods.base_likelihood import Likelihood
from ..variables import SISOVariable, SILeafVariable
from .base_model import Model


def check_layers(layers):
    if not isinstance(layers[0], Prior):
        raise ValueError("first layer must be a Prior")
    for i, layer in enumerate(layers[1:-1]):
        if not isinstance(layer, Channel):
            raise ValueError(f"intermediate layer i={i} must be a Channel")
    if isinstance(layers[-1], Channel):
        if layers[-1].n_next != 1:
            raise ValueError("last layer must be a Channel with one output")
    elif not isinstance(layers[-1], Likelihood):
        raise ValueError("last layer must be a Channel or a Likelihood")


def default_ids(n_layers):
    "Return x, t_1, ..., t_{L-1}, y"
    ids = [f"t_{l}" for l in range(n_layers)]
    ids[0] = "x"
    if n_layers > 1:
        ids[-1] = "y"
    return ids


[docs]class MultiLayerModel(Model): def __init__(self, layers, ids=None): check_layers(layers) n_layers = len(layers) ids = ids or default_ids(n_layers) if len(ids) != n_layers: raise ValueError(f"ids should be of length {n_layers}") self.n_layers = n_layers self.layers = layers self.ids = ids self.repr_init(pad=" ") def get_variable(l): V = SILeafVariable if l == n_layers-1 else SISOVariable return V(id=ids[l]) dag = layers[0] @ get_variable(0) for l in range(1, n_layers): dag = dag @ layers[l] @ get_variable(l) model_dag = dag.to_model_dag() Model.__init__(self, model_dag)