"""
Base classes.
"""
import numpy as np
import logging
logger = logging.getLogger(__name__)
class ReprMixin():
_repr_initialized = False
def repr_init(self, pad=None, reinit=False):
if reinit or not self._repr_initialized:
self._repr_kwargs = self.__dict__.copy()
self._repr_pad = pad
self._repr_initialized = True
def __repr__(self):
if self._repr_pad:
pad = f"\n{self._repr_pad}"
else:
pad = ""
sep = ","
args = sep.join(
f"{pad}{key}={val}" for key, val in self._repr_kwargs.items()
)
if self._repr_pad:
args += "\n"
name = self.__class__.__name__
return f"{name}({args})"
# NOTE : message = [source,target,data]
def filter_message(message, direction):
filtered_message = [
(source, target, data)
for source, target, data in message
if data["direction"] == direction
]
return filtered_message
def inv(v):
"""Numerically safe inverse"""
return 1 / np.maximum(v, 1e-20)
[docs]class Variable(ReprMixin):
def __init__(self, id, n_prev, n_next):
self.id = id
self.n_prev = n_prev
self.n_next = n_next
self.repr_init()
def __add__(self, other):
from .models.dag_algebra import DAG
return DAG(self) + other
def __matmul__(self, other):
from .models.dag_algebra import DAG
return DAG(self) @ other
def math(self):
return r"$" + self.id + r"$"
def check_message(self, message):
for source, target, data in message:
if (target != self):
raise ValueError(f"target {target} is not the instance {self}")
if not isinstance(source, Factor):
raise ValueError(f"source {source} is not a Factor")
n_next = len(filter_message(message, "bwd"))
n_prev = len(filter_message(message, "fwd"))
if (self.n_next != n_next):
raise ValueError(
f"number of next factors : expected {self.n_next} got {n_next}")
if (self.n_prev != n_prev):
raise ValueError(
f"number of prev factors : expected {self.n_prev} got {n_prev}")
def _parse_message_ab(self, message):
# prev factor k send fwd message
k_message = filter_message(message, "fwd")
assert len(k_message) == self.n_prev
ak = [data["a"] for source, target, data in k_message]
bk = [data["b"] for source, target, data in k_message]
k_source = [source for source, target, data in k_message]
if self.n_prev == 1:
ak = ak[0]
bk = bk[0]
k_source = k_source[0]
# next factor l send bwd message
l_message = filter_message(message, "bwd")
assert len(l_message) == self.n_next
al = [data["a"] for source, target, data in l_message]
bl = [data["b"] for source, target, data in l_message]
l_source = [source for source, target, data in l_message]
if self.n_next == 1:
al = al[0]
bl = bl[0]
l_source = l_source[0]
return k_source, l_source, ak, bk, al, bl
def _parse_message_a(self, message):
# prev factor k send fwd message
k_message = filter_message(message, "fwd")
assert len(k_message) == self.n_prev
ak = [data["a"] for source, target, data in k_message]
k_source = [source for source, target, data in k_message]
if self.n_prev == 1:
ak = ak[0]
k_source = k_source[0]
# next factor l send bwd message
l_message = filter_message(message, "bwd")
assert len(l_message) == self.n_next
al = [data["a"] for source, target, data in l_message]
l_source = [source for source, target, data in l_message]
if self.n_next == 1:
al = al[0]
l_source = l_source[0]
return k_source, l_source, ak, al
def _parse_tau(self, message):
source, target, data = message[0]
return data["tau"]
def compute_mutual_information(self, ax, tau_x):
I = 0.5*np.log(ax*tau_x)
return I
def compute_free_energy(self, ax, tau_x):
I = self.compute_mutual_information(ax, tau_x)
A = 0.5*ax*tau_x - I + 0.5*np.log(2*np.pi*tau_x/np.e)
return A
def compute_dual_mutual_information(self, vx, tau_x):
I_dual = 0.5*np.log(tau_x/vx) - 0.5
return I_dual
def compute_dual_free_energy(self, mx, tau_x):
A_dual = 0.5*np.log(2*np.pi*(tau_x - mx))
return A_dual
def compute_log_partition(self, ax, bx):
if ax<=0:
return np.inf
logZ = 0.5 * np.sum(bx**2 / ax + np.log(2*np.pi/ax))
return logZ
def posterior_ab(self, message):
a_hat = sum(data["a"] for source, target, data in message)
b_hat = sum(data["b"] for source, target, data in message)
return a_hat, b_hat
def posterior_rv(self, message):
a_hat, b_hat = self.posterior_ab(message)
r_hat = b_hat / a_hat
v_hat = 1. / a_hat
return r_hat, v_hat
def posterior_a(self, message):
a_hat = sum(data["a"] for source, target, data in message)
return a_hat
def posterior_v(self, message):
a_hat = self.posterior_a(message)
v_hat = 1. / a_hat
return v_hat
def log_partition(self, message):
ax, bx = self.posterior_ab(message)
logZ = self.compute_log_partition(ax, bx)
return logZ
def free_energy(self, message):
ax = self.posterior_a(message)
tau_x = self._parse_tau(message)
A = self.compute_free_energy(ax, tau_x)
return A
def forward_message(self, message):
if self.n_next == 0:
return []
a_hat, b_hat = self.posterior_ab(message)
# next factor l send bwd message
l_message = filter_message(message, "bwd")
new_message = [
(target, source,
dict(a=a_hat - data["a"], b=b_hat - data["b"], direction="fwd"))
for source, target, data in l_message
]
return new_message
def backward_message(self, message):
if self.n_prev == 0:
return []
a_hat, b_hat = self.posterior_ab(message)
# prev factor k send fwd message
k_message = filter_message(message, "fwd")
new_message = [
(target, source,
dict(a=a_hat - data["a"], b=b_hat - data["b"], direction="bwd"))
for source, target, data in k_message
]
return new_message
def forward_state_evolution(self, message):
if self.n_next == 0:
return []
a_hat = self.posterior_a(message)
# next factor l send bwd message
l_message = filter_message(message, "bwd")
new_message = [
(target, source,
dict(a=a_hat - data["a"], direction="fwd"))
for source, target, data in l_message
]
return new_message
def backward_state_evolution(self, message):
if self.n_prev == 0:
return []
a_hat = self.posterior_a(message)
# prev factor k send fwd message
k_message = filter_message(message, "fwd")
new_message = [
(target, source,
dict(a=a_hat - data["a"], direction="bwd"))
for source, target, data in k_message
]
return new_message
class Factor(ReprMixin):
AMAX = 1e+11
AMIN = 1e-11
def reset_precision_bounds(self, AMIN, AMAX):
self.AMIN = AMIN
self.AMAX = AMAX
def compute_a_new(self, v, a):
"Compute a_new and b_new ensuring that a_new is between AMIN and AMAX"
a_new = np.clip(inv(v) - a, self.AMIN, self.AMAX)
return a_new
def compute_ab_new(self, r, v, a, b):
"Compute a_new and b_new ensuring that a_new is between AMIN and AMAX"
a_new = np.clip(inv(v) - a, self.AMIN, self.AMAX)
v_inv = (a + a_new)
b_new = r * v_inv - b
return a_new, b_new
def compute_a_mhat_qhat_new(self, v, m, q, a, m_hat, q_hat, t0):
a_new = np.clip(inv(v) - a, self.AMIN, self.AMAX)
v_inv = (a + a_new)
m_hat_new = v_inv * (m / t0) - m_hat
q_hat_new = (v_inv**2) * (q - m**2 / t0) - q_hat
return a_new, m_hat_new, q_hat_new
def __add__(self, other):
from .models.dag_algebra import DAG
return DAG(self) + other
def __matmul__(self, other):
from .models.dag_algebra import DAG
return DAG(self) @ other
def check_message(self, message):
for source, target, data in message:
if (target != self):
raise ValueError(f"target {target} is not the instance {self}")
if not isinstance(source, Variable):
raise ValueError(f"source {source} is not a Variable")
n_prev = len(filter_message(message, "fwd"))
n_next = len(filter_message(message, "bwd"))
if self.n_prev != n_prev:
raise ValueError(f"expected n_prev={self.n_prev} got {n_prev}")
if self.n_next != n_next:
raise ValueError(f"expected n_next={self.n_next} got {n_next}")
def _parse_message_ab(self, message):
# prev variable z send fwd message
z_message = filter_message(message, "fwd")
assert len(z_message) == self.n_prev
az = [data["a"] for source, target, data in z_message]
bz = [data["b"] for source, target, data in z_message]
z_source = [source for source, target, data in z_message]
if self.n_prev == 1:
az = az[0]
bz = bz[0]
z_source = z_source[0]
# next variable x send bwd message
x_message = filter_message(message, "bwd")
assert len(x_message) == self.n_next
ax = [data["a"] for source, target, data in x_message]
bx = [data["b"] for source, target, data in x_message]
x_source = [source for source, target, data in x_message]
if self.n_next == 1:
ax = ax[0]
bx = bx[0]
x_source = x_source[0]
return z_source, x_source, az, bz, ax, bx
def _parse_message_a(self, message):
# prev variable z send fwd message
z_message = filter_message(message, "fwd")
assert len(z_message) == self.n_prev
az = [data["a"] for source, target, data in z_message]
tau_z = [data["tau"] for source, target, data in z_message]
z_source = [source for source, target, data in z_message]
if self.n_prev == 1:
az = az[0]
tau_z = tau_z[0]
z_source = z_source[0]
# next variable x send bwd message
x_message = filter_message(message, "bwd")
assert len(x_message) == self.n_next
ax = [data["a"] for source, target, data in x_message]
x_source = [source for source, target, data in x_message]
if self.n_next == 1:
ax = ax[0]
x_source = x_source[0]
return z_source, x_source, az, ax, tau_z
def forward_message(self, message):
if self.n_next == 0:
return []
z_source, x_source, az, bz, ax, bx = self._parse_message_ab(message)
if self.n_prev == 0:
ax_new, bx_new = self.compute_forward_message(ax, bx)
else:
ax_new, bx_new = self.compute_forward_message(az, bz, ax, bx)
if self.n_next == 1:
new_message = [(
self, x_source, dict(a=ax_new, b=bx_new, direction="fwd")
)]
else:
new_message = [
(self, source, dict(a=a, b=b, direction="fwd"))
for a, b, source in zip(ax_new, bx_new, x_source)
]
return new_message
def backward_message(self, message):
if self.n_prev == 0:
return []
z_source, x_source, az, bz, ax, bx = self._parse_message_ab(message)
if self.n_next == 0:
az_new, bz_new = self.compute_backward_message(az, bz)
else:
az_new, bz_new = self.compute_backward_message(az, bz, ax, bx)
if self.n_prev == 1:
new_message = [(
self, z_source, dict(a=az_new, b=bz_new, direction="bwd")
)]
else:
new_message = [
(self, source, dict(a=a, b=b, direction="bwd"))
for a, b, source in zip(az_new, bz_new, z_source)
]
return new_message
def log_partition(self, message):
z_source, x_source, az, bz, ax, bx = self._parse_message_ab(message)
if self.n_prev == 0:
logZ = self.compute_log_partition(ax, bx)
elif self.n_next == 0:
logZ = self.compute_log_partition(az, bz, self.y)
else:
logZ = self.compute_log_partition(az, bz, ax, bx)
return logZ
def forward_state_evolution(self, message):
if self.n_next == 0:
return []
z_source, x_source, az, ax, tau_z = self._parse_message_a(message)
if self.n_prev == 0:
ax_new = self.compute_forward_state_evolution(ax)
else:
ax_new = self.compute_forward_state_evolution(az, ax, tau_z)
if self.n_next == 1:
new_message = [(
self, x_source, dict(a=ax_new, direction="fwd")
)]
else:
new_message = [
(self, source, dict(a=a, direction="fwd"))
for a, source in zip(ax_new, x_source)
]
return new_message
def backward_state_evolution(self, message):
if self.n_prev == 0:
return []
z_source, x_source, az, ax, tau_z = self._parse_message_a(message)
if self.n_next == 0:
az_new = self.compute_backward_state_evolution(az, tau_z)
else:
az_new = self.compute_backward_state_evolution(az, ax, tau_z)
if self.n_prev == 1:
new_message = [(
self, z_source, dict(a=az_new, direction="bwd")
)]
else:
new_message = [
(self, source, dict(a=a, direction="bwd"))
for a, source in zip(az_new, z_source)
]
return new_message
def free_energy(self, message):
z_source, x_source, az, ax, tau_z = self._parse_message_a(message)
if self.n_prev == 0:
logZ = self.compute_free_energy(ax)
elif self.n_next == 0:
logZ = self.compute_free_energy(az, tau_z)
else:
logZ = self.compute_free_energy(az, ax, tau_z)
return logZ
def compute_forward_message(self, az, bz, ax, bx):
rx, vx = self.compute_forward_posterior(az, bz, ax, bx)
ab_new = [
self.compute_ab_new(rk, vk, ak, bk)
for rk, vk, ak, bk in zip(rx, vx, ax, bx)
]
ax_new = [a for a, b in ab_new]
bx_new = [b for a, b in ab_new]
return ax_new, bx_new
def compute_backward_message(self, az, bz, ax, bx):
rz, vz = self.compute_backward_posterior(az, bz, ax, bx)
ab_new = [
self.compute_ab_new(rk, vk, ak, bk)
for rk, vk, ak, bk in zip(rz, vz, az, bz)
]
az_new = [a for a, b in ab_new]
bz_new = [b for a, b in ab_new]
return az_new, bz_new
def compute_forward_state_evolution(self, az, ax, tau_z):
vx = self.compute_forward_error(az, ax, tau_z)
ax_new = [self.compute_a_new(vk, ak) for vk, ak in zip(vx, ax)]
return ax_new
def compute_backward_state_evolution(self, az, ax, tau_z):
vz = self.compute_backward_error(az, ax, tau_z)
az_new = [self.compute_a_new(vk, ak) for vk, ak in zip(vz, az)]
return az_new
def compute_forward_overlap(self, az, ax, tau_z):
vx = self.compute_forward_error(az, ax, tau_z)
tau_x = self.second_moment(tau_z)
mx = [tau_k - vk for tau_k, vk in zip(tau_x, vx)]
return mx
def compute_backward_overlap(self, az, ax, tau_z):
vz = self.compute_backward_error(az, ax, tau_z)
mz = [tau_k - vk for tau_k, vk in zip(tau_z, vz)]
return mz