Source code for tramp.channels.activation.piecewise_linear_channel

import numpy as np
from ..base_channel import Channel
from scipy.special import logsumexp, softmax
from tramp.utils.linear_region import LinearRegion


class PiecewiseLinearChannel(Channel):
    def __init__(self, name, regions):
        self.repr_init()
        self.name = name
        self.regions = [LinearRegion(**region) for region in regions]
        self.n_regions = len(regions)

    def sample(self, Z):
        X = sum(region.sample(Z) for region in self.regions)
        return X

    def math(self):
        return r"$\textrm{" + self.name + r"}$"

    def second_moment(self, tau_z):
        taus = [region.second_moment(tau_z) for region in self.regions]
        ps = [region.proba_tau(tau_z) for region in self.regions]
        tau_x = sum(p*tau for p, tau in zip(ps, taus))
        return tau_x

    def merge_estimates(self, rs, vs, As):
        ps = softmax(As, axis=0)
        r = sum(p*r for p, r in zip(ps, rs))
        Dr = sum(
            ps[i]*ps[j]*(rs[i] - rs[j])**2
            for i in range(self.n_regions)
            for j in range(i+1, self.n_regions)
        )
        v = sum(p*v for p, v in zip(ps, vs)) + Dr
        v = v.mean()
        return r, v

    def compute_forward_posterior(self, az, bz, ax, bx):
        rs = [region.forward_mean(az, bz, ax, bx) for region in self.regions]
        vs = [region.forward_variance(az, bz, ax, bx)
              for region in self.regions]
        As = [region.log_partitions(az, bz, ax, bx) for region in self.regions]
        r, v = self.merge_estimates(rs, vs, As)
        return r, v

    def compute_backward_posterior(self, az, bz, ax, bx):
        rs = [region.backward_mean(az, bz, ax, bx) for region in self.regions]
        vs = [region.backward_variance(az, bz, ax, bx)
              for region in self.regions]
        As = [region.log_partitions(az, bz, ax, bx) for region in self.regions]
        r, v = self.merge_estimates(rs, vs, As)
        return r, v

    def compute_log_partition(self, az, bz, ax, bx):
        As = [region.log_partitions(az, bz, ax, bx) for region in self.regions]
        A = logsumexp(As, axis=0)
        return A.sum()

    def beliefs_measure(self, az, ax, tau_z, f):
        mu = sum(
            region.beliefs_measure(az, ax, tau_z, f) for region in self.regions
        )
        return mu

    def measure(self, f, zmin, zmax):
        assert zmin < zmax
        mu = sum(region.measure(f, zmin, zmax) for region in self.regions)
        return mu


[docs]class LeakyReluChannel(PiecewiseLinearChannel): def __init__(self, slope): self.slope = slope neg = dict(zmin=-np.inf, zmax=0, slope=slope, x0=0) pos = dict(zmin=0, zmax=np.inf, slope=1, x0=0) super().__init__(name="l-relu", regions=[pos, neg])
[docs]class SgnChannel(PiecewiseLinearChannel): def __init__(self): neg = dict(zmin=-np.inf, zmax=0, slope=0, x0=-1) pos = dict(zmin=0, zmax=+np.inf, slope=0, x0=+1) super().__init__(name="sgn", regions=[pos, neg])
[docs]class AbsChannel(PiecewiseLinearChannel): def __init__(self): neg = dict(zmin=-np.inf, zmax=0, slope=-1, x0=0) pos = dict(zmin=0, zmax=+np.inf, slope=+1, x0=0) super().__init__(name="abs", regions=[pos, neg])
[docs]class AsymmetricAbsChannel(PiecewiseLinearChannel): def __init__(self, shift=1e-4): self.shift = shift neg = dict(zmin=-np.inf, zmax=shift, slope=-1, x0=0) pos = dict(zmin=shift, zmax=+np.inf, slope=+1, x0=0) super().__init__(name="a-abs", regions=[pos, neg])
[docs]class ReluChannel(PiecewiseLinearChannel): def __init__(self): neg = dict(zmin=-np.inf, zmax=0, slope=0, x0=0) pos = dict(zmin=0, zmax=+np.inf, slope=1, x0=0) super().__init__(name="relu", regions=[pos, neg])
[docs]class HardTanhChannel(PiecewiseLinearChannel): def __init__(self): neg = dict(zmin=-np.inf, zmax=-1, slope=0, x0=-1) mid = dict(zmin=-1, zmax=+1, slope=1, x0=0) pos = dict(zmin=1, zmax=np.inf, slope=0, x0=1) super().__init__(name="h-tanh", regions=[pos, mid, neg])
[docs]class HardSigmoidChannel(PiecewiseLinearChannel): def __init__(self): L = 2.5 neg = dict(zmin=-np.inf, zmax=-L, slope=0, x0=0) mid = dict(zmin=-L, zmax=+L, slope=1/(2*L), x0=0.5) pos = dict(zmin=L, zmax=np.inf, slope=0, x0=1) super().__init__(name="h-sigm", regions=[pos, mid, neg])
[docs]class SymmetricDoorChannel(PiecewiseLinearChannel): def __init__(self, width): self.width = width neg = dict(zmin=-np.inf, zmax=-width, slope=0, x0=+1) mid = dict(zmin=-width, zmax=+width, slope=0, x0=-1) pos = dict(zmin=+width, zmax=+np.inf, slope=0, x0=+1) super().__init__(name="door", regions=[pos, mid, neg])