Source code for zfit_physics.compwa.loss
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
import zfit
from zfit.util.container import convert_to_container
from .variables import params_from_intensity
if TYPE_CHECKING:
from tensorwaves.estimator import Estimator
from zfit.interface import ZfitLoss
__all__ = ["nll_from_estimator"]
[docs]
def nll_from_estimator(estimator: Estimator, *, params=None, errordef=None, numgrad=None) -> ZfitLoss:
r"""Create a negative log-likelihood function from a tensorwaves estimator.
Args:
estimator: An estimator object that computes a scalar loss function.
params: A list of zfit parameters that the loss function depends on.
errordef: The error definition of the loss function.
numgrad: If True, the gradient of the loss function is computed numerically and the ComPWA estimators
gradient method is not used. Can be useful as not all backends in ComPWA support gradients.
Returns:
A zfit loss function that can be used with zfit.
"""
from tensorwaves.estimator import ChiSquared, UnbinnedNLL
if params is None:
classname = estimator.__class__.__name__
intensity = getattr(estimator, f"_{classname}__function", None)
if intensity is None:
msg = f"Could not find intensity function in {estimator}. Maybe the attribute changed?"
raise ValueError(msg)
params = params_from_intensity(intensity)
else:
params = convert_to_container(params)
paramnames = [param.name for param in params]
def func(params):
paramdict = dict(zip(paramnames, params, strict=False))
return estimator(paramdict)
if numgrad:
grad = None
else:
def grad(params):
paramdict = dict(zip(paramnames, params, strict=False))
return estimator.gradient(paramdict)
if errordef is None:
if hasattr(estimator, "errordef"):
errordef = estimator.errordef
elif isinstance(estimator, ChiSquared):
errordef = 1.0
elif isinstance(estimator, UnbinnedNLL):
errordef = 0.5
return zfit.loss.SimpleLoss(func=func, gradient=grad, params=params, errordef=errordef)
def _nll_from_estimator_or_false(estimator: Estimator, *, params=None, errordef=None) -> ZfitLoss | bool:
if "tensorwaves" in repr(type(estimator)):
try:
import tensorwaves as tw
except ImportError:
return False
if not isinstance(estimator, tw.estimator.ChiSquared | tw.estimator.UnbinnedNLL):
warnings.warn(
"Only ChiSquared and UnbinnedNLL are supported from tensorwaves currently."
f"TensorWaves is in name of {estimator}, this could be a bug.",
stacklevel=2,
)
return False
return nll_from_estimator(estimator, params=params, errordef=errordef)
return None
zfit.loss.SimpleLoss.register_convertable_loss(_nll_from_estimator_or_false)