Source code for tramp.algos.metrics

import numpy as np
from ..utils.misc import complex2array, array2complex


[docs]def mean_squared_error(x_true, x_pred): return np.mean((x_true - x_pred)**2)
def sign_symmetric_mse(x_true, x_pred): "Mean squared error up to a global sign" mse_pos = np.mean((x_true - x_pred) ** 2) mse_neg = np.mean((x_true + x_pred) ** 2) mse = min(mse_pos, mse_neg) return mse def phase_symmetric_mse(x_true, x_pred): "Mean squared error up to a global phase" mses = [] for phi in np.linspace(0, 2*np.pi, 100): x_phase = complex2array(np.exp(phi*1j)*array2complex(x_pred)) mses.append(mean_squared_error(x_true, x_phase)) mse = min(mses) return mse
[docs]def overlap(x_true, x_pred): return np.mean(x_true * x_pred)
METRICS = { "sign_mse": sign_symmetric_mse, "phase_mse": phase_symmetric_mse, "mse": mean_squared_error, "overlap": overlap }