Source code for tramp.algos.expectation_propagation

from .message_passing import MessagePassing
from .callbacks import EarlyStopping, EarlyStoppingEP, JoinCallback


[docs]class ExpectationPropagation(MessagePassing): def __init__(self, model): model.init_shapes() super().__init__(model, message_keys=["a", "b"]) self.default_stopping = EarlyStoppingEP() def forward(self, node, message): return node.forward_message(message) def backward(self, node, message): return node.backward_message(message) def update(self, variable, message): r, v = variable.posterior_rv(message) return dict(r=r, v=v) def node_objective(self, node, message): return node.log_partition(message) def log_evidence(self, update=True): if update: self.update_objective() return self.A_model def surprisal(self, update=True): if update: self.update_objective() return -self.A_model