from __future__ import annotations
from functools import cached_property
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, InitVar
from typing import List, Iterable, Optional
from os import path
import matplotlib.pyplot as plt
from particle import Particle
import itertools
import numpy as np
from pandas import DataFrame, concat, read_csv
from uncertainties import unumpy as unp, ufloat
from .uncertainties import (
Uncertainty,
FullyCorrelatedUncertainty,
FullyCorrelatedUncertaintyInParts,
UncorrelatedUncertainty,
ExplicitlyCorrelatedUncertainty,
get_uncertainty_types,
)
from sysvar.utils import SavableAttributesObject, read_yaml, load_covariance_matrix
from sysvar.visualize import CorrectionVisualizer, UncertaintyVisualizer
import logging
logging.basicConfig(
format="%(levelname)s : %(funcName)s: %(lineno)d : %(message)s",
level=logging.INFO,
)
[docs]
class UncertaintyWithSameNameExists(Exception):
pass
[docs]
class UnkownUncertaintyType(Exception):
pass
[docs]
class NotValidRateType(Exception):
pass
[docs]
@dataclass()
class BaseCorrection(ABC, SavableAttributesObject):
uncertainties: dict = field(default_factory=dict)
@property
@abstractmethod
def visual_labels(self):
pass
[docs]
@abstractmethod
def build_queries(self) -> list:
pass
@staticmethod
def _build_column_name(prefix: str | None, variable: str) -> str:
"""
Constructs a column name by combining a prefix and a variable name.
This method takes an optional prefix and a mandatory variable name to create a
column name string. If the prefix is provided, the resulting column name will be
a concatenation of the prefix and the variable, separated by an underscore. If
the prefix is `None`, the column name will simply be the variable name.
Args:
prefix (str | None): An optional string to prepend to the variable name. If `None`, no prefix is added.
variable (str): The variable name to use for constructing the column name.
Returns:
str: The constructed column name, either as "prefix_variable" or just "variable" if no prefix is provided.
Examples:
>>> obj = MyClass()
>>> obj._build_column_name("prefix", "variable")
'prefix_variable'
>>> obj._build_column_name(None, "variable")
'variable'
"""
if not isinstance(variable, str):
raise ValueError(
f"{variable} is expected to be str by you passed {type(variable)}"
)
if not isinstance(prefix, str) and prefix is not None:
raise ValueError(
f"{prefix} is expected to be str by you passed {type(prefix)}"
)
elif prefix is None:
column_name = variable
else:
column_name = "_".join([prefix, variable])
return column_name
@property
def N(self) -> int:
return len(self.central_values)
@property
def total_error(self) -> np.ndarray:
if len(self.uncertainties) > 1:
return np.sqrt(
np.sum(
[np.power(x.errors, 2) for x in self.uncertainties.values()], axis=0
)
)
elif len(self.uncertainties) == 1:
return self.uncertainties[list(self.uncertainties.keys())[0]].errors
else:
raise ValueError("No uncertainties have been added to the correction")
[docs]
def add_uncertainty(
self,
unc_name,
unc_values,
unc_obj: Uncertainty,
explicit_cov_matrix: Optional[np.ndarray] = None,
) -> None:
"""
Add an uncertainty to the Correction.
Args:
unc (Uncertainty): The uncertainty to be added.
Raises:
UncertaintyWithSameNameExists: If uncertainty with the same name has already been added to the variator.
"""
if unc_name in self.uncertainties.keys():
raise UncertaintyWithSameNameExists(
f"An uncertainty with the name {unc_name} already exist in the set of uncertainties that that have been added to the Correction. Make sure that you add a specific uncertainty only once, and that there are no duplicate names"
)
else:
if self.cov_matrix is not None:
# Use the provided covariance matrix to create an explicitly correlated uncertainty
self.uncertainties.update(
{
unc_name: unc_obj(
unc_name,
unc_values,
self.visual_labels,
explicit_cov_matrix,
)
}
)
else:
self.uncertainties.update(
{unc_name: unc_obj(unc_name, unc_values, self.visual_labels)}
)
[docs]
def populate_uncertainties(self):
"""
Populate the `uncertainties` attribute with uncertainty objects based on the provided information.
This method adds uncertainties either from a provided covariance matrix (explicitly correlated uncertainty),
or from a dictionary of uncertainty types and values specified in the `info` attribute.
If a covariance matrix (`cov_matrix`) is available, it creates and adds a single
`ExplicitlyCorrelatedUncertainty`. Otherwise, it looks up available uncertainty types,
validates them, and populates the uncertainties accordingly.
Raises:
UnknownUncertaintyType: If an unsupported uncertainty type is found in the input data.
Notes:
- The uncertainty type definitions are expected to be returned from `get_uncertainty_types()`.
- Each uncertainty must have a unique name.
"""
if self.cov_matrix is not None:
# Use the provided covariance matrix to create an explicitly correlated uncertainty
errors = np.sqrt(np.diag(self.cov_matrix)) # Extract errors from diagonal
self.add_uncertainty(
unc_name="explicit_covariance",
unc_values=errors,
unc_obj=ExplicitlyCorrelatedUncertainty,
explicit_cov_matrix=self.cov_matrix,
)
else:
# Load the implemented uncertainty types
sysvar_uncertainties = get_uncertainty_types()
for unc_ctgy, uncertainty_dictionary in self.info["uncertainties"].items():
if unc_ctgy not in sysvar_uncertainties.keys():
raise UnkownUncertaintyType(
f"Unkown type of uncertainty is declared in the yaml file. Available uncertainty types are: ({', '.join(list(sysvar_uncertainties.keys()))}) but '{unc_ctgy}' was found in the yaml file"
)
else:
for unc_name, unc_values in uncertainty_dictionary.items():
self.add_uncertainty(
unc_name=unc_name,
unc_values=unc_values,
unc_obj=sysvar_uncertainties[unc_ctgy],
)
[docs]
def plot_error_comparison(
self, save: bool = False, filename: str = ""
) -> tuple[plt.Figure, plt.Axes]:
self.visualizer = CorrectionVisualizer(self)
fig, ax = self.visualizer.plot_error_comparison(save=save, filename=filename)
return fig, ax
[docs]
def plot_uncertainty(
self, unc_name: str | None = None, save: bool = False, filename: str = ""
):
if unc_name is None:
for unc_obj in self.uncertainties.values():
self.visualizer = UncertaintyVisualizer(unc_obj)
self.visualizer.plot_cov_and_corr(save=save, filename=filename)
else:
self.visualizer = UncertaintyVisualizer(self.uncertainties[unc_name])
self.visualizer.plot_cov_and_corr(save=save, filename=filename)
[docs]
@dataclass()
class BaseCorrectionFromYaml(BaseCorrection):
systematic: str = None
MC_production: str = None
def __post_init__(self):
super().__init__()
try:
self.info = read_yaml(self.systematic, self.MC_production)
except TypeError:
raise MissingInformationError(
f"Need to specify the systematic effect and the MC production in the positional arguments. You passed {self.systematic} and {self.MC_production}"
)
self.visualizer = None
self.title = self.info["title"]
self.cov_matrix = load_covariance_matrix(self.info)
@staticmethod
def _extend_queries_with_extra_cut(queries: list, extra_cut: str) -> list:
"""Extends a list of queries by appending an extra condition to each query.
Args:
queries (list): A list of query strings.
extra_cut (str): An additional condition to be appended to each query.
Returns:
list: A new list of query strings with the extra condition appended.
Example:
>>> _extend_queries_with_extra_cut(['query1', 'query2'], 'extra_condition')
['query1 & extra_condition', 'query2 & extra_condition']
"""
return [" & ".join([q, extra_cut]) for q in queries]
def _get_extra_cut_info(self):
return self.info["extra_cuts"]
@property
def table_dir(self) -> str:
"""
Returns the directory path where the tables are stored.
Returns:
str: The directory path where the tables are stored.
"""
return self.info["corrections"]["table_dir"]
@property
def table_name(self) -> str:
"""
Returns the table name by reading it from the config file.
Returns:
str: The table name.
"""
return self.info["corrections"]["table_name"]
@property
def table_ext(self) -> str:
"""
Returns the table extension by reading it from the config file.
Returns:
str: The table extension.
"""
return self.info["corrections"]["table_ext"]
@property
def table_key(self) -> str:
"""
Returns the table key by reading it from the config file.
Returns:
str: The table key.
"""
return self.info["corrections"]["table_key"]
[docs]
def build_table_path(self, suffix: Optional[str] = None) -> str:
"""
Builds the path for a table file based on the given suffix.
Args:
suffix (str): The suffix to be added to the table name.
Returns:
str: The full path of the table file.
Raises:
ValueError: If the table extension is unknown.
"""
if suffix is not None:
if not isinstance(suffix, str):
raise ValueError(f"Suffix must be a string, but got {type(suffix)}")
else:
base_name = f"{self.table_name}_{suffix}"
else:
base_name = self.table_name
if self.table_ext in ["txt", "csv"]:
filename = ".".join((base_name, self.table_ext))
else:
raise ValueError(
f"Unknown table extension {self.table_ext}. Supported extensions are: txt, csv"
)
return path.join(self.table_dir, filename)
[docs]
@dataclass
class Correction1D(BaseCorrectionFromYaml):
dependant_variable: str | None = None
central_values: Iterable = None
lower_bounds: Iterable = None
upper_bounds: Iterable = None
def __post_init__(self):
super().__post_init__()
self.central_values = self.read_corrections()
self.dependant_variable = self.info["dependant_variable"]
self.unit = self.info["unit"]
self.lower_bounds = self.info["min"]
self.upper_bounds = self.info["max"]
self.populate_uncertainties()
[docs]
def read_corrections(self) -> np.ndarray:
"""
Reads correction values either directly from the config or from a table file.
This method checks whether the 'corrections' entry in the config (`self.info`)
is a list/array or a dictionary. If it is a list or NumPy-compatible array,
the correction values are loaded directly. If it is a dictionary, the method
builds the table path using `self.build_table_path()`, reads the table, and
extracts the correction values using the key specified in `self.table_key`.
Returns:
np.ndarray: An array of correction values.
Raises:
KeyError: If 'corrections' or `self.table_key` is missing from the config.
FileNotFoundError: If the table file does not exist at the constructed path.
ValueError: If the table does not contain the specified key.
"""
correction_info = self.info["corrections"]
if isinstance(correction_info, (list, np.ndarray)):
logging.info("Loading correction values from config array.")
return np.asarray(correction_info)
elif isinstance(correction_info, dict):
logging.info("Loading correction values from config table.")
table_path = self.build_table_path()
table = read_csv(table_path)
# Convert the table to a numpy array
return np.asarray(table[self.table_key])
@property
def value_edges(self) -> np.ndarray:
return np.unique(np.concatenate((self.lower_bounds, self.upper_bounds)))
@property
def value_mids(self) -> np.ndarray:
return (self.value_edges[1:] + self.value_edges[:-1]) / 2
@property
def visual_labels(self) -> List[str]:
return [
f"{low} < {self.dependant_variable} < {up} {self.unit}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [
f"{low} <= {column_name} < {up}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
@dataclass
class Correction2D(BaseCorrectionFromYaml):
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
self.dependant_variable_1 = self.info["dependant_variable_1"]
self.dependant_variable_2 = self.info["dependant_variable_2"]
self.unit_1 = self.info["unit_1"]
self.unit_2 = self.info["unit_2"]
self.central_values = self._extract_central_values()
self.populate_uncertainties()
# Add an iterator to ensure that we'll loop over the corrections and bins
# consistently
@property
def iterator(self):
rows, columns, momenta, angles = [], [], [], []
for i, column_name in enumerate(self.central_values_table.columns):
for j, row_name in enumerate(self.central_values_table.index):
# clean the pi0 tables.... What a format...
if i == 0:
column_name = column_name.replace(" row:p column:t ", "")
# strip the strings and extract the momentum range
column_name = column_name.replace("p=", "")
ps = [float(x) / 10 for x in column_name.split("_")]
# strip the strings and extract the theta range
row_name = row_name.replace("cost=", "")
ts = [float(x) for x in row_name.split("_")]
rows.append(j)
columns.append(i)
momenta.append(ps)
angles.append(ts)
# Return a generator. Now can access all the central values and
# errors using iloc and the rows/colums
return zip(rows, columns, momenta, angles)
@cached_property
def central_values_table(self) -> DataFrame:
table_path = self.build_table_path("nom")
return read_csv(table_path)
@cached_property
def stat_error_table(self) -> DataFrame:
table_path = self.build_table_path("stat")
# Add column names when reading to skip creation of index column
return read_csv(table_path, names=[f"p bin {i}" for i in range(8)])
@cached_property
def sys_error_table(self) -> DataFrame:
table_path = self.build_table_path("sys")
# Add column names when reading to skip creation of index column
return read_csv(table_path, names=[f"p bin {i}" for i in range(8)])
@property
def visual_labels(self) -> List[str]:
return [
f"{momenta[0]} <= {self.dependant_variable_1} < {momenta[1]} {self.unit_1} & {angles[0]} <= {self.dependant_variable_2} < {angles[1]} {self.unit_2}"
for r, c, momenta, angles in self.iterator
]
def _extract_central_values(self):
return [
self.central_values_table.iloc[row, column]
for row, column, ps, ths in self.iterator
]
def _extract_errors(self, table: DataFrame):
return [table.iloc[row, column] for row, column, ps, ths in self.iterator]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name_1 = self._build_column_name(prefix, self.dependant_variable_1)
column_name_2 = self._build_column_name(prefix, self.dependant_variable_2)
queries = [
f"{momenta[0]} <= {column_name_1} < {momenta[1]} & {angles[0]} <= {column_name_2} < {angles[1]}"
for r, c, momenta, angles in self.iterator
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
def populate_uncertainties(self):
sysvar_uncertainties = get_uncertainty_types()
for unc_name, unc_ctgy in self.info["error_correlations"].items():
if unc_name == "stat":
unc_values = self._extract_errors(self.stat_error_table)
elif unc_name == "sys":
unc_values = self._extract_errors(self.sys_error_table)
else:
raise NotImplementedError(
"Only stat and sys error have been implemented now"
)
self.add_uncertainty(
unc_name=unc_name,
unc_values=unc_values,
unc_obj=sysvar_uncertainties[unc_ctgy],
)
[docs]
@dataclass
class Correction2DCategorical(BaseCorrectionFromYaml):
categorical_variable: str | None = None
continuus_variable: str | None = None
central_values: Iterable = None
continuus_edges: Iterable = None
categorical_values: Iterable = None
categorical_label: str = None
def __post_init__(self):
super().__post_init__()
self.central_values = []
part_dimensions = []
for correction in self.info["corrections"]:
self.central_values.extend(correction)
part_dimensions.append(len(correction))
self.categorical_variable = self.info["categorical_variable"]
self.categorical_values = self.info["categorical_values"]
self.categorical_label = self.info["categorical_label"]
self.continuus_variable = self.info["continuus_variable"]
self.continuus_edges = self.info["continuus_edges"]
self.extra_variables = self.info["extra_variables"]
@property
def iterator(self):
return itertools.product(
self.categorical_values, zip(self.continuus_edges, self.continuus_edges[1:])
)
@property
def strings(self) -> List[str]:
return [
f"{self.categorical_label}: {cv} [ {low} - {up} ]"
for cv, (low, up) in self.iterator
]
@property
def queries(self):
return [
f"{self.categorical_variable} == {cv} & {low} <= {self.continuus_variable} < {up} & {self._get_extra_cut()}"
for cv, (low, up) in self.iterator
]
[docs]
@dataclass
class CorrectionBF(BaseCorrectionFromYaml):
dependant_variable: str | None = None
central_values: Iterable = None
visual_labels: Iterable = None
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
central_values, error_amplitudes = self._calculate_scaling_ratios()
self.central_values = central_values
# Visual labels needs to be defined before we populate the uncertainties
# Otherwise the uncertainty object does not have visual_labels
self.visual_labels = self._create_strings()
self.populate_uncertainties(error_amplitudes)
self.dependant_variable = self.info["dependant_variable"]
def _create_strings(self) -> List[str]:
mother = Particle.from_pdgid(self.info["mother_particle"]).latex_name
daughter_pdgs = [
x for mode in self.info["modes"].values() for x in mode["daughters"]
]
strings = []
for daughter_set in daughter_pdgs:
daughter_names = []
for x in daughter_set:
try:
daughter_names.append(Particle.from_pdgid(x).latex_name)
except:
daughter_names.append(x)
strings.append(
rf"${mother} \rightarrow {' '.join(str(x) for x in daughter_names)}$"
)
return strings
def _calculate_scaling_ratios(self):
pdg_BFs = unp.uarray(
[x["pdg_live"][0] for x in self.info["modes"].values()],
[x["pdg_live"][1] for x in self.info["modes"].values()],
)
decaydec_BFs = unp.uarray(
[x["decay_dec"] for x in self.info["modes"].values()],
[0 for x in self.info["modes"].values()],
)
# Safe 0 division. Returns 1+- 0 for the ones where
# decay.dec = 0 or pdg = 0
corrections = np.divide(
pdg_BFs,
decaydec_BFs,
out=unp.uarray(np.ones_like(pdg_BFs), np.ones_like(pdg_BFs)),
where=((decaydec_BFs != 0) & (pdg_BFs != 0)),
)
return list(unp.nominal_values(corrections)), list(unp.std_devs(corrections))
[docs]
def populate_uncertainties(self, error_amplitudes: list):
"""
Overrides the method of the base class method as the error amplitutes are calculated dynamically from the calculate_scaling_ratios method
"""
unc_correlation = self.get_uncertainty_correlation()
sysvar_uncertainties = get_uncertainty_types()
self.add_uncertainty(
unc_name="BF_unc",
unc_values=error_amplitudes,
unc_obj=sysvar_uncertainties[unc_correlation],
)
[docs]
def get_uncertainty_correlation(self):
if self.info["correlation"] in [
"fully_correlated",
"uncorrelated",
"fully_correlated_in_parts",
]:
unc_type = self.info["correlation"]
elif self.info["correlation"] == "explicitly_correlated":
raise NotImplementedError(
"Cannot support custom correlation for BF correction yet"
)
else:
raise ValueError(
"Unkown correlation type. Available types are: fully_correlated, uncorrelated, fully_correlated_in_parts"
)
return unc_type
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [
(
f"{column_name} == '{mode['dmID']}'"
if isinstance(mode["dmID"], str)
else f"{column_name} in {mode['dmID']}"
)
for mode in self.info["modes"].values()
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
@dataclass
class CustomCorrection(BaseCorrection):
info: InitVar[dict] = None
dependant_variable: str | None = None
central_values: Iterable = None
uncertainties: dict = field(default_factory=dict)
query_targets: Iterable = None
def __post_init__(self, info):
self.info = info
self.dependant_variable = self.info["dependant_variable"]
self.central_values = self.info["central_values"]
self.query_targets = self.info["query_targets"]
self.unit = self.info["unit"]
self.title = self.info["title"]
self.cov_matrix = load_covariance_matrix(self.info)
self.populate_uncertainties()
@property
def value_edges(self) -> np.ndarray:
return np.arange(len(self.central_values) + 1)
@property
def value_mids(self) -> np.ndarray:
return (self.value_edges[1:] + self.value_edges[:-1]) / 2
@property
def visual_labels(self) -> List[str]:
return [
f"{self.dependant_variable} = {target}" for target in self.query_targets
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [f"{column_name} == '{target}'" for target in self.query_targets]
return queries
[docs]
@dataclass
class CorrectionPID(BaseCorrectionFromYaml):
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
rate = self.info["rate"]
self.check_valid_rate(rate)
self.table = self.get_table(rate)
self.central_values = self._extract_central_values()
self.p = self.info["momentum_variable"]
self.theta = self.info["theta_variable"]
self.PDG = self.info["PDG_variable"]
self.mcPDG = self.info["mcPDG_variable"]
self.momentum_unit = self.info["momentum_unit"]
# Add uncertainties as fully uncorrelated. This is a conservative choice
error_id = "stat"
self.uncertainties.update(
{
f"{error_id} uncertainty": UncorrelatedUncertainty(
f"{error_id} uncertainty",
self._extract_errors(f"{error_id}"),
self.visual_labels,
)
}
)
error_id = "sys"
self.uncertainties.update(
{
f"{error_id} uncertainty": FullyCorrelatedUncertainty(
f"{error_id} uncertainty",
self._extract_errors(f"{error_id}"),
self.visual_labels,
)
}
)
[docs]
@staticmethod
def check_valid_rate(rate):
valid_rates = ["eff", "fake"]
if rate not in valid_rates:
raise NotValidRateType(
f"Valid rate arguments are {*valid_rates,} but you passed {rate}"
)
[docs]
def get_table(self, rate):
table_finders = []
if "eID" in self.systematic:
if self.systematic == "eID_K_fake":
table_finders.append(
CorrectionTableFinder.Kfake_electrons(self.info, self.MC_production)
)
elif self.systematic == "eID_pi_fake":
table_finders.append(
CorrectionTableFinder.pifake_electrons(
self.info, self.MC_production
)
)
else:
table_finders.append(
CorrectionTableFinder.electrons(self.info, self.MC_production)
)
elif "muID" in self.systematic:
if self.systematic == "muID_K_fake":
table_finders.append(
CorrectionTableFinder.Kfake_muons(self.info, self.MC_production)
)
elif self.systematic == "muID_pi_fake":
table_finders.append(
CorrectionTableFinder.pifake_muons(self.info, self.MC_production)
)
else:
table_finders.append(
CorrectionTableFinder.muons(self.info, self.MC_production)
)
elif "kID" in self.systematic:
table_finders.append(
CorrectionTableFinder.kaons(self.info, self.MC_production)
)
elif "piID" in self.systematic:
table_finders.append(
CorrectionTableFinder.pions(self.info, self.MC_production)
)
eff_table = concat([x.eff_table for x in table_finders])
fake_rate_table = concat([x.fake_rate_table for x in table_finders])
if rate == "eff":
table = eff_table
# PATCH
# This has to be read somewhere else somehow
self._true_pdg = table_finders[0].true_pdg
elif rate == "fake":
table = fake_rate_table
# PATCH
# This has to be read somewhere else somehow
self._true_pdg = table_finders[0].fake_pdg
return table
@property
def true_pdg(self) -> list:
return self._true_pdg
@true_pdg.setter
def true_pdg(self, true_pdg):
self._true_pdg = true_pdg
@property
def iterator(self):
return self.table.iterrows()
@property
def queries(self):
# PATCH
# This just "implements" the property to satisfy the parent class
pass
[docs]
def build_queries(
self, prefix: str | None = None, extra_cut: str | None = None
) -> List[str]:
# Pre-compute column names to avoid repeated function calls
p_column_name = self._build_column_name(prefix, self.p)
theta_column_name = self._build_column_name(prefix, self.theta)
PDG_column_name = self._build_column_name(prefix, self.PDG)
mcPDG_column_name = self._build_column_name(prefix, self.mcPDG)
# Create a local reference for self._true_pdg to avoid repeated attribute access
true_pdg = self._true_pdg
# Use a list comprehension with cached lookups to improve performance
queries = []
append_query = queries.append # Local function assignment for faster append
# Use local variable access within the loop to speed up string formatting
for _, row in self.iterator:
# Access row[1] once and cache its values in local variables
p_min = row["p_min"]
p_max = row["p_max"]
theta_min = row["theta_min"]
theta_max = row["theta_max"]
mcPDG = row["mcPDG"]
# Construct the query string with reduced overhead
query = (
f"({p_min} <= {p_column_name} < {p_max} & "
f"{theta_min} <= {theta_column_name} < {theta_max} & "
f"{PDG_column_name} == {mcPDG} & "
f"{mcPDG_column_name} in {true_pdg})"
)
# Append the constructed query string to the list
append_query(query)
# Add any extra cuts to the queries if needed
queries = self.add_extra_cuts(queries, prefix)
return queries
@property
def visual_labels(self) -> List[str]:
return [
rf"{row[1]['p_min']} <= p < {row[1]['p_max']} {self.momentum_unit} & {row[1]['theta_min']} <= $\theta$ < {row[1]['theta_max']} & q = {row[1]['charge']}"
for row in self.iterator
]
def _extract_central_values(self):
return [row[1]["data_MC_ratio"] for row in self.iterator]
def _extract_errors(self, error_type):
# this assumes symmetric errors and takes the maximum out of the two.
return [
row[1][
[
f"data_MC_uncertainty_{error_type}_up",
f"data_MC_uncertainty_{error_type}_dn",
]
].max()
for row in self.iterator
]
# #######################################################################################
[docs]
def create_correction_object(
syst_effect: str | dict | None, MC_prod: str
) -> BaseCorrection:
"""Retrieves amd creates the appropriate correction object based on the systematic effect and MC production type.
Args:
syst_effect (str): The systematic effect identifier.
MC_prod (str): The Monte Carlo production type identifier.
Returns:
BaseCorrection: The appropriate correction object based on the provided systematic effect and MC production type.
Raises:
NotImplementedError: If the correction type specified in the configuration is not implemented.
Example:
>>> correction = get_correction_object("syst1", "MC1")
>>> isinstance(correction, BaseCorrection)
True
"""
correction_types = {
"1D": Correction1D,
"2D": Correction2D,
"2DCategorical": Correction2DCategorical,
"BF": CorrectionBF,
"PID": CorrectionPID,
}
if isinstance(syst_effect, str):
corr_type = read_yaml(syst_effect, MC_prod)["correction_type"]
try:
return correction_types[corr_type](
systematic=syst_effect, MC_production=MC_prod
)
except KeyError:
raise NotImplementedError(
f"Available corrections are: {list(correction_types.keys())} but you passed {corr_type}"
)
elif isinstance(syst_effect, dict):
return CustomCorrection(info=syst_effect)
else:
raise ValueError(
"Pass a string for existing standard systematic to create a correction object from yaml files or a dictionary to create a custom correction object"
)
[docs]
class CorrectionTableFinder:
"""
Factory method class to get correction tables for kaons, pions, electrons and muons
"""
def __init__(
self, particle_species, online_cut, base_table_path, variable, MC_production
):
self.particle_species = particle_species
self.online_cut = online_cut
self.base_table_path = base_table_path
self.variable = variable
self.MC_production = MC_production
self.production_table_ids()
self.true_pdg = self.particle_species_settings[particle_species]["true_pdgs"]
self.fake_pdg = self.particle_species_settings[particle_species]["fake_pdgs"]
self.value = self.get_cut_value()
self.cut_type = self.get_cut_type()
efficiency_table_names = self.build_table_name(
self.particle_species_settings[self.particle_species]["eff_table_ids"]
)
fake_rate_table_names = self.build_table_name(
self.particle_species_settings[self.particle_species]["fake_rate_table_ids"]
)
self.eff_table = self.get_table(efficiency_table_names)
self.fake_rate_table = self.get_table(fake_rate_table_names)
[docs]
@classmethod
def kaons(cls, external_info, MC_production):
particle_species = "kaon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=None,
MC_production=MC_production,
)
[docs]
@classmethod
def pions(cls, external_info, MC_production):
particle_species = "pion"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=None,
MC_production=MC_production,
)
[docs]
@classmethod
def electrons(cls, external_info, MC_production):
particle_species = "elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def Kfake_electrons(cls, external_info, MC_production):
particle_species = "Kfake_elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def pifake_electrons(cls, external_info, MC_production):
particle_species = "pifake_elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def muons(cls, external_info, MC_production):
particle_species = "muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def Kfake_muons(cls, external_info, MC_production):
particle_species = "Kfake_muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def pifake_muons(cls, external_info, MC_production):
particle_species = "pifake_muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
def production_table_ids(self):
"""
Updates hadron table IDs based on the MC production type.
Raises:
ValueError: If an unsupported MC production type is provided.
"""
if self.MC_production not in ["MC15ri", "MC15rd"]:
raise ValueError("Invalid production type. Must be 'MC15ri' or 'MC15rd'.")
eff_table_mapping = {
"MC15ri": {"kaon": "keff", "pion": "pieff"},
"MC15rd": {"kaon": "KEff", "pion": "piEff"},
}
fake_table_mapping = {
"MC15ri": {"kaon": "piFk", "pion": "kFpi"},
"MC15rd": {"kaon": "piFakeK", "pion": "KFakepi"},
}
if self.particle_species in ["kaon", "pion"]:
self.particle_species_settings[self.particle_species]["eff_table_ids"] = [
eff_table_mapping[self.MC_production][self.particle_species]
]
self.particle_species_settings[self.particle_species][
"fake_rate_table_ids"
] = [fake_table_mapping[self.MC_production][self.particle_species]]
@property
def particle_species_settings(self) -> dict:
if not hasattr(self, "_particle_species_settings"):
self._particle_species_settings = {
"kaon": {
"true_pdgs": [321, -321],
"fake_pdgs": [211, -211],
"eff_table_ids": ["keff"],
"fake_rate_table_ids": ["piFk"],
},
"pion": {
"true_pdgs": [211, -211],
"fake_pdgs": [321, -321],
"eff_table_ids": ["pieff"],
"fake_rate_table_ids": ["kFpi"],
},
"elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [321, 211, -321, -211],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": [
"K_e_fakeRate",
"pi_e_fakeRate",
],
},
"Kfake_elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [
321,
-321,
],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": ["K_e_fakeRate"],
},
"pifake_elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [
211,
-211,
],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": ["pi_e_fakeRate"],
},
"muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [321, 211, -321, -211],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": [
"K_mu_fakeRate",
"pi_mu_fakeRate",
],
},
"Kfake_muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [321, -321],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": ["K_mu_fakeRate"],
},
"pifake_muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [
211,
-211,
],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": ["pi_mu_fakeRate"],
},
}
return self._particle_species_settings
[docs]
def get_cut_type(
self,
) -> str:
"""Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction
Args:
species: Particle species, should be K+ or pi+
Returns:
the cut type that has been applied online. Binary or global
"""
# Read the online selections that have been applied on the online reconstruction
# TODO update this to the config file of each experiment to avoid making the mistake of
# changing the value during the offline preproccesing
if "elec" in self.particle_species or "muon" in self.particle_species:
cut_type = self.online_cut
elif self.particle_species in ["kaon", "pion"]:
if "binaryPID" in self.online_cut:
cut_type = "B"
elif "ID" in self.online_cut:
cut_type = "G"
else:
logging.warning(
"Cut type, neither global, nor binary HID selection has been applied online"
)
else:
raise ValueError("Wrong particle species")
return cut_type
[docs]
def get_cut_value(self) -> str:
"""Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction
Args:
species: Particle species, should be K+ or pi+
Returns:
the cut type that has been applied online. Binary or global
"""
# Read the online selections that have been applied on the online reconstruction
# TODO update this to the config file of each experiment to avoid making the mistake of
# changing the value during the offline preproccesing
if self.particle_species in ["muon", "elec"]:
cut_value = self.online_cut[-1]
elif self.particle_species in ["kaon", "pion"]:
cut_value = self.online_cut[-3:]
return self.online_cut[-1]
[docs]
def build_table_name(
self,
table_ids: str,
) -> list:
"""Builds the efficiency and fake rate tables path names
Args:
table_ids: efficiency or fake table id
Returns:
list with the efficiency or fake rate table file names
"""
# Create the file names.
# These are both for plus and minus
if self.particle_species in ["kaon", "pion"]:
# First build the names for positive charge
file_names = [self.build_hid_table_name(x, "p") for x in table_ids]
# Now add thhe names for negative charge
file_names.extend([self.build_hid_table_name(x, "m") for x in table_ids])
# elif self.particle_species in ["elec", "muon"]:
elif "elec" in self.particle_species or "muon" in self.particle_species:
file_names = ["_".join((x, "table.csv")) for x in table_ids]
return [path.join(self.base_table_path, x) for x in file_names]
[docs]
def build_hid_table_name(self, table_id, charge):
if self.MC_production == "MC15ri":
table_name = "_".join(
(
"Rdtmc",
table_id,
charge,
self.cut_type + "0-" + str(self.value)[-1],
"all.log",
)
)
elif self.MC_production == "MC15rd":
table_name = "_".join(
(
"MC15rd",
table_id,
"all",
self.cut_type + "0-" + str(self.value)[-1] + ".log",
)
)
return table_name
[docs]
def get_table(self, table_names):
if self.particle_species in ["kaon", "pion"]:
table = concat([read_csv(x) for x in table_names])
self.make_pidvar_compatible(self.MC_production, table, max_uncertainty=10)
elif "elec" in self.particle_species or "muon" in self.particle_species:
# elif self.particle_species in ["elec", "muon"]:
table = concat([read_csv(x) for x in table_names])
table.query(self.get_lid_queries(), inplace=True)
self.add_mcPDG_to_table(table)
return table
[docs]
def add_mcPDG_to_table(self, table):
table.loc[:, "mcPDG"] = -9999
table.loc[table["charge"] == "-", "mcPDG"] = self.true_pdg[0]
table.loc[table["charge"] == "+", "mcPDG"] = self.true_pdg[1]
[docs]
@staticmethod
def make_pidvar_compatible(
MC_production: str, table: DataFrame, max_uncertainty: Optional[float] = 1e2
):
"""
Convert the pandas dataframes obtained Hadron ID CSV tables via into a
format consistent with the format of the lepton ID tables which ``PIDvar``
understands.
In particular, convert the ``charge`` column from ``1``/``-1`` integer
entries to ``+``/``-`` string entries and calculate the
``theta_min``/``theta_max`` columns.
:param table: Pandas dataframe obtained from ``pandas.read_csv`` on HID table
:param inplace: Whether to modify the existing dataframe in place.
Otherwise, a copy of the existing dataframe will be returned.
:param max_uncertainty: Drop rows in HID tables where any of the data-MC
uncertainties (sys/stat up/down) exceed this value. Rationale: The HID
tables contain rows with nonsense uncertainties > 10⁸, so it is meant to
remove those entries. Therefore, the exact value is not important. Set
to ``None`` to disable dropping any columns.
:return: Modified dataframe that can be used by ``PIDvar``.
"""
# Some checks that table has expected format of Hadron ID tables
if MC_production == "MC15ri":
if not set(table["charge"]).issubset({1, -1}):
raise ValueError(
"Expected that the ``charge`` entries of the original Hadron ID dataframe consists"
+ "only of ``1`` and ``-1``, but it contains {}".format(
set(table["charge"])
)
)
# if "theta_min" in table or "theta_max" in table:
# raise ValueError("Dataframe already has ``theta_…`` columns")
if max_uncertainty is not None:
unc_cols = [
"data_MC_uncertainty_stat_up",
"data_MC_uncertainty_stat_dn",
"data_MC_uncertainty_sys_up",
"data_MC_uncertainty_sys_dn",
]
# table = table[table[unc_cols].max(axis=1) <= max_uncertainty]
# for θ in [0, π], cos(θ) is strictly decreasing, so we have invert min and max when inverting the cosine
if MC_production == "MC15ri":
table["theta_min"] = -9999
table["theta_max"] = -9999
table.loc[:, "theta_min"] = np.arccos(table["cos_max"].copy(deep=True))
table.loc[:, "theta_max"] = np.arccos(table["cos_min"].copy(deep=True))
elif MC_production == "MC15rd":
pass
# PIDvar expects charge columns to contain + or -
if MC_production == "MC15ri":
table.loc[:, "charge"] = np.where(table["charge"] == +1, "+", "-")
elif MC_production == "MC15rd":
table.loc[:, "charge"] = np.where(table["charge_min"] == -2, "-", "+")
return table
[docs]
def get_lid_queries(self):
working_point = f"(working_point == '{self.cut_type}')"
best_available = "(is_best_available == True)"
if "elec" in self.particle_species:
exclude_bins = "(not ((theta_min == 0.56 and theta_max == 2.23) or (theta_min == 0.22 and theta_max == 2.71) or (p_min == 0.2 and p_max == 7) or (p_min == 0.2 and p_max == 5)))"
elif "muon" in self.particle_species:
exclude_bins = "(not ((theta_min == 0.82 and theta_max == 2.22) or (theta_min == 0.4 and theta_max == 0.82) or (theta_min == 0.4 and theta_max == 2.6) or (p_min == 0.2 and p_max == 5)))"
variable = f"(variable == '{self.variable}')"
# PATCH exclude negative and extremely large weights
physical_values = "(0 < data_MC_ratio < 10)"
logging.warning(
f"If negative weights or extremely large weights are present in the table, these will be excluded. The arbitrarily selected 'physical range' is {physical_values}"
)
total_cutstring = " and ".join(
(working_point, best_available, exclude_bins, variable, physical_values)
)
logging.info(
f"The following cutstring has been applied to the provide LID table: {total_cutstring}"
)
return total_cutstring