"""Callbacks for ExpectationPropagation and StateEvolution algorithms."""
from .metrics import METRICS
import pandas as pd
import numpy as np
from ..base import ReprMixin
import logging
logger = logging.getLogger(__name__)
[docs]class Callback(ReprMixin):
pass
[docs]class PassCallback(Callback):
def __init__(self):
self.repr_init()
def __call__(self, algo, i, max_iter):
pass
[docs]class JoinCallback(Callback):
def __init__(self, callbacks):
self.callbacks = callbacks
self.repr_init(pad="\t")
def __call__(self, algo, i, max_iter):
stops = [
callback(algo, i, max_iter) for callback in self.callbacks
]
stop = any(stops)
return stop
[docs]class LogProgress(Callback):
def __init__(self, ids="all", every=1):
self.ids = ids
self.every = every
self.repr_init()
def __call__(self, algo, i, max_iter):
if (i % self.every == 0):
variables_data = algo.get_variables_data(self.ids)
logger.info(f"iteration={i+1}/{max_iter}")
for variable_id, data in variables_data.items():
logger.info(f"id={variable_id} v={data['v']:.3f}")
[docs]class TrackMessages(Callback):
def __init__(self, keys=["a", "n_iter", "direction"]):
self.keys = keys
self.records = []
def __call__(self, algo, i, max_iter):
if (i == 0):
self.records = []
self.records += algo.get_edges_data(self.keys)
[docs] def get_dataframe(self):
return pd.DataFrame(self.records)
[docs]class TrackObjective(Callback):
def __init__(self):
self.edge_records = []
self.node_records = []
self.model_records = []
def __call__(self, algo, i, max_iter):
if (i == 0):
self.records = []
algo.update_objective()
# model
model_record = dict(A=algo.A_model, n_iter=algo.n_iter)
self.model_records.append(model_record)
# edges
self.edge_records += algo.get_edges_data(["A", "n_iter", "direction"])
# nodes
self.node_records += algo.get_nodes_data(["A", "n_iter"])
[docs] def get_dataframe(self):
edge_df = pd.DataFrame(self.edge_records)
node_df = pd.DataFrame(self.node_records)
model_df = pd.DataFrame(self.model_records)
return edge_df, node_df, model_df
[docs]class TrackEvolution(Callback):
def __init__(self, ids="all", every=1, verbose=False):
self.ids = ids
self.every = every
self.repr_init()
self.records = []
self.verbose = verbose
def __call__(self, algo, i, max_iter):
if (i == 0):
self.records = []
if (i % self.every == 0):
variables_data = algo.get_variables_data(self.ids)
for variable_id, data in variables_data.items():
record = dict(id=variable_id, v=data["v"], iter=i)
self.records.append(record)
if self.verbose:
print(record)
[docs] def get_dataframe(self):
return pd.DataFrame(self.records)
[docs]class TrackEstimate(Callback):
def __init__(self, ids="all", every=1):
self.ids = ids
self.every = every
self.repr_init()
self.records = []
def __call__(self, algo, i, max_iter):
if (i == 0):
self.records = []
if (i % self.every == 0):
variables_data = algo.get_variables_data(self.ids)
for variable_id, data in variables_data.items():
record = dict(id=variable_id, r=data["r"], iter=i)
self.records.append(record)
[docs] def get_dataframe(self):
return pd.DataFrame(self.records)
[docs]class TrackErrors(Callback):
def __init__(self, true_values, metrics=["mse"], every=1, verbose=False):
self.ids = true_values.keys()
self.metrics = metrics
self.every = every
self.repr_init()
self.X_true = true_values
self.errors = []
self.verbose = verbose
def __call__(self, algo, i, max_iter):
if (i == 0):
self.errors = []
if (i % self.every == 0):
variables_data = algo.get_variables_data(self.ids)
X_pred = {
variable_id: data["r"]
for variable_id, data in variables_data.items()
}
errors = []
for id in self.ids:
error = dict(id=id, iter=i)
for metric in self.metrics:
func = METRICS.get(metric)
error[metric] = func(X_pred[id], self.X_true[id])
errors.append(error)
if self.verbose:
print(errors)
self.errors += errors
[docs] def get_dataframe(self):
return pd.DataFrame(self.errors)
[docs]class TrackOverlaps(Callback):
def __init__(self, true_values, ids="all", every=1, verbose=False):
self.ids = ids
self.every = every
self.repr_init()
self.X_true = true_values
self.records = []
self.verbose = verbose
def __call__(self, algo, i, max_iter):
if (i == 0):
self.records = []
if (i % self.every == 0):
variables_data = algo.get_variables_data(self.ids)
for variable_id, data in variables_data.items():
m = 1/self.X_true[variable_id].shape[0] * \
(data['r'].T).dot(self.X_true[variable_id])
q = 1/self.X_true[variable_id].shape[0] * \
(data['r'].T).dot(data['r'])
Q = 1/self.X_true[variable_id].shape[0] * \
(self.X_true[variable_id].T).dot(self.X_true[variable_id])
record = dict(id=variable_id, m=m, q=q, Q=Q, iter=i)
self.records.append(record)
if self.verbose:
print(record)
[docs] def get_dataframe(self):
return pd.DataFrame(self.records)
[docs]class EarlyStopping(Callback):
def __init__(self, ids="all", tol=1e-6, min_variance=-1,
wait_increase=5, max_increase=0.2):
self.ids = ids
self.tol = tol
self.min_variance = min_variance
self.wait_increase = wait_increase
self.max_increase = max_increase
self.repr_init()
self.old_vs = None
def __call__(self, algo, i, max_iter):
if (i == 0):
self.old_vs = None
variables_data = algo.get_variables_data(self.ids)
new_vs = [data["v"] for variable_id, data in variables_data.items()]
if any(v < self.min_variance for v in new_vs):
logger.info(f"early stopping min variance {min(new_vs)}")
return True
if any(np.isnan(v) for v in new_vs):
logger.warning("early stopping nan values")
logger.info("restoring old message dag")
algo.reset_message_dag(self.old_message_dag)
return True
if self.old_vs:
tols = [
np.abs(old_v - new_v)
for old_v, new_v in zip(self.old_vs, new_vs)
]
if max(tols) < self.tol:
logger.info(
"early stopping all tolerances (on v) are "
f"below tol={self.tol:.2e}"
)
return True
increase = [
new_v - old_v for old_v, new_v in zip(self.old_vs, new_vs)
]
if i > self.wait_increase and max(increase) > self.max_increase:
logger.info(
f"increase={max(increase)} above "
f"max_increase={self.max_increase:.2e}"
)
logger.info("restoring old message dag")
algo.reset_message_dag(self.old_message_dag)
return True
# for next iteration
self.old_vs = new_vs
self.old_message_dag = algo.message_dag.copy()
[docs]def norm(x):
return np.sqrt(np.mean(x**2))
[docs]class EarlyStoppingEP(Callback):
def __init__(self, ids="all", tol=1e-6, wait_increase=5, max_increase=0.2):
self.ids = ids
self.tol = tol
self.wait_increase = wait_increase
self.max_increase = max_increase
self.repr_init()
self.old_rs = None
def __call__(self, algo, i, max_iter):
if (i == 0):
self.old_rs = None
variables_data = algo.get_variables_data(self.ids)
new_rs = [data["r"] for variable_id, data in variables_data.items()]
if self.old_rs:
tols = [
norm(new_r - old_r)/norm(new_r)
for old_r, new_r in zip(self.old_rs, new_rs)
]
if max(tols) < self.tol:
logger.info(
"early stopping all tolerances (on r) are "
f"below tol={self.tol:.2e}"
)
return True
increase = tols
if i > self.wait_increase and max(increase) > self.max_increase:
logger.info(
f"increase={max(increase)} above "
f"max_increase={self.max_increase:.2e}"
)
logger.info("restoring old message dag")
algo.reset_message_dag(self.old_message_dag)
return True
# for next iteration
self.old_rs = new_rs
self.old_message_dag = algo.message_dag.copy()