from __future__ import annotations
from functools import cached_property
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, InitVar
from typing import List, Iterable, Optional, Any
from os import path
from pathlib import Path
import matplotlib.pyplot as plt
from particle import Particle
import itertools
import numpy as np
from pandas import DataFrame, concat, read_csv
import pandas as pd
from uncertainties import unumpy as unp, ufloat
from .uncertainties import (
Uncertainty,
FullyCorrelatedUncertainty,
FullyCorrelatedUncertaintyInParts,
UncorrelatedUncertainty,
ExplicitlyCorrelatedUncertainty,
get_uncertainty_types,
)
from sysvar.utils import SavableAttributesObject, read_yaml, load_covariance_matrix
from sysvar.visualize import CorrectionVisualizer, UncertaintyVisualizer
import logging
logging.basicConfig(
format="%(levelname)s : %(funcName)s: %(lineno)d : %(message)s",
level=logging.INFO,
)
[docs]
class UncertaintyWithSameNameExists(Exception):
pass
[docs]
class UnkownUncertaintyType(Exception):
pass
[docs]
class NotValidRateType(Exception):
pass
[docs]
class InvalidCorrectionTableKey(Exception):
pass
[docs]
class CorrectionTablehasNaNValues(Exception):
pass
[docs]
class InvalidInfoDict(Exception):
pass
[docs]
@dataclass()
class BaseCorrection(ABC, SavableAttributesObject):
uncertainties: dict = field(default_factory=dict)
@property
@abstractmethod
def visual_labels(self):
pass
[docs]
@abstractmethod
def build_queries(self) -> list:
pass
@staticmethod
def _build_column_name(prefix: str | None, variable: str) -> str:
"""
Constructs a column name by combining a prefix and a variable name.
This method takes an optional prefix and a mandatory variable name to create a
column name string. If the prefix is provided, the resulting column name will be
a concatenation of the prefix and the variable, separated by an underscore. If
the prefix is `None`, the column name will simply be the variable name.
Args:
prefix (str | None): An optional string to prepend to the variable name. If `None`, no prefix is added.
variable (str): The variable name to use for constructing the column name.
Returns:
str: The constructed column name, either as "prefix_variable" or just "variable" if no prefix is provided.
Examples:
>>> obj = MyClass()
>>> obj._build_column_name("prefix", "variable")
'prefix_variable'
>>> obj._build_column_name(None, "variable")
'variable'
"""
if not isinstance(variable, str):
raise ValueError(
f"{variable} is expected to be str by you passed {type(variable)}"
)
if not isinstance(prefix, str) and prefix is not None:
raise ValueError(
f"{prefix} is expected to be str by you passed {type(prefix)}"
)
elif prefix is None:
column_name = variable
else:
column_name = "_".join([prefix, variable])
return column_name
@property
def N(self) -> int:
return len(self.central_values)
@property
def total_error(self) -> np.ndarray:
if len(self.uncertainties) > 1:
return np.sqrt(
np.sum(
[np.power(x.errors, 2) for x in self.uncertainties.values()], axis=0
)
)
elif len(self.uncertainties) == 1:
return self.uncertainties[list(self.uncertainties.keys())[0]].errors
else:
raise ValueError("No uncertainties have been added to the correction")
def _is_valid_info_dict(self) -> None:
self._require_key("cov_matrix")
self._validate_cov_matrix_value(self.info["cov_matrix"])
# ----------------------------
# Helpers (instance methods)
# ----------------------------
def _require_key(self, key: str) -> None:
if key not in self.info:
raise InvalidInfoDict(f"The info dictionary must contain a '{key}' key")
def _validate_cov_matrix_value(self, cov: Any) -> None:
"""Validate cov_matrix value: None OR path OR array-like."""
if cov is None:
return
if isinstance(cov, (str, Path)):
self._validate_cov_matrix_path(cov)
return
if isinstance(cov, (list, tuple, np.ndarray)):
self._validate_cov_matrix_arraylike(cov)
return
raise InvalidInfoDict(
f"'cov_matrix' must be None, a path (str/Path), or an array-like "
f"(list/tuple/np.ndarray). Got: {type(cov)}"
)
def _validate_cov_matrix_path(self, path_value: str | Path) -> None:
p = Path(path_value)
if not p.exists():
raise InvalidInfoDict(f"'cov_matrix' path does not exist: {p}")
allowed_suffixes = {
".npy",
".tsv",
}
if p.suffix not in allowed_suffixes:
raise InvalidInfoDict(
f"Unsupported cov_matrix file type '{p.suffix}'. "
f"Supported: {', '.join(sorted(allowed_suffixes))}"
)
def _validate_cov_matrix_arraylike(self, cov: Any) -> None:
arr = np.asarray(cov, dtype=float)
if arr.ndim != 2:
raise InvalidInfoDict("'cov_matrix' must be a 2D array-like (NxN).")
n, m = arr.shape
if n == 0 or m == 0:
raise InvalidInfoDict("'cov_matrix' cannot be empty.")
if n != m:
raise InvalidInfoDict(
f"'cov_matrix' must be square (NxN). Got shape {arr.shape}."
)
if not np.isfinite(arr).all():
raise InvalidInfoDict("'cov_matrix' contains NaN or inf values.")
# Optional but recommended
if not np.allclose(arr, arr.T, rtol=0, atol=1e-12):
raise InvalidInfoDict("'cov_matrix' must be symmetric.")
[docs]
def add_uncertainty(
self,
unc_name,
unc_values,
unc_obj: Uncertainty,
explicit_cov_matrix: Optional[np.ndarray] = None,
) -> None:
"""
Add an uncertainty to the Correction.
Args:
unc (Uncertainty): The uncertainty to be added.
Raises:
UncertaintyWithSameNameExists: If uncertainty with the same name has already been added to the variator.
"""
if unc_name in self.uncertainties.keys():
raise UncertaintyWithSameNameExists(
f"An uncertainty with the name {unc_name} already exist in the set of uncertainties that that have been added to the Correction. Make sure that you add a specific uncertainty only once, and that there are no duplicate names"
)
else:
if self.cov_matrix is not None:
# Use the provided covariance matrix to create an explicitly correlated uncertainty
self.uncertainties.update(
{
unc_name: unc_obj(
unc_name,
unc_values,
self.visual_labels,
explicit_cov_matrix,
)
}
)
else:
self.uncertainties.update(
{unc_name: unc_obj(unc_name, unc_values, self.visual_labels)}
)
[docs]
def populate_uncertainties(self):
"""
Populate the `uncertainties` attribute with uncertainty objects based on the provided information.
This method adds uncertainties either from a provided covariance matrix (explicitly correlated uncertainty),
or from a dictionary of uncertainty types and values specified in the `info` attribute.
If a covariance matrix (`cov_matrix`) is available, it creates and adds a single
`ExplicitlyCorrelatedUncertainty`. Otherwise, it looks up available uncertainty types,
validates them, and populates the uncertainties accordingly.
Raises:
UnknownUncertaintyType: If an unsupported uncertainty type is found in the input data.
Notes:
- The uncertainty type definitions are expected to be returned from `get_uncertainty_types()`.
- Each uncertainty must have a unique name.
"""
if self.cov_matrix is not None:
# Use the provided covariance matrix to create an explicitly correlated uncertainty
errors = np.sqrt(np.diag(self.cov_matrix)) # Extract errors from diagonal
self.add_uncertainty(
unc_name="explicit_covariance",
unc_values=errors,
unc_obj=ExplicitlyCorrelatedUncertainty,
explicit_cov_matrix=self.cov_matrix,
)
else:
# Load the implemented uncertainty types
sysvar_uncertainties = get_uncertainty_types()
for unc_ctgy, uncertainty_dictionary in self.info["uncertainties"].items():
if unc_ctgy not in sysvar_uncertainties.keys():
raise UnkownUncertaintyType(
f"Unkown type of uncertainty is declared in the yaml file. Available uncertainty types are: ({', '.join(list(sysvar_uncertainties.keys()))}) but '{unc_ctgy}' was found in the yaml file"
)
else:
for unc_name, unc_values in uncertainty_dictionary.items():
self.add_uncertainty(
unc_name=unc_name,
unc_values=unc_values,
unc_obj=sysvar_uncertainties[unc_ctgy],
)
[docs]
def plot_error_comparison(
self, save: bool = False, filename: str = ""
) -> tuple[plt.Figure, plt.Axes]:
self.visualizer = CorrectionVisualizer(self)
fig, ax = self.visualizer.plot_error_comparison(save=save, filename=filename)
return fig, ax
[docs]
def plot_uncertainty(
self, unc_name: str | None = None, save: bool = False, filename: str = ""
):
if unc_name is None:
for unc_obj in self.uncertainties.values():
self.visualizer = UncertaintyVisualizer(unc_obj)
self.visualizer.plot_cov_and_corr(save=save, filename=filename)
else:
self.visualizer = UncertaintyVisualizer(self.uncertainties[unc_name])
self.visualizer.plot_cov_and_corr(save=save, filename=filename)
[docs]
@dataclass()
class BaseCorrectionFromYaml(BaseCorrection):
systematic: str = None
MC_production: str = None
def __post_init__(self):
super().__init__()
try:
self.info = read_yaml(self.systematic, self.MC_production)
except TypeError:
raise MissingInformationError(
f"Need to specify the systematic effect and the MC production in the positional arguments. You passed {self.systematic} and {self.MC_production}"
)
self._is_valid_info_dict()
self.visualizer = None
self.title = self.info["title"]
self.cov_matrix = load_covariance_matrix(self.info)
@staticmethod
def _extend_queries_with_extra_cut(queries: list, extra_cut: str) -> list:
"""Extends a list of queries by appending an extra condition to each query.
Args:
queries (list): A list of query strings.
extra_cut (str): An additional condition to be appended to each query.
Returns:
list: A new list of query strings with the extra condition appended.
Example:
>>> _extend_queries_with_extra_cut(['query1', 'query2'], 'extra_condition')
['query1 & extra_condition', 'query2 & extra_condition']
"""
return [" & ".join([q, extra_cut]) for q in queries]
def _get_extra_cut_info(self):
return self.info["extra_cuts"]
@property
def table_dir(self) -> str:
"""
Returns the directory path where the tables are stored.
Returns:
str: The directory path where the tables are stored.
"""
return self.info["corrections"]["table_dir"]
@property
def table_name(self) -> str:
"""
Returns the table name by reading it from the config file.
Returns:
str: The table name.
"""
return self.info["corrections"]["table_name"]
@property
def table_ext(self) -> str:
"""
Returns the table extension by reading it from the config file.
Returns:
str: The table extension.
"""
return self.info["corrections"]["table_ext"]
@property
def table_key(self) -> str:
"""
Returns the table key by reading it from the config file.
Returns:
str: The table key.
"""
return self.info["corrections"]["table_key"]
[docs]
def build_table_path(self, suffix: Optional[str] = None) -> str:
"""
Builds the path for a table file based on the given suffix.
Args:
suffix (str): The suffix to be added to the table name.
Returns:
str: The full path of the table file.
Raises:
ValueError: If the table extension is unknown.
"""
if suffix is not None:
if not isinstance(suffix, str):
raise ValueError(f"Suffix must be a string, but got {type(suffix)}")
else:
base_name = f"{self.table_name}_{suffix}"
else:
base_name = self.table_name
if self.table_ext in ["txt", "csv"]:
filename = ".".join((base_name, self.table_ext))
else:
raise ValueError(
f"Unknown table extension {self.table_ext}. Supported extensions are: txt, csv"
)
return path.join(self.table_dir, filename)
[docs]
@dataclass()
class BaseCorrectionFromCSV(BaseCorrection):
"""
CSV-driven base class that standardizes reading corrections (central values and
uncertainties) from a single, long-form CSV file.
The CSV is expected to contain at least:
- 'central_value'
Optional uncertainty columns (if present they are auto-mapped):
- 'stat_corr', 'sys_corr' -> fully_correlated
- 'stat_uncorr', 'sys_uncorr'-> uncorrelated
Optional extra metadata columns that 1D/2D specializations will use:
- 'dependant_variable' or 'dependant_variable_1' and 'dependant_variable_2'
- '{var}_unit', '{var}_min', '{var}_max' for each dependant variable
Optional extra cut columns for automatic query enhancement:
- 'PDG': PDG codes as strings in format "[521,-521]" or "[521]"
- 'mcPDG': MC truth PDG codes as strings in format "[521,-521]" or "[521]"
Note: Only string format like "[521,-521]" is supported.
Optional explicit covariance matrix:
- 'cov_matrix_path': Path to file containing explicit covariance matrix
"""
csv_path: str | None = None
title: str | None = None
cov_matrix_path: str | None = None
def __post_init__(self):
super().__init__()
if self.csv_path is None:
raise MissingInformationError(
"csv_path must be provided when using BaseCorrectionFromCSV."
)
# loading all columns for the csv file to enforce string dtype on unit columns
cols = pd.read_csv(self.csv_path, nrows=0).columns
# Enforce string dtype on unit columns
unit_converters = {col: str for col in cols if "unit" in col.lower()}
self.table = read_csv(self.csv_path, converters=unit_converters, na_filter=True)
self._is_valid_table()
# Handle explicit covariance matrix
if self.cov_matrix_path is not None:
if not path.exists(self.cov_matrix_path):
raise ValueError(
f"Covariance matrix file not found: {self.cov_matrix_path}"
)
# Create a minimal info dict for load_covariance_matrix
cov_info = {"cov_matrix": self.cov_matrix_path}
self.cov_matrix = load_covariance_matrix(cov_info)
else:
self.cov_matrix = None
self.info = self._build_info_from_table()
self._is_valid_info_dict()
self.title = (
self.title if isinstance(self.title, str) else path.basename(self.csv_path)
)
self.visualizer = None
def _is_valid_table(self) -> None:
if self.table is None or len(self.table) == 0:
raise ValueError(f"No data found in CSV at {self.csv_path}")
# Validate PDG string format if these columns exist
for pdg_column in ["PDG", "mcPDG"]:
if pdg_column in self.table.columns:
for row_index, pdg_value in enumerate(self.table[pdg_column].tolist()):
if (
isinstance(pdg_value, str)
and pdg_value.startswith("[")
and pdg_value.endswith("]")
):
pdg_list_text = pdg_value.strip("[]").strip()
if pdg_list_text == "":
continue
try:
_ = [int(x.strip()) for x in pdg_list_text.split(",")]
except Exception:
raise ValueError(
f"Row {row_index} in column '{pdg_column}' has invalid PDG list format: {pdg_value}"
)
elif pd.isna(pdg_value):
continue
else:
raise ValueError(
f"Column '{pdg_column}' must use string list format like '[521,-521]'; got: {pdg_value} with datatype {type(pdg_value)}"
)
def _build_info_from_table(self) -> dict:
"""
Build a lightweight info dictionary compatible with BaseCorrection.populate_uncertainties.
"""
extra_cuts_columns = [
col
for col in self.table.columns
if col in ["PDG", "mcPDG"] and self.table[col].notna().any()
]
# If explicit covariance matrix is provided, skip individual uncertainty columns
if self.cov_matrix is not None:
info = {
"uncertainties": {},
"extra_cuts_columns": extra_cuts_columns,
"cov_matrix": self.cov_matrix_path,
}
return info
uncertainties: dict = {
"fully_correlated": {},
"uncorrelated": {},
}
for key in ["stat_corr", "sys_corr", "stat_uncorr", "sys_uncorr"]:
if key not in self.table:
raise InvalidCorrectionTableKey(
f"Missing key in correction table: {key}"
)
values = self.table[key].tolist()
if not any(np.isnan(values)) and not any(np.isinf(values)):
if key.endswith("_corr"):
uncertainties["fully_correlated"][key] = values
elif key.endswith("_uncorr"):
uncertainties["uncorrelated"][key] = values
uncertainties = {k: v for k, v in uncertainties.items() if len(v) > 0}
info = {
"uncertainties": uncertainties,
"extra_cuts_columns": extra_cuts_columns,
"cov_matrix": None,
}
return info
[docs]
@dataclass
class Correction1D(BaseCorrectionFromYaml):
dependant_variable: str | None = None
central_values: Iterable = None
lower_bounds: Iterable = None
upper_bounds: Iterable = None
def __post_init__(self):
super().__post_init__()
self.central_values = self.read_corrections()
self.dependant_variable = self.info["dependant_variable"]
self.unit = self.info["unit"]
self.lower_bounds = self.info["min"]
self.upper_bounds = self.info["max"]
self.populate_uncertainties()
[docs]
def read_corrections(self) -> np.ndarray:
"""
Reads correction values either directly from the config or from a table file.
This method checks whether the 'corrections' entry in the config (`self.info`)
is a list/array or a dictionary. If it is a list or NumPy-compatible array,
the correction values are loaded directly. If it is a dictionary, the method
builds the table path using `self.build_table_path()`, reads the table, and
extracts the correction values using the key specified in `self.table_key`.
Returns:
np.ndarray: An array of correction values.
Raises:
KeyError: If 'corrections' or `self.table_key` is missing from the config.
FileNotFoundError: If the table file does not exist at the constructed path.
ValueError: If the table does not contain the specified key.
"""
correction_info = self.info["corrections"]
if isinstance(correction_info, (list, np.ndarray)):
logging.info("Loading correction values from config array.")
return np.asarray(correction_info)
elif isinstance(correction_info, dict):
logging.info("Loading correction values from config table.")
table_path = self.build_table_path()
table = read_csv(table_path)
# Convert the table to a numpy array
return np.asarray(table[self.table_key])
@property
def value_edges(self) -> np.ndarray:
return np.unique(np.concatenate((self.lower_bounds, self.upper_bounds)))
@property
def value_mids(self) -> np.ndarray:
return (self.value_edges[1:] + self.value_edges[:-1]) / 2
@property
def visual_labels(self) -> List[str]:
return [
f"{low} < {self.dependant_variable} < {up} {self.unit}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [
f"{low} <= {column_name} < {up}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
@dataclass
class Correction1DFromCSV(BaseCorrectionFromCSV):
"""
1D correction reader from a single CSV.
Expects columns:
- 'central_value'
- 'dependant_variable' or 'dependant_variable_1'
- '{var}_unit' (optional)
- Either: '{var}_min', '{var}_max' for bin edges (continuous bins)
- Or: column named '{var}' with discrete integer values (uses equality queries)
"""
dependant_variable: str | None = None
central_values: Iterable = None
lower_bounds: Iterable = None
upper_bounds: Iterable = None
unit: str | None = None
use_equality_queries: bool = False
variable_values: Iterable = None
def __post_init__(self):
super().__post_init__()
self._is_valid_1D_table()
self.unit = self.get_unit()
# Determine if using bin edges or discrete values for corrections.
# If both min and max columns are present, use bin edges. Otherwise, look for discrete values which builds queries using equality.
min_col = f"{self.dependant_variable}_min"
max_col = f"{self.dependant_variable}_max"
if min_col in self.table and max_col in self.table:
self.use_equality_queries = False
self.lower_bounds = self.table[min_col].tolist()
self.upper_bounds = self.table[max_col].tolist()
elif self.dependant_variable in self.table:
self.use_equality_queries = True
self.variable_values = self.table[self.dependant_variable].tolist()
else:
raise ValueError(
f"CSV must contain either '{min_col}' and '{max_col}' columns for bin edges, "
f"or a '{self.dependant_variable}' column with discrete values."
)
self.central_values = self.table["central_value"].tolist()
self.populate_uncertainties()
def _is_valid_1D_table(self) -> None:
if "dependant_variable" in self.table:
self.dependant_variable = str(
self.table["dependant_variable"].iloc[0]
).strip()
elif "dependant_variable_1" in self.table:
self.dependant_variable = str(
self.table["dependant_variable_1"].iloc[0]
).strip()
else:
raise ValueError(
"CSV must contain 'dependant_variable' or 'dependant_variable_1' column."
)
[docs]
def get_unit(self) -> str:
"""Return the unit associated with the dependent variable.
This property attempts to determine the unit column in the table
using the following priority:
1. A column named "{dependent_variable}_unit".
2. A column named "unit".
3. The first column that ends with "_unit".
If a matching unit column is found, the value from the first row
of that column is returned. If no such column exists, an empty
string is returned.
Returns:
str: The unit string if found; otherwise an empty string.
"""
unit_col_candidates = [f"{self.dependant_variable}_unit", "unit"]
unit_col = next(
(c for c in unit_col_candidates if c in self.table.columns), None
)
if unit_col is None:
unit_cols = [c for c in self.table.columns if c.endswith("_unit")]
unit_col = unit_cols[0] if len(unit_cols) > 0 else None
if unit_col is not None:
return self.table[unit_col].iloc[0]
else:
return ""
@property
def value_edges(self) -> np.ndarray:
if self.use_equality_queries:
unique_vals = list(dict.fromkeys(self.variable_values))
return np.arange(len(unique_vals) + 1)
else:
return np.unique(np.concatenate((self.lower_bounds, self.upper_bounds)))
@property
def value_mids(self) -> np.ndarray:
if self.use_equality_queries:
unique_vals = list(dict.fromkeys(self.variable_values))
return np.arange(len(unique_vals))
else:
return (self.value_edges[1:] + self.value_edges[:-1]) / 2
@property
def visual_labels(self) -> List[str]:
label_col_candidates = [f"{self.dependant_variable}_label", "label"]
label_col = next(
(c for c in label_col_candidates if c in self.table.columns), None
)
if label_col is None:
label_cols = [c for c in self.table.columns if c.endswith("_label")]
label_col = label_cols[0] if len(label_cols) > 0 else None
if label_col is not None:
return self.table[label_col].tolist()
if self.use_equality_queries:
return [
f"{self.dependant_variable} = {val} {self.unit}"
for val in self.variable_values
]
else:
return [
f"{low} < {self.dependant_variable} < {up} {self.unit}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
if self.use_equality_queries:
queries = [f"{column_name} == {val}" for val in self.variable_values]
else:
queries = [
f"{low} <= {column_name} < {up}"
for low, up in zip(self.lower_bounds, self.upper_bounds)
]
return self.add_extra_cuts(queries, prefix)
[docs]
@dataclass
class Correction2D(BaseCorrectionFromYaml):
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
self.dependant_variable_1 = self.info["dependant_variable_1"]
self.dependant_variable_2 = self.info["dependant_variable_2"]
self.unit_1 = self.info["unit_1"]
self.unit_2 = self.info["unit_2"]
self.central_values = self._extract_central_values()
self.populate_uncertainties()
# Add an iterator to ensure that we'll loop over the corrections and bins
# consistently
@property
def iterator(self):
rows, columns, momenta, angles = [], [], [], []
for i, column_name in enumerate(self.central_values_table.columns):
for j, row_name in enumerate(self.central_values_table.index):
# clean the pi0 tables.... What a format...
if i == 0:
column_name = column_name.replace(" row:p column:t ", "")
# strip the strings and extract the momentum range
column_name = column_name.replace("p=", "")
ps = [float(x) / 10 for x in column_name.split("_")]
# strip the strings and extract the theta range
row_name = row_name.replace("cost=", "")
ts = [float(x) for x in row_name.split("_")]
rows.append(j)
columns.append(i)
momenta.append(ps)
angles.append(ts)
# Return a generator. Now can access all the central values and
# errors using iloc and the rows/colums
return zip(rows, columns, momenta, angles)
@cached_property
def central_values_table(self) -> DataFrame:
table_path = self.build_table_path("nom")
return read_csv(table_path)
@cached_property
def stat_error_table(self) -> DataFrame:
table_path = self.build_table_path("stat")
# Add column names when reading to skip creation of index column
return read_csv(table_path, names=[f"p bin {i}" for i in range(8)])
@cached_property
def sys_error_table(self) -> DataFrame:
table_path = self.build_table_path("sys")
# Add column names when reading to skip creation of index column
return read_csv(table_path, names=[f"p bin {i}" for i in range(8)])
@property
def visual_labels(self) -> List[str]:
return [
f"{momenta[0]} <= {self.dependant_variable_1} < {momenta[1]} {self.unit_1} & {angles[0]} <= {self.dependant_variable_2} < {angles[1]} {self.unit_2}"
for r, c, momenta, angles in self.iterator
]
def _extract_central_values(self):
return [
self.central_values_table.iloc[row, column]
for row, column, ps, ths in self.iterator
]
def _extract_errors(self, table: DataFrame):
return [table.iloc[row, column] for row, column, ps, ths in self.iterator]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name_1 = self._build_column_name(prefix, self.dependant_variable_1)
column_name_2 = self._build_column_name(prefix, self.dependant_variable_2)
queries = [
f"{momenta[0]} <= {column_name_1} < {momenta[1]} & {angles[0]} <= {column_name_2} < {angles[1]}"
for r, c, momenta, angles in self.iterator
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
def populate_uncertainties(self):
sysvar_uncertainties = get_uncertainty_types()
for unc_name, unc_ctgy in self.info["error_correlations"].items():
if unc_name == "stat":
unc_values = self._extract_errors(self.stat_error_table)
elif unc_name == "sys":
unc_values = self._extract_errors(self.sys_error_table)
else:
raise NotImplementedError(
"Only stat and sys error have been implemented now"
)
self.add_uncertainty(
unc_name=unc_name,
unc_values=unc_values,
unc_obj=sysvar_uncertainties[unc_ctgy],
)
[docs]
@dataclass
class Correction2DFromCSV(BaseCorrectionFromCSV):
"""
2D correction reader from a single CSV.
Expects columns:
- 'central_value'
- 'dependant_variable_1', 'dependant_variable_2'
- '{var1}_unit', '{var2}_unit'
- '{var1}_min','{var1}_max','{var2}_min','{var2}_max'
"""
dependant_variable_1: str | None = None
dependant_variable_2: str | None = None
unit_1: str | None = None
unit_2: str | None = None
central_values: Iterable = None
def __post_init__(self):
super().__post_init__()
self._is_valid_2D_table()
self.dependant_variable_1 = str(
self.table["dependant_variable_1"].iloc[0]
).strip()
self.dependant_variable_2 = str(
self.table["dependant_variable_2"].iloc[0]
).strip()
# Units
unit1_col = f"{self.dependant_variable_1}_unit"
unit2_col = f"{self.dependant_variable_2}_unit"
self.unit_1 = self.table[unit1_col].iloc[0] if unit1_col in self.table else ""
self.unit_2 = self.table[unit2_col].iloc[0] if unit2_col in self.table else ""
# Edges
self._v1_min = f"{self.dependant_variable_1}_min"
self._v1_max = f"{self.dependant_variable_1}_max"
self._v2_min = f"{self.dependant_variable_2}_min"
self._v2_max = f"{self.dependant_variable_2}_max"
for c in (self._v1_min, self._v1_max, self._v2_min, self._v2_max):
if c not in self.table.columns:
raise ValueError(f"CSV must contain '{c}' column for 2D bin edges.")
self.central_values = self.table["central_value"].tolist()
self.populate_uncertainties()
def _is_valid_2D_table(self) -> None:
if (
"dependant_variable_1" not in self.table
or "dependant_variable_2" not in self.table
):
raise ValueError(
"CSV must contain 'dependant_variable_1' and 'dependant_variable_2' column."
)
@property
def iterator(self):
# Provide a generator over rows with unpacked bins
for _, row in self.table.iterrows():
yield (
row[self._v1_min],
row[self._v1_max],
row[self._v2_min],
row[self._v2_max],
)
@property
def visual_labels(self) -> List[str]:
return [
f"{v1min} <= {self.dependant_variable_1} < {v1max} {self.unit_1} & "
f"{v2min} <= {self.dependant_variable_2} < {v2max} {self.unit_2}"
for (v1min, v1max, v2min, v2max) in self.iterator
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name_1 = self._build_column_name(prefix, self.dependant_variable_1)
column_name_2 = self._build_column_name(prefix, self.dependant_variable_2)
queries = [
f"{v1min} <= {column_name_1} < {v1max} & "
f"{v2min} <= {column_name_2} < {v2max}"
for (v1min, v1max, v2min, v2max) in self.iterator
]
return self.add_extra_cuts(queries, prefix)
[docs]
@dataclass
class Correction3DFromCSV(BaseCorrectionFromCSV):
"""
3D correction reader from a single CSV.
Expects columns:
- 'central_value'
- 'dependant_variable_1', 'dependant_variable_2', 'dependant_variable_3'
- '{var1}_unit', '{var2}_unit', '{var3}_unit'
- '{var1}_min','{var1}_max','{var2}_min','{var2}_max','{var3}_min','{var3}_max'
"""
dependant_variable_1: str | None = None
dependant_variable_2: str | None = None
dependant_variable_3: str | None = None
unit_1: str | None = None
unit_2: str | None = None
unit_3: str | None = None
central_values: Iterable = None
def __post_init__(self):
super().__post_init__()
self._is_valid_3D_table()
self.dependant_variable_1 = str(
self.table["dependant_variable_1"].iloc[0]
).strip()
self.dependant_variable_2 = str(
self.table["dependant_variable_2"].iloc[0]
).strip()
self.dependant_variable_3 = str(
self.table["dependant_variable_3"].iloc[0]
).strip()
# Units
unit1_col = f"{self.dependant_variable_1}_unit"
unit2_col = f"{self.dependant_variable_2}_unit"
unit3_col = f"{self.dependant_variable_3}_unit"
self.unit_1 = self.table[unit1_col].iloc[0] if unit1_col in self.table else ""
self.unit_2 = self.table[unit2_col].iloc[0] if unit2_col in self.table else ""
self.unit_3 = self.table[unit3_col].iloc[0] if unit3_col in self.table else ""
# Edges
self._v1_min = f"{self.dependant_variable_1}_min"
self._v1_max = f"{self.dependant_variable_1}_max"
self._v2_min = f"{self.dependant_variable_2}_min"
self._v2_max = f"{self.dependant_variable_2}_max"
self._v3_min = f"{self.dependant_variable_3}_min"
self._v3_max = f"{self.dependant_variable_3}_max"
for c in (
self._v1_min,
self._v1_max,
self._v2_min,
self._v2_max,
self._v3_min,
self._v3_max,
):
if c not in self.table.columns:
raise ValueError(f"CSV must contain '{c}' column for 2D bin edges.")
self.central_values = self.table["central_value"].tolist()
self.populate_uncertainties()
def _is_valid_3D_table(self) -> None:
if (
"dependant_variable_1" not in self.table
or "dependant_variable_2" not in self.table
or "dependant_variable_3" not in self.table
):
raise ValueError(
"CSV must contain 'dependant_variable_1', 'dependant_variable_2' and 'dependant_variable_3' column."
)
@property
def iterator(self):
# Provide a generator over rows with unpacked bins
for _, row in self.table.iterrows():
yield (
row[self._v1_min],
row[self._v1_max],
row[self._v2_min],
row[self._v2_max],
row[self._v3_min],
row[self._v3_max],
)
@property
def visual_labels(self) -> List[str]:
return [
f"{v1min} <= {self.dependant_variable_1} < {v1max} {self.unit_1} & "
f"{v2min} <= {self.dependant_variable_2} < {v2max} {self.unit_2} & "
f"{v3min} <= {self.dependant_variable_3} < {v3max} {self.unit_3}"
for (v1min, v1max, v2min, v2max, v3min, v3max) in self.iterator
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name_1 = self._build_column_name(prefix, self.dependant_variable_1)
column_name_2 = self._build_column_name(prefix, self.dependant_variable_2)
column_name_3 = self._build_column_name(prefix, self.dependant_variable_3)
queries = [
f"{v1min} <= {column_name_1} < {v1max} & "
f"{v2min} <= {column_name_2} < {v2max} & "
f"{v3min} <= {column_name_3} < {v3max}"
for (v1min, v1max, v2min, v2max, v3min, v3max) in self.iterator
]
return self.add_extra_cuts(queries, prefix)
[docs]
@dataclass
class Correction2DCategorical(BaseCorrectionFromYaml):
categorical_variable: str | None = None
continuus_variable: str | None = None
central_values: Iterable = None
continuus_edges: Iterable = None
categorical_values: Iterable = None
categorical_label: str = None
def __post_init__(self):
super().__post_init__()
self.central_values = []
part_dimensions = []
for correction in self.info["corrections"]:
self.central_values.extend(correction)
part_dimensions.append(len(correction))
self.categorical_variable = self.info["categorical_variable"]
self.categorical_values = self.info["categorical_values"]
self.categorical_label = self.info["categorical_label"]
self.continuus_variable = self.info["continuus_variable"]
self.continuus_edges = self.info["continuus_edges"]
self.extra_variables = self.info["extra_variables"]
@property
def iterator(self):
return itertools.product(
self.categorical_values, zip(self.continuus_edges, self.continuus_edges[1:])
)
@property
def strings(self) -> List[str]:
return [
f"{self.categorical_label}: {cv} [ {low} - {up} ]"
for cv, (low, up) in self.iterator
]
@property
def queries(self):
return [
f"{self.categorical_variable} == {cv} & {low} <= {self.continuus_variable} < {up} & {self._get_extra_cut()}"
for cv, (low, up) in self.iterator
]
[docs]
@dataclass
class CorrectionBF(BaseCorrectionFromYaml):
dependant_variable: str | None = None
central_values: Iterable = None
visual_labels: Iterable = None
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
central_values, error_amplitudes = self._calculate_scaling_ratios()
self.central_values = central_values
# Visual labels needs to be defined before we populate the uncertainties
# Otherwise the uncertainty object does not have visual_labels
self.visual_labels = self._create_strings()
self.populate_uncertainties(error_amplitudes)
self.dependant_variable = self.info["dependant_variable"]
def _create_strings(self) -> List[str]:
mother = Particle.from_pdgid(self.info["mother_particle"]).latex_name
daughter_pdgs = [
x for mode in self.info["modes"].values() for x in mode["daughters"]
]
strings = []
for daughter_set in daughter_pdgs:
daughter_names = []
for x in daughter_set:
try:
daughter_names.append(Particle.from_pdgid(x).latex_name)
except:
daughter_names.append(x)
strings.append(
rf"${mother} \rightarrow {' '.join(str(x) for x in daughter_names)}$"
)
return strings
def _calculate_scaling_ratios(self):
pdg_BFs = unp.uarray(
[x["pdg_live"][0] for x in self.info["modes"].values()],
[x["pdg_live"][1] for x in self.info["modes"].values()],
)
decaydec_BFs = unp.uarray(
[x["decay_dec"] for x in self.info["modes"].values()],
[0 for x in self.info["modes"].values()],
)
# Safe 0 division. Returns 1+- 0 for the ones where
# decay.dec = 0 or pdg = 0
corrections = np.divide(
pdg_BFs,
decaydec_BFs,
out=unp.uarray(np.ones_like(pdg_BFs), np.ones_like(pdg_BFs)),
where=((decaydec_BFs != 0) & (pdg_BFs != 0)),
)
return list(unp.nominal_values(corrections)), list(unp.std_devs(corrections))
[docs]
def populate_uncertainties(self, error_amplitudes: list):
"""
Overrides the method of the base class method as the error amplitutes are calculated dynamically from the calculate_scaling_ratios method
"""
unc_correlation = self.get_uncertainty_correlation()
sysvar_uncertainties = get_uncertainty_types()
self.add_uncertainty(
unc_name="BF_unc",
unc_values=error_amplitudes,
unc_obj=sysvar_uncertainties[unc_correlation],
)
[docs]
def get_uncertainty_correlation(self):
if self.info["correlation"] in [
"fully_correlated",
"uncorrelated",
"fully_correlated_in_parts",
]:
unc_type = self.info["correlation"]
elif self.info["correlation"] == "explicitly_correlated":
raise NotImplementedError(
"Cannot support custom correlation for BF correction yet"
)
else:
raise ValueError(
"Unkown correlation type. Available types are: fully_correlated, uncorrelated, fully_correlated_in_parts"
)
return unc_type
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [
(
f"{column_name} == '{mode['dmID']}'"
if isinstance(mode["dmID"], str)
else f"{column_name} in {mode['dmID']}"
)
for mode in self.info["modes"].values()
]
queries = self.add_extra_cuts(queries, prefix)
return queries
[docs]
@dataclass
class CustomCorrection(BaseCorrection):
info: InitVar[dict] = None
dependant_variable: str | None = None
central_values: Iterable = None
uncertainties: dict = field(default_factory=dict)
query_targets: Iterable = None
def __post_init__(self, info):
self.info = info
self._is_valid_info_dict()
self.dependant_variable = self.info["dependant_variable"]
self.central_values = self.info["central_values"]
self.query_targets = self.info["query_targets"]
self.unit = self.info["unit"]
self.title = self.info["title"]
self.cov_matrix = load_covariance_matrix(self.info)
self.populate_uncertainties()
@property
def value_edges(self) -> np.ndarray:
return np.arange(len(self.central_values) + 1)
@property
def value_mids(self) -> np.ndarray:
return (self.value_edges[1:] + self.value_edges[:-1]) / 2
@property
def visual_labels(self) -> List[str]:
return [
f"{self.dependant_variable} = {target}" for target in self.query_targets
]
[docs]
def build_queries(self, prefix: str | None = None) -> list:
column_name = self._build_column_name(prefix, self.dependant_variable)
queries = [f"{column_name} == '{target}'" for target in self.query_targets]
return queries
[docs]
@dataclass
class CorrectionPID(BaseCorrectionFromYaml):
uncertainties: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
rate = self.info["rate"]
self.check_valid_rate(rate)
self.table = self.get_table(rate)
self.central_values = self._extract_central_values()
self.p = self.info["momentum_variable"]
self.theta = self.info["theta_variable"]
self.PDG = self.info["PDG_variable"]
self.mcPDG = self.info["mcPDG_variable"]
self.momentum_unit = self.info["momentum_unit"]
# Add uncertainties as fully uncorrelated. This is a conservative choice
error_id = "stat"
self.uncertainties.update(
{
f"{error_id} uncertainty": UncorrelatedUncertainty(
f"{error_id} uncertainty",
self._extract_errors(f"{error_id}"),
self.visual_labels,
)
}
)
error_id = "sys"
self.uncertainties.update(
{
f"{error_id} uncertainty": FullyCorrelatedUncertainty(
f"{error_id} uncertainty",
self._extract_errors(f"{error_id}"),
self.visual_labels,
)
}
)
[docs]
@staticmethod
def check_valid_rate(rate):
valid_rates = ["eff", "fake"]
if rate not in valid_rates:
raise NotValidRateType(
f"Valid rate arguments are {*valid_rates,} but you passed {rate}"
)
[docs]
def get_table(self, rate):
table_finders = []
if "eID" in self.systematic:
if self.systematic == "eID_K_fake":
table_finders.append(
CorrectionTableFinder.Kfake_electrons(self.info, self.MC_production)
)
elif self.systematic == "eID_pi_fake":
table_finders.append(
CorrectionTableFinder.pifake_electrons(
self.info, self.MC_production
)
)
else:
table_finders.append(
CorrectionTableFinder.electrons(self.info, self.MC_production)
)
elif "muID" in self.systematic:
if self.systematic == "muID_K_fake":
table_finders.append(
CorrectionTableFinder.Kfake_muons(self.info, self.MC_production)
)
elif self.systematic == "muID_pi_fake":
table_finders.append(
CorrectionTableFinder.pifake_muons(self.info, self.MC_production)
)
else:
table_finders.append(
CorrectionTableFinder.muons(self.info, self.MC_production)
)
elif "kID" in self.systematic:
table_finders.append(
CorrectionTableFinder.kaons(self.info, self.MC_production)
)
elif "piID" in self.systematic:
table_finders.append(
CorrectionTableFinder.pions(self.info, self.MC_production)
)
eff_table = concat([x.eff_table for x in table_finders])
fake_rate_table = concat([x.fake_rate_table for x in table_finders])
if rate == "eff":
table = eff_table
# PATCH
# This has to be read somewhere else somehow
self._true_pdg = table_finders[0].true_pdg
elif rate == "fake":
table = fake_rate_table
# PATCH
# This has to be read somewhere else somehow
self._true_pdg = table_finders[0].fake_pdg
return table
@property
def true_pdg(self) -> list:
return self._true_pdg
@true_pdg.setter
def true_pdg(self, true_pdg):
self._true_pdg = true_pdg
@property
def iterator(self):
return self.table.iterrows()
@property
def queries(self):
# PATCH
# This just "implements" the property to satisfy the parent class
pass
[docs]
def build_queries(
self, prefix: str | None = None, extra_cut: str | None = None
) -> List[str]:
# Pre-compute column names to avoid repeated function calls
p_column_name = self._build_column_name(prefix, self.p)
theta_column_name = self._build_column_name(prefix, self.theta)
PDG_column_name = self._build_column_name(prefix, self.PDG)
mcPDG_column_name = self._build_column_name(prefix, self.mcPDG)
# Create a local reference for self._true_pdg to avoid repeated attribute access
true_pdg = self._true_pdg
# Use a list comprehension with cached lookups to improve performance
queries = []
append_query = queries.append # Local function assignment for faster append
# Use local variable access within the loop to speed up string formatting
for _, row in self.iterator:
# Access row[1] once and cache its values in local variables
p_min = row["p_min"]
p_max = row["p_max"]
theta_min = row["theta_min"]
theta_max = row["theta_max"]
mcPDG = row["mcPDG"]
# Construct the query string with reduced overhead
query = (
f"({p_min} <= {p_column_name} < {p_max} & "
f"{theta_min} <= {theta_column_name} < {theta_max} & "
f"{PDG_column_name} == {mcPDG} & "
f"{mcPDG_column_name} in {true_pdg})"
)
# Append the constructed query string to the list
append_query(query)
# Add any extra cuts to the queries if needed
queries = self.add_extra_cuts(queries, prefix)
return queries
@property
def visual_labels(self) -> List[str]:
return [
rf"{row[1]['p_min']} <= p < {row[1]['p_max']} {self.momentum_unit} & {row[1]['theta_min']} <= $\theta$ < {row[1]['theta_max']} & q = {row[1]['charge']}"
for row in self.iterator
]
def _extract_central_values(self):
return [row[1]["data_MC_ratio"] for row in self.iterator]
def _extract_errors(self, error_type):
# this assumes symmetric errors and takes the maximum out of the two.
return [
row[1][
[
f"data_MC_uncertainty_{error_type}_up",
f"data_MC_uncertainty_{error_type}_dn",
]
].max()
for row in self.iterator
]
# #######################################################################################
[docs]
def create_correction_object(
correction_source: str | Path | dict,
MC_production: str | None = None,
title: str | None = None,
cov_matrix_path: str | None = None,
) -> BaseCorrection:
"""Retrieves and creates the appropriate correction object based on the systematic effect and MC production type.
Args:
correcction source (str): The systematic effect identifier for YAML-based corrections, or CSV file path for CSV-based corrections.
MC_production (str, optional): The Monte Carlo production type identifier. Required for YAML-based corrections.
title (str, optional): Title for CSV-based corrections. If not provided, will use filename.
cov_matrix_path (str, optional): Path to explicit covariance matrix file for CSV-based corrections.
Returns:
BaseCorrection: The appropriate correction object based on the provided systematic effect and MC production type.
Raises:
NotImplementedError: If the correction type specified in the configuration is not implemented.
ValueError: If invalid combination of arguments is provided.
Example:
>>> # YAML-based correction
>>> correction = create_correction_object("syst1", "MC1")
>>> isinstance(correction, BaseCorrection)
True
>>> # CSV-based correction
>>> correction = create_correction_object(csv_path="path/to/file.csv", csv_type="1D")
>>> isinstance(correction, BaseCorrection)
True
>>> # CSV-based correction with explicit covariance matrix
>>> correction = create_correction_object(csv_path="path/to/file.csv", cov_matrix_path="path/to/cov.txt")
>>> isinstance(correction, BaseCorrection)
True
"""
correction_types = {
"1D": Correction1D,
"2D": Correction2D,
"2DCategorical": Correction2DCategorical,
"BF": CorrectionBF,
"PID": CorrectionPID,
}
csv_correction_types = {
"1D": Correction1DFromCSV,
"2D": Correction2DFromCSV,
"3D": Correction3DFromCSV,
}
# Handle CSV-based corrections
if isinstance(correction_source, (Path)) or (
isinstance(correction_source, str) and correction_source.endswith(".csv")
):
if not path.exists(correction_source):
raise ValueError(f"CSV file not found: {correction_source}")
# Determine correction type from CSV structure
try:
test_table = read_csv(correction_source)
if (
"dependant_variable_1" in test_table.columns
and "dependant_variable_2" in test_table.columns
and "dependant_variable_3" in test_table.columns
):
csv_type = "3D"
elif (
"dependant_variable_1" in test_table.columns
and "dependant_variable_2" in test_table.columns
):
csv_type = "2D"
elif (
"dependant_variable" in test_table.columns
or "dependant_variable_1" in test_table.columns
):
csv_type = "1D"
else:
raise ValueError(
"Cannot determine CSV correction type from columns. Please specify csv_type."
)
except Exception as e:
raise ValueError(f"Error reading CSV file {correction_source}: {e}")
if csv_type not in csv_correction_types:
raise NotImplementedError(
f"Available CSV correction types are: {list(csv_correction_types.keys())} but you passed {csv_type}"
)
return csv_correction_types[csv_type](
csv_path=correction_source, title=title, cov_matrix_path=cov_matrix_path
)
# Handle YAML-based corrections
elif isinstance(correction_source, str) and MC_production is not None:
MC_production = MC_production if MC_production is not None else ""
if MC_production == "":
raise ValueError("MC_production is required for YAML-based corrections")
corr_type = read_yaml(correction_source, MC_production)["correction_type"]
try:
return correction_types[corr_type](
systematic=correction_source, MC_production=MC_production
)
except KeyError:
raise NotImplementedError(
f"Available corrections are: {list(correction_types.keys())} but you passed {corr_type}"
)
# Handle custom corrections
elif isinstance(correction_source, dict):
return CustomCorrection(info=correction_source)
else:
raise ValueError(
"Pass a string for existing standard systematic to create a correction object from yaml files, "
"a dictionary to create a custom correction object, or provide csv_path for CSV-based corrections"
)
[docs]
class CorrectionTableFinder:
"""
Factory method class to get correction tables for kaons, pions, electrons and muons
"""
def __init__(
self, particle_species, online_cut, base_table_path, variable, MC_production
):
self.particle_species = particle_species
self.online_cut = online_cut
self.base_table_path = base_table_path
self.variable = variable
self.MC_production = MC_production
self.production_table_ids()
self.true_pdg = self.particle_species_settings[particle_species]["true_pdgs"]
self.fake_pdg = self.particle_species_settings[particle_species]["fake_pdgs"]
self.value = self.get_cut_value()
self.cut_type = self.get_cut_type()
efficiency_table_names = self.build_table_name(
self.particle_species_settings[self.particle_species]["eff_table_ids"]
)
fake_rate_table_names = self.build_table_name(
self.particle_species_settings[self.particle_species]["fake_rate_table_ids"]
)
self.eff_table = self.get_table(efficiency_table_names)
self.fake_rate_table = self.get_table(fake_rate_table_names)
[docs]
@classmethod
def kaons(cls, external_info, MC_production):
particle_species = "kaon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=None,
MC_production=MC_production,
)
[docs]
@classmethod
def pions(cls, external_info, MC_production):
particle_species = "pion"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=None,
MC_production=MC_production,
)
[docs]
@classmethod
def electrons(cls, external_info, MC_production):
particle_species = "elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def Kfake_electrons(cls, external_info, MC_production):
particle_species = "Kfake_elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def pifake_electrons(cls, external_info, MC_production):
particle_species = "pifake_elec"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def muons(cls, external_info, MC_production):
particle_species = "muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def Kfake_muons(cls, external_info, MC_production):
particle_species = "Kfake_muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
@classmethod
def pifake_muons(cls, external_info, MC_production):
particle_species = "pifake_muon"
return cls(
particle_species=particle_species,
online_cut=external_info["online_cut"],
base_table_path=external_info["table_paths"],
variable=external_info["variable"],
MC_production=MC_production,
)
[docs]
def production_table_ids(self):
"""
Updates hadron table IDs based on the MC production type.
Raises:
ValueError: If an unsupported MC production type is provided.
"""
if self.MC_production not in ["MC15ri", "MC15rd"]:
raise ValueError("Invalid production type. Must be 'MC15ri' or 'MC15rd'.")
eff_table_mapping = {
"MC15ri": {"kaon": "keff", "pion": "pieff"},
"MC15rd": {"kaon": "KEff", "pion": "piEff"},
}
fake_table_mapping = {
"MC15ri": {"kaon": "piFk", "pion": "kFpi"},
"MC15rd": {"kaon": "piFakeK", "pion": "KFakepi"},
}
if self.particle_species in ["kaon", "pion"]:
self.particle_species_settings[self.particle_species]["eff_table_ids"] = [
eff_table_mapping[self.MC_production][self.particle_species]
]
self.particle_species_settings[self.particle_species][
"fake_rate_table_ids"
] = [fake_table_mapping[self.MC_production][self.particle_species]]
@property
def particle_species_settings(self) -> dict:
if not hasattr(self, "_particle_species_settings"):
self._particle_species_settings = {
"kaon": {
"true_pdgs": [321, -321],
"fake_pdgs": [211, -211],
"eff_table_ids": ["keff"],
"fake_rate_table_ids": ["piFk"],
},
"pion": {
"true_pdgs": [211, -211],
"fake_pdgs": [321, -321],
"eff_table_ids": ["pieff"],
"fake_rate_table_ids": ["kFpi"],
},
"elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [321, 211, -321, -211],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": [
"K_e_fakeRate",
"pi_e_fakeRate",
],
},
"Kfake_elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [
321,
-321,
],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": ["K_e_fakeRate"],
},
"pifake_elec": {
"true_pdgs": [11, -11],
"fake_pdgs": [
211,
-211,
],
"eff_table_ids": ["e_efficiency"],
"fake_rate_table_ids": ["pi_e_fakeRate"],
},
"muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [321, 211, -321, -211],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": [
"K_mu_fakeRate",
"pi_mu_fakeRate",
],
},
"Kfake_muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [321, -321],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": ["K_mu_fakeRate"],
},
"pifake_muon": {
"true_pdgs": [13, -13],
"fake_pdgs": [
211,
-211,
],
"eff_table_ids": ["mu_efficiency"],
"fake_rate_table_ids": ["pi_mu_fakeRate"],
},
}
return self._particle_species_settings
[docs]
def get_cut_type(
self,
) -> str:
"""Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction
Args:
species: Particle species, should be K+ or pi+
Returns:
the cut type that has been applied online. Binary or global
"""
# Read the online selections that have been applied on the online reconstruction
# TODO update this to the config file of each experiment to avoid making the mistake of
# changing the value during the offline preproccesing
if "elec" in self.particle_species or "muon" in self.particle_species:
cut_type = self.online_cut
elif self.particle_species in ["kaon", "pion"]:
if "binaryPID" in self.online_cut:
cut_type = "B"
elif "ID" in self.online_cut:
cut_type = "G"
else:
logging.warning(
"Cut type, neither global, nor binary HID selection has been applied online"
)
else:
raise ValueError("Wrong particle species")
return cut_type
[docs]
def get_cut_value(self) -> str:
"""Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction
Args:
species: Particle species, should be K+ or pi+
Returns:
the cut type that has been applied online. Binary or global
"""
# Read the online selections that have been applied on the online reconstruction
# TODO update this to the config file of each experiment to avoid making the mistake of
# changing the value during the offline preproccesing
if self.particle_species in ["muon", "elec"]:
cut_value = self.online_cut[-1]
elif self.particle_species in ["kaon", "pion"]:
cut_value = self.online_cut[-3:]
return self.online_cut[-1]
[docs]
def build_table_name(
self,
table_ids: str,
) -> list:
"""Builds the efficiency and fake rate tables path names
Args:
table_ids: efficiency or fake table id
Returns:
list with the efficiency or fake rate table file names
"""
# Create the file names.
# These are both for plus and minus
if self.particle_species in ["kaon", "pion"]:
# Now add thhe names for negative charge
file_names.extend([self.build_hid_table_name(x, "m") for x in table_ids])
# elif self.particle_species in ["elec", "muon"]:
elif "elec" in self.particle_species or "muon" in self.particle_species:
file_names = ["_".join((x, "table.csv")) for x in table_ids]
return [path.join(self.base_table_path, x) for x in file_names]
[docs]
def build_hid_table_name(self, table_id, charge):
if self.MC_production == "MC15ri":
table_name = "_".join(
(
"Rdtmc",
table_id,
charge,
self.cut_type + "0-" + str(self.value)[-1],
"all.log",
)
)
elif self.MC_production == "MC15rd":
table_name = "_".join(
(
"MC15rd",
table_id,
"all",
self.cut_type + "0-" + str(self.value)[-1] + ".log",
)
)
return table_name
[docs]
def get_table(self, table_names):
if self.particle_species in ["kaon", "pion"]:
table = concat([read_csv(x) for x in table_names])
self.make_pidvar_compatible(self.MC_production, table, max_uncertainty=10)
elif "elec" in self.particle_species or "muon" in self.particle_species:
# elif self.particle_species in ["elec", "muon"]:
table = concat([read_csv(x) for x in table_names])
table.query(self.get_lid_queries(), inplace=True)
self.add_mcPDG_to_table(table)
return table
[docs]
def add_mcPDG_to_table(self, table):
table.loc[:, "mcPDG"] = -9999
table.loc[table["charge"] == "-", "mcPDG"] = self.true_pdg[0]
table.loc[table["charge"] == "+", "mcPDG"] = self.true_pdg[1]
[docs]
@staticmethod
def make_pidvar_compatible(
MC_production: str, table: DataFrame, max_uncertainty: Optional[float] = 1e2
):
"""
Convert the pandas dataframes obtained Hadron ID CSV tables via into a
format consistent with the format of the lepton ID tables which ``PIDvar``
understands.
In particular, convert the ``charge`` column from ``1``/``-1`` integer
entries to ``+``/``-`` string entries and calculate the
``theta_min``/``theta_max`` columns.
:param table: Pandas dataframe obtained from ``pandas.read_csv`` on HID table
:param inplace: Whether to modify the existing dataframe in place.
Otherwise, a copy of the existing dataframe will be returned.
:param max_uncertainty: Drop rows in HID tables where any of the data-MC
uncertainties (sys/stat up/down) exceed this value. Rationale: The HID
tables contain rows with nonsense uncertainties > 10⁸, so it is meant to
remove those entries. Therefore, the exact value is not important. Set
to ``None`` to disable dropping any columns.
:return: Modified dataframe that can be used by ``PIDvar``.
"""
# Some checks that table has expected format of Hadron ID tables
if MC_production == "MC15ri":
if not set(table["charge"]).issubset({1, -1}):
raise ValueError(
"Expected that the ``charge`` entries of the original Hadron ID dataframe consists"
+ "only of ``1`` and ``-1``, but it contains {}".format(
set(table["charge"])
)
)
# if "theta_min" in table or "theta_max" in table:
# raise ValueError("Dataframe already has ``theta_…`` columns")
if max_uncertainty is not None:
unc_cols = [
"data_MC_uncertainty_stat_up",
"data_MC_uncertainty_stat_dn",
"data_MC_uncertainty_sys_up",
"data_MC_uncertainty_sys_dn",
]
# table = table[table[unc_cols].max(axis=1) <= max_uncertainty]
# for θ in [0, π], cos(θ) is strictly decreasing, so we have invert min and max when inverting the cosine
if MC_production == "MC15ri":
table["theta_min"] = -9999
table["theta_max"] = -9999
table.loc[:, "theta_min"] = np.arccos(table["cos_max"].copy(deep=True))
table.loc[:, "theta_max"] = np.arccos(table["cos_min"].copy(deep=True))
elif MC_production == "MC15rd":
pass
# PIDvar expects charge columns to contain + or -
if MC_production == "MC15ri":
table.loc[:, "charge"] = np.where(table["charge"] == +1, "+", "-")
elif MC_production == "MC15rd":
table.loc[:, "charge"] = np.where(table["charge_min"] == -2, "-", "+")
return table
[docs]
def get_lid_queries(self):
working_point = f"(working_point == '{self.cut_type}')"
best_available = "(is_best_available == True)"
if "elec" in self.particle_species:
exclude_bins = "(not ((theta_min == 0.56 and theta_max == 2.23) or (theta_min == 0.22 and theta_max == 2.71) or (p_min == 0.2 and p_max == 7) or (p_min == 0.2 and p_max == 5)))"
elif "muon" in self.particle_species:
exclude_bins = "(not ((theta_min == 0.82 and theta_max == 2.22) or (theta_min == 0.4 and theta_max == 0.82) or (theta_min == 0.4 and theta_max == 2.6) or (p_min == 0.2 and p_max == 5)))"
variable = f"(variable == '{self.variable}')"
# PATCH exclude negative and extremely large weights
physical_values = "(0 < data_MC_ratio < 10)"
logging.warning(
f"If negative weights or extremely large weights are present in the table, these will be excluded. The arbitrarily selected 'physical range' is {physical_values}"
)
total_cutstring = " and ".join(
(working_point, best_available, exclude_bins, variable, physical_values)
)
logging.info(
f"The following cutstring has been applied to the provide LID table: {total_cutstring}"
)
return total_cutstring