Source code for tramp.channels.shape.concat_channel

from ..base_channel import SOFactor
import numpy as np


[docs]class ConcatChannel(SOFactor): n_next = 1 def __init__(self, Ns, axis=0): self.Ns = Ns self.axis = axis self.repr_init() self.n_prev = len(Ns) self.N = sum(Ns) def sample(self, *Zs): if len(Zs) != self.n_prev: raise ValueError(f"expect {self.n_prev} arrays") for k, Z in enumerate(Zs): if (Z.shape[self.axis] != self.Ns[k]): raise ValueError( f"expect Z k={k} array of dimension {self.Ns[k]} " f"along axis {self.axis} " f"but got array of dimension {Z.shape[self.axis]}" ) X = np.concatenate(Zs, axis=self.axis) assert X.shape[self.axis] == self.N return X def math(self): return r"$\oplus$" def second_moment(self, *tau_zs): if len(tau_zs) != self.n_prev: raise ValueError(f"expect {self.n_prev} tau_zs") tau_x = sum(N * tau_z for N, tau_z in zip(self.Ns, tau_zs)) / self.N return tau_x def compute_forward_posterior(self, az, bz, ax, bx): "estimate x = [zk] from z={zk}" rz, vz = self.compute_backward_posterior(az, bz, ax, bx) rx = np.concatenate(rz, axis=self.axis) vx = sum(N * v for N, v in zip(self.Ns, vz)) / self.N return rx, vx def _compute_ak_bk(self, az, bz, ax, bx): for N, Z in zip(self.Ns, bz): assert bz.shape[self.axis] == N assert bx.shape[self.axis] == self.N idx = [0]+list(np.cumsum(self.Ns)) bx_subs = [ np.take(bx, range(idx_min, idx_max), axis=self.axis) for idx_min, idx_max in zip(idx[:-1], idx[1:]) ] ak = [a + ax for a in az] bk = [b + bx_sub for b, bx_sub in zip(bz, bx_subs)] return ak, bk def compute_backward_posterior(self, az, bz, ax, bx): "estimate z={zk} from x = [zk]" ak, bk = self._compute_ak_bk(az, bz, ax, bx) vz = [1 / a for a in ak] rz = [b / a for a, b in zip(ak, bk)] return rz, vz def compute_forward_error(self, az, ax, tau_z): vz = self.compute_backward_error(az, ax, tau_z) vx = sum(N * v for N, v in zip(self.Ns, vz)) / self.N return vx def compute_backward_error(self, az, ax, tau_z): ak = [a + ax for a in az] vz = [1 / ak for a in ak] return vz def compute_log_partition(self, az, bz, ax, bx): ak, bk = self._compute_ak_bk(az, bz, ax, bx) logZ = sum([ 0.5 * np.sum(b**2 / a + np.log(2 * np.pi / a)) for a, b in zip(ak, bk) ]) return logZ def compute_free_energy(self, az, ax, tau_z): raise NotImplementedError