"""Implements the MAP_L21NormPrior class."""
import numpy as np
from .base_prior import Prior
import warnings
def l21_norm(x, axis):
    x_norm = np.linalg.norm(x, axis=axis, keepdims=False)
    return x_norm.sum()
def group_soft_threshold(x, gamma, axis):
    x_norm = np.linalg.norm(x, axis=axis, keepdims=True)  # broadcast against x
    return np.maximum(0, 1 - gamma / x_norm) * x
def v_group_soft_threshold(x, gamma, axis):
    x_norm = np.linalg.norm(x, axis=axis, keepdims=False)
    v = (x_norm > gamma) * (1 + (x**2 / x_norm**2 - 1) * gamma / x_norm)
    return v
[docs]class MAP_L21NormPrior(Prior):
    r"""MAP prior associated to the $\Vert . \Vert_{2,1}$ penalty.
    The corresponding factor is given by $f(x)=e^{-\gamma \Vert x \Vert_{2,1}}$
    where $\gamma$ is the regularization parameter.
    Parameters
    ----------
    size : tuple of int
        Shape of x
    gamma : float
        Regularization parameter $\gamma$
    axis : int
        Axis over which the $\Vert . \Vert_2$ norm is taken
    isotropic : bool
        Using isotropic or diagonal beliefs
    """
    def __init__(self, size, gamma=1, axis=0, isotropic=True):
        assert type(size)==tuple and len(size)>1, "size must be a tuple of length > 1"
        self.size = size
        self.gamma = gamma
        self.axis = axis
        self.isotropic = isotropic
        self.repr_init()
        self.N = np.prod(size)
        self.d = size[axis]
    def sample(self):
        warnings.warn(
            "MAP_L21NormPrior.sample not implemented "
            "return zero array as a placeholder"
        )
        return np.zeros(self.size)
    def math(self):
        return r"$\Vert . \Vert_{2,1}$"
    def second_moment(self):
        raise NotImplementedError
    def forward_second_moment_FG(self, tx_hat):
        raise NotImplementedError
    def compute_forward_posterior(self, ax, bx):
        rx = (1 / ax) * group_soft_threshold(bx, self.gamma, self.axis)
        vx = (1 / ax) * v_group_soft_threshold(bx, self.gamma, self.axis)
        if self.isotropic:
            vx = vx.mean()
        return rx, vx
    def compute_log_partition(self, ax, bx):
        rx = (1 / ax) * group_soft_threshold(bx, self.gamma, self.axis)
        A_sum = np.sum(bx*rx - 0.5*ax*(rx**2)) - self.gamma*l21_norm(rx, self.axis)
        return A_sum / self.N
    def b_measure(self, mx_hat, qx_hat, tx0_hat, f):
        raise NotImplementedError
    def bx_measure(self, mx_hat, qx_hat, tx0_hat, f):
        raise NotImplementedError
    def beliefs_measure(self, ax, f):
        raise NotImplementedError
    def measure(self, f):
        raise NotImplementedError