Source code for zfit_physics.roofit.loss

#  Copyright (c) 2024 zfit
from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING

from zfit.interface import ZfitParameter

if TYPE_CHECKING:
    try:
        import ROOT
    except ImportError:
        ROOT = None
import warnings

import zfit
from zfit.util.container import convert_to_container

from .variables import roo2z_param


[docs] def nll_from_roofit(nll: ROOT.RooAbsReal, params: ZfitParameter | Iterable[ZfitParameter] = None): """ Converts a RooFit NLL (negative log-likelihood) to a Zfit loss object. Args: nll: The RooFit NLL object to be converted. params: The ``zfit.Parameter`` to be used in the loss. If None, all parameters in the NLL will be used Returns: zfit.loss.SimpleLoss: The converted Zfit loss object. Raises: TypeError: If the provided RooFit loss does not have an error level. """ params = {} if params is None else {p.name: p for p in convert_to_container(params)} import zfit def roofit_eval(x): for par, arg in zip(nll.getVariables(), x, strict=False): par.setVal(arg) # following RooMinimizerFcn.cxx nll.setHideOffset(False) r = nll.getVal() nll.setHideOffset(True) return r paramsall = [] for v in nll.getVariables(): param = params[name] if (name := v.GetName()) in params else roo2z_param(v) paramsall.append(param) if (errordef := getattr(nll, "defaultErrorLevel", lambda: None)()) is None and ( errordef := getattr(nll, "errordef", lambda: None)() ) is None: msg = ( "Provided loss is RooFit loss but has not error level. " "Either set it or create an attribute on the fly (like `nllroofit.errordef = 0.5`) " ) raise TypeError(msg) return zfit.loss.SimpleLoss(roofit_eval, paramsall, errordef=errordef, jit=False, gradient="num", hessian="num")
def _nll_from_roofit_or_false(nll, params=None): ROOT = None if "RooAbsReal" in str(type(nll)): try: import ROOT except ImportError: warnings.warn( f"nll ({nll}) seems to be of type RooAbsReal but ROOT is not available, skipping.", stacklevel=2 ) if ROOT is None or not isinstance(nll, ROOT.RooAbsReal): return False # not a RooFit loss return nll_from_roofit(nll, params=params) zfit.loss.SimpleLoss.register_convertable_loss(nll_from_roofit, priority=50)