from ..algos import StateEvolution, CustomInit, ConstantInit
import numpy as np
import logging
logger = logging.getLogger(__name__)
def binary_search(f, xmin, xmax, xtol):
"Binary search on boolean f, assuming f(xmin)=0 and f(xmax)=1"
ymin, ymax = f(xmin), f(xmax)
if not (ymin == 0 and ymax == 1):
raise ValueError(f"Bad bounds: ymin={ymin} and ymax={ymax}")
max_iter = int(np.log2((xmax - xmin) / xtol)) + 2
for n_iter in range(1, max_iter + 1):
xmid = (xmin + xmax)/2
ymid = f(xmid)
xerr = xmax - xmin
logger.info(f"binary search {n_iter}/{max_iter} xerr={xerr}")
if (xerr < xtol):
break
if ymid == 0:
xmin, ymin = xmid, ymid
else:
xmax, ymax = xmid, ymid
assert ymin == 0 and ymax == 1
assert (xerr < xtol)
return dict(
xmid=xmid, xmin=xmin, xmax=xmax, xerr=xerr, n_iter=n_iter
)
def find_state_evolution_mse(id, a0, alpha, model_builder, **model_kwargs):
"""Find the variable mse according to the state evolution of the model.
Parameters
----------
id : str
id of the variables to infer (signal)
a0 : float
initial value of the a message id -> prior
alpha : float
measurement density
model_builder : function or class
**model_kwargs : dict
model_builder(**model_kwargs) must return a Model instance.
Returns
-------
v : float
The variable mse according to state evolution
"""
model = model_builder(alpha=alpha, **model_kwargs)
a_init = [(id, "bwd", a0)]
initializer = CustomInit(a_init=a_init)
se = StateEvolution(model)
se.iterate(max_iter=200, initializer=initializer)
v = se.get_variable_data(id=id)["v"]
return v
[docs]def find_critical_alpha(id, a0, mse_criterion, alpha_min, alpha_max,
model_builder, alpha_tol=1e-6, vtol=1e-3,
**model_kwargs):
"""Find critical value of the measurment density alpha.
It performs a binary search on alpha to find the minimal value of alpha for
which the mse criterion is satisfied.
Parameters
----------
id : str
id of the variable to infer (signal)
a0 : float
Initial value of the a message id -> prior
mse_criterion : {"random", "perfect"} or function
Criterion on the mse:
- "random" : search the maximal value of alpha for which v = tau_x (no better than random guess)
- "perfect" : search the minimal value of alpha for which v = 0 (perfect reconstruction)
- function : mse_criterion(v) must return False when alpha < alpha_c and True when alpha > alpha_c
alpha_min : float
Minimal value for the alpha search
alpha_max : float
Maximal value for the alpha search
alpha_tol : float,
Tolerance on alpha, default 1e-6
vtol : float
Tolerance on the variance v used in the "perfect" or "random" mse
criteria
"""
if mse_criterion == "perfect":
def mse_criterion(v):
return abs(v) < vtol
elif mse_criterion == "random":
# assuming that tau_x does not depend on alpha, we choose a fixed value
model = model_builder(alpha=0.5, **model_kwargs)
model.init_second_moments()
tau_x = model.get_second_moments()[id]
def mse_criterion(v):
return abs(v - tau_x) > vtol
def f(alpha):
v = find_state_evolution_mse(id, a0, alpha, model_builder, **model_kwargs)
return mse_criterion(v)
search = binary_search(f, alpha_min, alpha_max, alpha_tol)
alpha_c = search["xmid"]
return alpha_c