Source code for tramp.channels.linear.gradient_channel

import numpy as np
from numpy.fft import fftn, ifftn
from ..base_channel import Channel
from tramp.utils.conv_filters import gradient_filters
import logging
logger = logging.getLogger(__name__)



[docs]class GradientChannel(Channel): "Gradient channel x = grad z " def __init__(self, shape, real=True): self.d = len(shape) self.shape = shape self.real = real self.repr_init() self.filter = gradient_filters(shape) self.axes = list(range(1, self.d + 1)) # axes over which fft is taken # conv weights = time reversed filter; their ffts are conjugate self.w_fft_bar = fftn(self.filter, axes=self.axes) self.w_fft = np.conjugate(self.w_fft_bar) self.spectrum = (np.absolute(self.w_fft)**2).sum(axis=0) assert self.spectrum.shape == shape def convolve(self, z): if (z.shape != self.shape): raise ValueError(f"Bad shape for z: {z.shape} expected {self.shape}") z_fft = fftn(z) x_fft = self.w_fft * z_fft[np.newaxis,:] x = ifftn(x_fft, axes=self.axes) # no fft over axis=0 (grad direction) if self.real: x = np.real(x) return x def sample(self, Z): return self.convolve(Z) def math(self): return r"$\nabla$" def second_moment(self, tau_z): return tau_z * self.spectrum.mean() / self.d def compute_n_eff(self, az, ax): "Effective number of parameters = overlap in z" if ax == 0: logger.info(f"ax=0 in {self} compute_n_eff") return 0. if az / ax == 0: logger.info(f"az/ax=0 in {self} compute_n_eff") return 1. n_eff = np.mean(self.spectrum / (az / ax + self.spectrum)) return n_eff def compute_backward_mean(self, az, bz, ax, bx, return_fft=False): # estimate z from x = Wz bx_fft = fftn(bx, axes=self.axes) # no fft over axis=0 (grad direction) bz_fft = fftn(bz) resolvent = 1 / (az + ax * self.spectrum) rz_fft = resolvent * (bz_fft + (self.w_fft_bar * bx_fft).sum(axis=0)) if return_fft: return rz_fft rz = ifftn(rz_fft) if self.real: rz = np.real(rz) return rz def compute_forward_mean(self, az, bz, ax, bx): # estimate x from x = Wz we have rx = W rz rz_fft = self.compute_backward_mean(az, bz, ax, bx, return_fft=True) rx_fft = self.w_fft * rz_fft[np.newaxis,:] rx = ifftn(rx_fft, axes=self.axes) if self.real: rx = np.real(rx) return rx def compute_backward_variance(self, az, ax): assert az > 0 n_eff = self.compute_n_eff(az, ax) vz = (1 - n_eff) / az return vz def compute_forward_variance(self, az, ax): if ax == 0: s_mean = np.mean(self.spectrum) return s_mean / az n_eff = self.compute_n_eff(az, ax) vx = n_eff / (ax * self.d) return vx def compute_backward_posterior(self, az, bz, ax, bx): # estimate z from x = Wz rz = self.compute_backward_mean(az, bz, ax, bx) vz = self.compute_backward_variance(az, ax) return rz, vz def compute_forward_posterior(self, az, bz, ax, bx): # estimate x from x = Wz rx = self.compute_forward_mean(az, bz, ax, bx) vx = self.compute_forward_variance(az, ax) return rx, vx def compute_backward_error(self, az, ax, tau_z): vz = self.compute_backward_variance(az, ax) return vz def compute_forward_error(self, az, ax, tau_z): vx = self.compute_forward_variance(az, ax) return vx