Source code for tramp.channels.linear.unitary_channel

import numpy as np
from ..base_channel import Channel
from tramp.utils.misc import complex2array, array2complex


def check_unitary(U):
    if (U.shape[0] != U.shape[1]):
        raise ValueError(f"U.shape = {U.shape}")
    N = U.shape[0]
    if not np.allclose(U @ U.conj().T, np.identity(N)):
        raise ValueError("U not unitary")


[docs]class UnitaryChannel(Channel): """Unitary channel x = U z. Parameters ---------- - U: unitary matrix - name: str name of unitary matrix U for display Notes ----- For message passing it is more convenient to represent a complex array x as a real array X where X[0] = x.real and X[1] = x.imag In particular: - input of sample(): Z array of shape (2, z.shape) - output of sample(): X array of shape (2, x.shape) - message bz, posterior rz: real arrays of shape (2, z.shape) - message bx, posterior rx: real arrays of shape (2, x.shape) """ def __init__(self, U, name="U"): check_unitary(U) self.name = name self.N = U.shape[0] self.repr_init() self.U = U def sample(self, Z): "We assume Z[0] = Z.real and Z[1] = Z.imag" Z = array2complex(Z) X = self.U @ Z X = complex2array(X) assert X.shape == (2, self.N) return X def math(self): return r"$"+self.name+"$" def second_moment(self, tau_z): return tau_z def compute_forward_message(self, az, bz, ax, bx): # x = U z ax_new = az bz = array2complex(bz) bx_new = self.U @ bz bx_new = complex2array(bx_new) return ax_new, bx_new def compute_backward_message(self, az, bz, ax, bx): # z = U.conj().T x az_new = ax bx = array2complex(bx) bz_new = self.U.conj().T @ bx bz_new = complex2array(bz_new) return az_new, bz_new def compute_forward_state_evolution(self, az, ax, tau_z): ax_new = az return ax_new def compute_backward_state_evolution(self, az, ax, tau_z): az_new = ax return az_new def compute_log_partition(self, az, bz, ax, bx): b = complex2array( array2complex(bz) + self.U.conj().T @ array2complex(bx) ) a = az + ax logZ = 0.5 * np.sum(b**2 / a) + self.N * np.log(2 * np.pi / a) return logZ def compute_mutual_information(self, az, ax, tau_z): a = ax + az I = 0.5*np.log(a*tau_z) return I def compute_free_energy(self, az, ax, tau_z): tau_x = self.second_moment(tau_z) I = self.compute_mutual_information(az, ax, tau_z) A = 0.5*(az*tau_z + ax*tau_x) - I + 0.5*np.log(2*np.pi*tau_z/np.e) return A