Source code for tramp.algos.initial_conditions

from ..base import ReprMixin, Variable, Factor
import numpy as np


class InitialConditions(ReprMixin):
    def init(self, message_key, shape, id, direction):
        if message_key == "a":
            return self.init_a(shape, id, direction)
        if message_key == "b":
            return self.init_b(shape, id, direction)


[docs]class ConstantInit(InitialConditions): def __init__(self, a=0, b=0): self.a = a self.b = b self.repr_init() def init_a(self, shape, id, direction): return self.a def init_b(self, shape, id, direction): assert shape is not None return self.b * np.ones(shape)
[docs]class NoisyInit(InitialConditions): def __init__(self, a_mean=0, a_var=0, b_mean=0, b_var=1): self.a_mean = a_mean self.a_var = a_var self.b_mean = b_mean self.b_var = b_var self.repr_init() self.a_sigma = np.sqrt(a_var) self.b_sigma = np.sqrt(b_var) def init_a(self, shape, id, direction): return self.a_mean + self.a_sigma * np.random.standard_normal() def init_b(self, shape, id, direction): assert shape is not None return self.b_mean + self.b_sigma * np.random.standard_normal(shape)
[docs]class CustomInit(InitialConditions): """Custom init on variables Parameters ---------- - a_init: list of variable.id, direction, a tuples Edges from/into `variable.id` and given `direction` will be initialized with a = `a` - b_init: list of variable.id, direction, b tuples Edges from/into `variable.id` and given `direction` will be initialized with b = `b` - a : float Default constant value for a. - b : float Default constant value for b. """ def __init__(self, a_init=None, b_init=None, a=0, b=0): a_init = a_init or [] self.a_init = {id: {direction: a} for id, direction, a in a_init} b_init = b_init or [] self.b_init = {id: {direction: b} for id, direction, b in b_init} self.a = a self.b = b self.repr_init() def init_a(self, shape, id, direction): try: a = self.a_init[id][direction] except KeyError: a = self.a return a def init_b(self, shape, id, direction): assert shape is not None try: b = self.b_init[id][direction] assert b.shape == shape except KeyError: b = self.b * np.ones(shape) return b