Source code for sysvar.corrections

from __future__ import annotations

from functools import cached_property

from abc import ABC, abstractmethod
from dataclasses import dataclass, field, InitVar
from typing import List, Iterable, Optional
from os import path
import matplotlib.pyplot as plt

from particle import Particle

import itertools
import numpy as np
from pandas import DataFrame, concat, read_csv
from uncertainties import unumpy as unp, ufloat

from .uncertainties import (
    Uncertainty,
    FullyCorrelatedUncertainty,
    FullyCorrelatedUncertaintyInParts,
    UncorrelatedUncertainty,
    ExplicitlyCorrelatedUncertainty,
    get_uncertainty_types,
)
from sysvar.utils import SavableAttributesObject, read_yaml, load_covariance_matrix
from sysvar.visualize import CorrectionVisualizer, UncertaintyVisualizer

import logging

logging.basicConfig(
    format="%(levelname)s : %(funcName)s: %(lineno)d :  %(message)s",
    level=logging.INFO,
)


[docs] class MissingInformationError(Exception): pass
[docs] class UncertaintyWithSameNameExists(Exception): pass
[docs] class UnkownUncertaintyType(Exception): pass
[docs] class NotValidRateType(Exception): pass
[docs] @dataclass() class BaseCorrection(ABC, SavableAttributesObject): uncertainties: dict = field(default_factory=dict) @property @abstractmethod def visual_labels(self): pass
[docs] @abstractmethod def build_queries(self) -> list: pass
@staticmethod def _build_column_name(prefix: str | None, variable: str) -> str: """ Constructs a column name by combining a prefix and a variable name. This method takes an optional prefix and a mandatory variable name to create a column name string. If the prefix is provided, the resulting column name will be a concatenation of the prefix and the variable, separated by an underscore. If the prefix is `None`, the column name will simply be the variable name. Args: prefix (str | None): An optional string to prepend to the variable name. If `None`, no prefix is added. variable (str): The variable name to use for constructing the column name. Returns: str: The constructed column name, either as "prefix_variable" or just "variable" if no prefix is provided. Examples: >>> obj = MyClass() >>> obj._build_column_name("prefix", "variable") 'prefix_variable' >>> obj._build_column_name(None, "variable") 'variable' """ if not isinstance(variable, str): raise ValueError( f"{variable} is expected to be str by you passed {type(variable)}" ) if not isinstance(prefix, str) and prefix is not None: raise ValueError( f"{prefix} is expected to be str by you passed {type(prefix)}" ) elif prefix is None: column_name = variable else: column_name = "_".join([prefix, variable]) return column_name @property def N(self) -> int: return len(self.central_values) @property def total_error(self) -> np.ndarray: if len(self.uncertainties) > 1: return np.sqrt( np.sum( [np.power(x.errors, 2) for x in self.uncertainties.values()], axis=0 ) ) elif len(self.uncertainties) == 1: return self.uncertainties[list(self.uncertainties.keys())[0]].errors else: raise ValueError("No uncertainties have been added to the correction")
[docs] def add_uncertainty( self, unc_name, unc_values, unc_obj: Uncertainty, explicit_cov_matrix: Optional[np.ndarray] = None, ) -> None: """ Add an uncertainty to the Correction. Args: unc (Uncertainty): The uncertainty to be added. Raises: UncertaintyWithSameNameExists: If uncertainty with the same name has already been added to the variator. """ if unc_name in self.uncertainties.keys(): raise UncertaintyWithSameNameExists( f"An uncertainty with the name {unc_name} already exist in the set of uncertainties that that have been added to the Correction. Make sure that you add a specific uncertainty only once, and that there are no duplicate names" ) else: if self.cov_matrix is not None: # Use the provided covariance matrix to create an explicitly correlated uncertainty self.uncertainties.update( { unc_name: unc_obj( unc_name, unc_values, self.visual_labels, explicit_cov_matrix, ) } ) else: self.uncertainties.update( {unc_name: unc_obj(unc_name, unc_values, self.visual_labels)} )
[docs] def populate_uncertainties(self): """ Populate the `uncertainties` attribute with uncertainty objects based on the provided information. This method adds uncertainties either from a provided covariance matrix (explicitly correlated uncertainty), or from a dictionary of uncertainty types and values specified in the `info` attribute. If a covariance matrix (`cov_matrix`) is available, it creates and adds a single `ExplicitlyCorrelatedUncertainty`. Otherwise, it looks up available uncertainty types, validates them, and populates the uncertainties accordingly. Raises: UnknownUncertaintyType: If an unsupported uncertainty type is found in the input data. Notes: - The uncertainty type definitions are expected to be returned from `get_uncertainty_types()`. - Each uncertainty must have a unique name. """ if self.cov_matrix is not None: # Use the provided covariance matrix to create an explicitly correlated uncertainty errors = np.sqrt(np.diag(self.cov_matrix)) # Extract errors from diagonal self.add_uncertainty( unc_name="explicit_covariance", unc_values=errors, unc_obj=ExplicitlyCorrelatedUncertainty, explicit_cov_matrix=self.cov_matrix, ) else: # Load the implemented uncertainty types sysvar_uncertainties = get_uncertainty_types() for unc_ctgy, uncertainty_dictionary in self.info["uncertainties"].items(): if unc_ctgy not in sysvar_uncertainties.keys(): raise UnkownUncertaintyType( f"Unkown type of uncertainty is declared in the yaml file. Available uncertainty types are: ({', '.join(list(sysvar_uncertainties.keys()))}) but '{unc_ctgy}' was found in the yaml file" ) else: for unc_name, unc_values in uncertainty_dictionary.items(): self.add_uncertainty( unc_name=unc_name, unc_values=unc_values, unc_obj=sysvar_uncertainties[unc_ctgy], )
[docs] def plot_error_comparison( self, save: bool = False, filename: str = "" ) -> tuple[plt.Figure, plt.Axes]: self.visualizer = CorrectionVisualizer(self) fig, ax = self.visualizer.plot_error_comparison(save=save, filename=filename) return fig, ax
[docs] def plot_uncertainty( self, unc_name: str | None = None, save: bool = False, filename: str = "" ): if unc_name is None: for unc_obj in self.uncertainties.values(): self.visualizer = UncertaintyVisualizer(unc_obj) self.visualizer.plot_cov_and_corr(save=save, filename=filename) else: self.visualizer = UncertaintyVisualizer(self.uncertainties[unc_name]) self.visualizer.plot_cov_and_corr(save=save, filename=filename)
[docs] @dataclass() class BaseCorrectionFromYaml(BaseCorrection): systematic: str = None MC_production: str = None def __post_init__(self): super().__init__() try: self.info = read_yaml(self.systematic, self.MC_production) except TypeError: raise MissingInformationError( f"Need to specify the systematic effect and the MC production in the positional arguments. You passed {self.systematic} and {self.MC_production}" ) self.visualizer = None self.title = self.info["title"] self.cov_matrix = load_covariance_matrix(self.info)
[docs] def add_extra_cuts(self, queries: str, prefix: str) -> str: if self._get_extra_cut_info() is not None: # Add the prefix to the extra cut info for var, values in self._get_extra_cut_info().items(): extra_cut = self._build_column_name(prefix, f"{var} in {values}") queries = self._extend_queries_with_extra_cut(queries, extra_cut) return queries else: return queries
@staticmethod def _extend_queries_with_extra_cut(queries: list, extra_cut: str) -> list: """Extends a list of queries by appending an extra condition to each query. Args: queries (list): A list of query strings. extra_cut (str): An additional condition to be appended to each query. Returns: list: A new list of query strings with the extra condition appended. Example: >>> _extend_queries_with_extra_cut(['query1', 'query2'], 'extra_condition') ['query1 & extra_condition', 'query2 & extra_condition'] """ return [" & ".join([q, extra_cut]) for q in queries] def _get_extra_cut_info(self): return self.info["extra_cuts"] @property def table_dir(self) -> str: """ Returns the directory path where the tables are stored. Returns: str: The directory path where the tables are stored. """ return self.info["corrections"]["table_dir"] @property def table_name(self) -> str: """ Returns the table name by reading it from the config file. Returns: str: The table name. """ return self.info["corrections"]["table_name"] @property def table_ext(self) -> str: """ Returns the table extension by reading it from the config file. Returns: str: The table extension. """ return self.info["corrections"]["table_ext"] @property def table_key(self) -> str: """ Returns the table key by reading it from the config file. Returns: str: The table key. """ return self.info["corrections"]["table_key"]
[docs] def build_table_path(self, suffix: Optional[str] = None) -> str: """ Builds the path for a table file based on the given suffix. Args: suffix (str): The suffix to be added to the table name. Returns: str: The full path of the table file. Raises: ValueError: If the table extension is unknown. """ if suffix is not None: if not isinstance(suffix, str): raise ValueError(f"Suffix must be a string, but got {type(suffix)}") else: base_name = f"{self.table_name}_{suffix}" else: base_name = self.table_name if self.table_ext in ["txt", "csv"]: filename = ".".join((base_name, self.table_ext)) else: raise ValueError( f"Unknown table extension {self.table_ext}. Supported extensions are: txt, csv" ) return path.join(self.table_dir, filename)
[docs] @dataclass class Correction1D(BaseCorrectionFromYaml): dependant_variable: str | None = None central_values: Iterable = None lower_bounds: Iterable = None upper_bounds: Iterable = None def __post_init__(self): super().__post_init__() self.central_values = self.read_corrections() self.dependant_variable = self.info["dependant_variable"] self.unit = self.info["unit"] self.lower_bounds = self.info["min"] self.upper_bounds = self.info["max"] self.populate_uncertainties()
[docs] def read_corrections(self) -> np.ndarray: """ Reads correction values either directly from the config or from a table file. This method checks whether the 'corrections' entry in the config (`self.info`) is a list/array or a dictionary. If it is a list or NumPy-compatible array, the correction values are loaded directly. If it is a dictionary, the method builds the table path using `self.build_table_path()`, reads the table, and extracts the correction values using the key specified in `self.table_key`. Returns: np.ndarray: An array of correction values. Raises: KeyError: If 'corrections' or `self.table_key` is missing from the config. FileNotFoundError: If the table file does not exist at the constructed path. ValueError: If the table does not contain the specified key. """ correction_info = self.info["corrections"] if isinstance(correction_info, (list, np.ndarray)): logging.info("Loading correction values from config array.") return np.asarray(correction_info) elif isinstance(correction_info, dict): logging.info("Loading correction values from config table.") table_path = self.build_table_path() table = read_csv(table_path) # Convert the table to a numpy array return np.asarray(table[self.table_key])
@property def value_edges(self) -> np.ndarray: return np.unique(np.concatenate((self.lower_bounds, self.upper_bounds))) @property def value_mids(self) -> np.ndarray: return (self.value_edges[1:] + self.value_edges[:-1]) / 2 @property def visual_labels(self) -> List[str]: return [ f"{low} < {self.dependant_variable} < {up} {self.unit}" for low, up in zip(self.lower_bounds, self.upper_bounds) ]
[docs] def build_queries(self, prefix: str | None = None) -> list: column_name = self._build_column_name(prefix, self.dependant_variable) queries = [ f"{low} <= {column_name} < {up}" for low, up in zip(self.lower_bounds, self.upper_bounds) ] queries = self.add_extra_cuts(queries, prefix) return queries
[docs] @dataclass class Correction2D(BaseCorrectionFromYaml): uncertainties: dict = field(default_factory=dict) def __post_init__(self): super().__post_init__() self.dependant_variable_1 = self.info["dependant_variable_1"] self.dependant_variable_2 = self.info["dependant_variable_2"] self.unit_1 = self.info["unit_1"] self.unit_2 = self.info["unit_2"] self.central_values = self._extract_central_values() self.populate_uncertainties() # Add an iterator to ensure that we'll loop over the corrections and bins # consistently @property def iterator(self): rows, columns, momenta, angles = [], [], [], [] for i, column_name in enumerate(self.central_values_table.columns): for j, row_name in enumerate(self.central_values_table.index): # clean the pi0 tables.... What a format... if i == 0: column_name = column_name.replace(" row:p column:t ", "") # strip the strings and extract the momentum range column_name = column_name.replace("p=", "") ps = [float(x) / 10 for x in column_name.split("_")] # strip the strings and extract the theta range row_name = row_name.replace("cost=", "") ts = [float(x) for x in row_name.split("_")] rows.append(j) columns.append(i) momenta.append(ps) angles.append(ts) # Return a generator. Now can access all the central values and # errors using iloc and the rows/colums return zip(rows, columns, momenta, angles) @cached_property def central_values_table(self) -> DataFrame: table_path = self.build_table_path("nom") return read_csv(table_path) @cached_property def stat_error_table(self) -> DataFrame: table_path = self.build_table_path("stat") # Add column names when reading to skip creation of index column return read_csv(table_path, names=[f"p bin {i}" for i in range(8)]) @cached_property def sys_error_table(self) -> DataFrame: table_path = self.build_table_path("sys") # Add column names when reading to skip creation of index column return read_csv(table_path, names=[f"p bin {i}" for i in range(8)]) @property def visual_labels(self) -> List[str]: return [ f"{momenta[0]} <= {self.dependant_variable_1} < {momenta[1]} {self.unit_1} & {angles[0]} <= {self.dependant_variable_2} < {angles[1]} {self.unit_2}" for r, c, momenta, angles in self.iterator ] def _extract_central_values(self): return [ self.central_values_table.iloc[row, column] for row, column, ps, ths in self.iterator ] def _extract_errors(self, table: DataFrame): return [table.iloc[row, column] for row, column, ps, ths in self.iterator]
[docs] def build_queries(self, prefix: str | None = None) -> list: column_name_1 = self._build_column_name(prefix, self.dependant_variable_1) column_name_2 = self._build_column_name(prefix, self.dependant_variable_2) queries = [ f"{momenta[0]} <= {column_name_1} < {momenta[1]} & {angles[0]} <= {column_name_2} < {angles[1]}" for r, c, momenta, angles in self.iterator ] queries = self.add_extra_cuts(queries, prefix) return queries
[docs] def populate_uncertainties(self): sysvar_uncertainties = get_uncertainty_types() for unc_name, unc_ctgy in self.info["error_correlations"].items(): if unc_name == "stat": unc_values = self._extract_errors(self.stat_error_table) elif unc_name == "sys": unc_values = self._extract_errors(self.sys_error_table) else: raise NotImplementedError( "Only stat and sys error have been implemented now" ) self.add_uncertainty( unc_name=unc_name, unc_values=unc_values, unc_obj=sysvar_uncertainties[unc_ctgy], )
[docs] @dataclass class Correction2DCategorical(BaseCorrectionFromYaml): categorical_variable: str | None = None continuus_variable: str | None = None central_values: Iterable = None continuus_edges: Iterable = None categorical_values: Iterable = None categorical_label: str = None def __post_init__(self): super().__post_init__() self.central_values = [] part_dimensions = [] for correction in self.info["corrections"]: self.central_values.extend(correction) part_dimensions.append(len(correction)) self.categorical_variable = self.info["categorical_variable"] self.categorical_values = self.info["categorical_values"] self.categorical_label = self.info["categorical_label"] self.continuus_variable = self.info["continuus_variable"] self.continuus_edges = self.info["continuus_edges"] self.extra_variables = self.info["extra_variables"] @property def iterator(self): return itertools.product( self.categorical_values, zip(self.continuus_edges, self.continuus_edges[1:]) ) @property def strings(self) -> List[str]: return [ f"{self.categorical_label}: {cv} [ {low} - {up} ]" for cv, (low, up) in self.iterator ] @property def queries(self): return [ f"{self.categorical_variable} == {cv} & {low} <= {self.continuus_variable} < {up} & {self._get_extra_cut()}" for cv, (low, up) in self.iterator ]
[docs] @dataclass class CorrectionBF(BaseCorrectionFromYaml): dependant_variable: str | None = None central_values: Iterable = None visual_labels: Iterable = None uncertainties: dict = field(default_factory=dict) def __post_init__(self): super().__post_init__() central_values, error_amplitudes = self._calculate_scaling_ratios() self.central_values = central_values # Visual labels needs to be defined before we populate the uncertainties # Otherwise the uncertainty object does not have visual_labels self.visual_labels = self._create_strings() self.populate_uncertainties(error_amplitudes) self.dependant_variable = self.info["dependant_variable"] def _create_strings(self) -> List[str]: mother = Particle.from_pdgid(self.info["mother_particle"]).latex_name daughter_pdgs = [ x for mode in self.info["modes"].values() for x in mode["daughters"] ] strings = [] for daughter_set in daughter_pdgs: daughter_names = [] for x in daughter_set: try: daughter_names.append(Particle.from_pdgid(x).latex_name) except: daughter_names.append(x) strings.append( rf"${mother} \rightarrow {' '.join(str(x) for x in daughter_names)}$" ) return strings def _calculate_scaling_ratios(self): pdg_BFs = unp.uarray( [x["pdg_live"][0] for x in self.info["modes"].values()], [x["pdg_live"][1] for x in self.info["modes"].values()], ) decaydec_BFs = unp.uarray( [x["decay_dec"] for x in self.info["modes"].values()], [0 for x in self.info["modes"].values()], ) # Safe 0 division. Returns 1+- 0 for the ones where # decay.dec = 0 or pdg = 0 corrections = np.divide( pdg_BFs, decaydec_BFs, out=unp.uarray(np.ones_like(pdg_BFs), np.ones_like(pdg_BFs)), where=((decaydec_BFs != 0) & (pdg_BFs != 0)), ) return list(unp.nominal_values(corrections)), list(unp.std_devs(corrections))
[docs] def populate_uncertainties(self, error_amplitudes: list): """ Overrides the method of the base class method as the error amplitutes are calculated dynamically from the calculate_scaling_ratios method """ unc_correlation = self.get_uncertainty_correlation() sysvar_uncertainties = get_uncertainty_types() self.add_uncertainty( unc_name="BF_unc", unc_values=error_amplitudes, unc_obj=sysvar_uncertainties[unc_correlation], )
[docs] def get_uncertainty_correlation(self): if self.info["correlation"] in [ "fully_correlated", "uncorrelated", "fully_correlated_in_parts", ]: unc_type = self.info["correlation"] elif self.info["correlation"] == "explicitly_correlated": raise NotImplementedError( "Cannot support custom correlation for BF correction yet" ) else: raise ValueError( "Unkown correlation type. Available types are: fully_correlated, uncorrelated, fully_correlated_in_parts" ) return unc_type
[docs] def build_queries(self, prefix: str | None = None) -> list: column_name = self._build_column_name(prefix, self.dependant_variable) queries = [ ( f"{column_name} == '{mode['dmID']}'" if isinstance(mode["dmID"], str) else f"{column_name} in {mode['dmID']}" ) for mode in self.info["modes"].values() ] queries = self.add_extra_cuts(queries, prefix) return queries
[docs] @dataclass class CustomCorrection(BaseCorrection): info: InitVar[dict] = None dependant_variable: str | None = None central_values: Iterable = None uncertainties: dict = field(default_factory=dict) query_targets: Iterable = None def __post_init__(self, info): self.info = info self.dependant_variable = self.info["dependant_variable"] self.central_values = self.info["central_values"] self.query_targets = self.info["query_targets"] self.unit = self.info["unit"] self.title = self.info["title"] self.cov_matrix = load_covariance_matrix(self.info) self.populate_uncertainties() @property def value_edges(self) -> np.ndarray: return np.arange(len(self.central_values) + 1) @property def value_mids(self) -> np.ndarray: return (self.value_edges[1:] + self.value_edges[:-1]) / 2 @property def visual_labels(self) -> List[str]: return [ f"{self.dependant_variable} = {target}" for target in self.query_targets ]
[docs] def build_queries(self, prefix: str | None = None) -> list: column_name = self._build_column_name(prefix, self.dependant_variable) queries = [f"{column_name} == '{target}'" for target in self.query_targets] return queries
[docs] @dataclass class CorrectionPID(BaseCorrectionFromYaml): uncertainties: dict = field(default_factory=dict) def __post_init__(self): super().__post_init__() rate = self.info["rate"] self.check_valid_rate(rate) self.table = self.get_table(rate) self.central_values = self._extract_central_values() self.p = self.info["momentum_variable"] self.theta = self.info["theta_variable"] self.PDG = self.info["PDG_variable"] self.mcPDG = self.info["mcPDG_variable"] self.momentum_unit = self.info["momentum_unit"] # Add uncertainties as fully uncorrelated. This is a conservative choice error_id = "stat" self.uncertainties.update( { f"{error_id} uncertainty": UncorrelatedUncertainty( f"{error_id} uncertainty", self._extract_errors(f"{error_id}"), self.visual_labels, ) } ) error_id = "sys" self.uncertainties.update( { f"{error_id} uncertainty": FullyCorrelatedUncertainty( f"{error_id} uncertainty", self._extract_errors(f"{error_id}"), self.visual_labels, ) } )
[docs] @staticmethod def check_valid_rate(rate): valid_rates = ["eff", "fake"] if rate not in valid_rates: raise NotValidRateType( f"Valid rate arguments are {*valid_rates,} but you passed {rate}" )
[docs] def get_table(self, rate): table_finders = [] if "eID" in self.systematic: if self.systematic == "eID_K_fake": table_finders.append( CorrectionTableFinder.Kfake_electrons(self.info, self.MC_production) ) elif self.systematic == "eID_pi_fake": table_finders.append( CorrectionTableFinder.pifake_electrons( self.info, self.MC_production ) ) else: table_finders.append( CorrectionTableFinder.electrons(self.info, self.MC_production) ) elif "muID" in self.systematic: if self.systematic == "muID_K_fake": table_finders.append( CorrectionTableFinder.Kfake_muons(self.info, self.MC_production) ) elif self.systematic == "muID_pi_fake": table_finders.append( CorrectionTableFinder.pifake_muons(self.info, self.MC_production) ) else: table_finders.append( CorrectionTableFinder.muons(self.info, self.MC_production) ) elif "kID" in self.systematic: table_finders.append( CorrectionTableFinder.kaons(self.info, self.MC_production) ) elif "piID" in self.systematic: table_finders.append( CorrectionTableFinder.pions(self.info, self.MC_production) ) eff_table = concat([x.eff_table for x in table_finders]) fake_rate_table = concat([x.fake_rate_table for x in table_finders]) if rate == "eff": table = eff_table # PATCH # This has to be read somewhere else somehow self._true_pdg = table_finders[0].true_pdg elif rate == "fake": table = fake_rate_table # PATCH # This has to be read somewhere else somehow self._true_pdg = table_finders[0].fake_pdg return table
@property def true_pdg(self) -> list: return self._true_pdg @true_pdg.setter def true_pdg(self, true_pdg): self._true_pdg = true_pdg @property def iterator(self): return self.table.iterrows() @property def queries(self): # PATCH # This just "implements" the property to satisfy the parent class pass
[docs] def build_queries( self, prefix: str | None = None, extra_cut: str | None = None ) -> List[str]: # Pre-compute column names to avoid repeated function calls p_column_name = self._build_column_name(prefix, self.p) theta_column_name = self._build_column_name(prefix, self.theta) PDG_column_name = self._build_column_name(prefix, self.PDG) mcPDG_column_name = self._build_column_name(prefix, self.mcPDG) # Create a local reference for self._true_pdg to avoid repeated attribute access true_pdg = self._true_pdg # Use a list comprehension with cached lookups to improve performance queries = [] append_query = queries.append # Local function assignment for faster append # Use local variable access within the loop to speed up string formatting for _, row in self.iterator: # Access row[1] once and cache its values in local variables p_min = row["p_min"] p_max = row["p_max"] theta_min = row["theta_min"] theta_max = row["theta_max"] mcPDG = row["mcPDG"] # Construct the query string with reduced overhead query = ( f"({p_min} <= {p_column_name} < {p_max} & " f"{theta_min} <= {theta_column_name} < {theta_max} & " f"{PDG_column_name} == {mcPDG} & " f"{mcPDG_column_name} in {true_pdg})" ) # Append the constructed query string to the list append_query(query) # Add any extra cuts to the queries if needed queries = self.add_extra_cuts(queries, prefix) return queries
@property def visual_labels(self) -> List[str]: return [ rf"{row[1]['p_min']} <= p < {row[1]['p_max']} {self.momentum_unit} & {row[1]['theta_min']} <= $\theta$ < {row[1]['theta_max']} & q = {row[1]['charge']}" for row in self.iterator ] def _extract_central_values(self): return [row[1]["data_MC_ratio"] for row in self.iterator] def _extract_errors(self, error_type): # this assumes symmetric errors and takes the maximum out of the two. return [ row[1][ [ f"data_MC_uncertainty_{error_type}_up", f"data_MC_uncertainty_{error_type}_dn", ] ].max() for row in self.iterator ]
# #######################################################################################
[docs] def create_correction_object( syst_effect: str | dict | None, MC_prod: str ) -> BaseCorrection: """Retrieves amd creates the appropriate correction object based on the systematic effect and MC production type. Args: syst_effect (str): The systematic effect identifier. MC_prod (str): The Monte Carlo production type identifier. Returns: BaseCorrection: The appropriate correction object based on the provided systematic effect and MC production type. Raises: NotImplementedError: If the correction type specified in the configuration is not implemented. Example: >>> correction = get_correction_object("syst1", "MC1") >>> isinstance(correction, BaseCorrection) True """ correction_types = { "1D": Correction1D, "2D": Correction2D, "2DCategorical": Correction2DCategorical, "BF": CorrectionBF, "PID": CorrectionPID, } if isinstance(syst_effect, str): corr_type = read_yaml(syst_effect, MC_prod)["correction_type"] try: return correction_types[corr_type]( systematic=syst_effect, MC_production=MC_prod ) except KeyError: raise NotImplementedError( f"Available corrections are: {list(correction_types.keys())} but you passed {corr_type}" ) elif isinstance(syst_effect, dict): return CustomCorrection(info=syst_effect) else: raise ValueError( "Pass a string for existing standard systematic to create a correction object from yaml files or a dictionary to create a custom correction object" )
[docs] class CorrectionTableFinder: """ Factory method class to get correction tables for kaons, pions, electrons and muons """ def __init__( self, particle_species, online_cut, base_table_path, variable, MC_production ): self.particle_species = particle_species self.online_cut = online_cut self.base_table_path = base_table_path self.variable = variable self.MC_production = MC_production self.production_table_ids() self.true_pdg = self.particle_species_settings[particle_species]["true_pdgs"] self.fake_pdg = self.particle_species_settings[particle_species]["fake_pdgs"] self.value = self.get_cut_value() self.cut_type = self.get_cut_type() efficiency_table_names = self.build_table_name( self.particle_species_settings[self.particle_species]["eff_table_ids"] ) fake_rate_table_names = self.build_table_name( self.particle_species_settings[self.particle_species]["fake_rate_table_ids"] ) self.eff_table = self.get_table(efficiency_table_names) self.fake_rate_table = self.get_table(fake_rate_table_names)
[docs] @classmethod def kaons(cls, external_info, MC_production): particle_species = "kaon" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=None, MC_production=MC_production, )
[docs] @classmethod def pions(cls, external_info, MC_production): particle_species = "pion" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=None, MC_production=MC_production, )
[docs] @classmethod def electrons(cls, external_info, MC_production): particle_species = "elec" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] @classmethod def Kfake_electrons(cls, external_info, MC_production): particle_species = "Kfake_elec" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] @classmethod def pifake_electrons(cls, external_info, MC_production): particle_species = "pifake_elec" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] @classmethod def muons(cls, external_info, MC_production): particle_species = "muon" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] @classmethod def Kfake_muons(cls, external_info, MC_production): particle_species = "Kfake_muon" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] @classmethod def pifake_muons(cls, external_info, MC_production): particle_species = "pifake_muon" return cls( particle_species=particle_species, online_cut=external_info["online_cut"], base_table_path=external_info["table_paths"], variable=external_info["variable"], MC_production=MC_production, )
[docs] def production_table_ids(self): """ Updates hadron table IDs based on the MC production type. Raises: ValueError: If an unsupported MC production type is provided. """ if self.MC_production not in ["MC15ri", "MC15rd"]: raise ValueError("Invalid production type. Must be 'MC15ri' or 'MC15rd'.") eff_table_mapping = { "MC15ri": {"kaon": "keff", "pion": "pieff"}, "MC15rd": {"kaon": "KEff", "pion": "piEff"}, } fake_table_mapping = { "MC15ri": {"kaon": "piFk", "pion": "kFpi"}, "MC15rd": {"kaon": "piFakeK", "pion": "KFakepi"}, } if self.particle_species in ["kaon", "pion"]: self.particle_species_settings[self.particle_species]["eff_table_ids"] = [ eff_table_mapping[self.MC_production][self.particle_species] ] self.particle_species_settings[self.particle_species][ "fake_rate_table_ids" ] = [fake_table_mapping[self.MC_production][self.particle_species]]
@property def particle_species_settings(self) -> dict: if not hasattr(self, "_particle_species_settings"): self._particle_species_settings = { "kaon": { "true_pdgs": [321, -321], "fake_pdgs": [211, -211], "eff_table_ids": ["keff"], "fake_rate_table_ids": ["piFk"], }, "pion": { "true_pdgs": [211, -211], "fake_pdgs": [321, -321], "eff_table_ids": ["pieff"], "fake_rate_table_ids": ["kFpi"], }, "elec": { "true_pdgs": [11, -11], "fake_pdgs": [321, 211, -321, -211], "eff_table_ids": ["e_efficiency"], "fake_rate_table_ids": [ "K_e_fakeRate", "pi_e_fakeRate", ], }, "Kfake_elec": { "true_pdgs": [11, -11], "fake_pdgs": [ 321, -321, ], "eff_table_ids": ["e_efficiency"], "fake_rate_table_ids": ["K_e_fakeRate"], }, "pifake_elec": { "true_pdgs": [11, -11], "fake_pdgs": [ 211, -211, ], "eff_table_ids": ["e_efficiency"], "fake_rate_table_ids": ["pi_e_fakeRate"], }, "muon": { "true_pdgs": [13, -13], "fake_pdgs": [321, 211, -321, -211], "eff_table_ids": ["mu_efficiency"], "fake_rate_table_ids": [ "K_mu_fakeRate", "pi_mu_fakeRate", ], }, "Kfake_muon": { "true_pdgs": [13, -13], "fake_pdgs": [321, -321], "eff_table_ids": ["mu_efficiency"], "fake_rate_table_ids": ["K_mu_fakeRate"], }, "pifake_muon": { "true_pdgs": [13, -13], "fake_pdgs": [ 211, -211, ], "eff_table_ids": ["mu_efficiency"], "fake_rate_table_ids": ["pi_mu_fakeRate"], }, } return self._particle_species_settings
[docs] def get_cut_type( self, ) -> str: """Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction Args: species: Particle species, should be K+ or pi+ Returns: the cut type that has been applied online. Binary or global """ # Read the online selections that have been applied on the online reconstruction # TODO update this to the config file of each experiment to avoid making the mistake of # changing the value during the offline preproccesing if "elec" in self.particle_species or "muon" in self.particle_species: cut_type = self.online_cut elif self.particle_species in ["kaon", "pion"]: if "binaryPID" in self.online_cut: cut_type = "B" elif "ID" in self.online_cut: cut_type = "G" else: logging.warning( "Cut type, neither global, nor binary HID selection has been applied online" ) else: raise ValueError("Wrong particle species") return cut_type
[docs] def get_cut_value(self) -> str: """Reads the yaml configuration file and extracts the cut type that have been applied in the online reconstuction Args: species: Particle species, should be K+ or pi+ Returns: the cut type that has been applied online. Binary or global """ # Read the online selections that have been applied on the online reconstruction # TODO update this to the config file of each experiment to avoid making the mistake of # changing the value during the offline preproccesing if self.particle_species in ["muon", "elec"]: cut_value = self.online_cut[-1] elif self.particle_species in ["kaon", "pion"]: cut_value = self.online_cut[-3:] return self.online_cut[-1]
[docs] def build_table_name( self, table_ids: str, ) -> list: """Builds the efficiency and fake rate tables path names Args: table_ids: efficiency or fake table id Returns: list with the efficiency or fake rate table file names """ # Create the file names. # These are both for plus and minus if self.particle_species in ["kaon", "pion"]: # First build the names for positive charge file_names = [self.build_hid_table_name(x, "p") for x in table_ids] # Now add thhe names for negative charge file_names.extend([self.build_hid_table_name(x, "m") for x in table_ids]) # elif self.particle_species in ["elec", "muon"]: elif "elec" in self.particle_species or "muon" in self.particle_species: file_names = ["_".join((x, "table.csv")) for x in table_ids] return [path.join(self.base_table_path, x) for x in file_names]
[docs] def build_hid_table_name(self, table_id, charge): if self.MC_production == "MC15ri": table_name = "_".join( ( "Rdtmc", table_id, charge, self.cut_type + "0-" + str(self.value)[-1], "all.log", ) ) elif self.MC_production == "MC15rd": table_name = "_".join( ( "MC15rd", table_id, "all", self.cut_type + "0-" + str(self.value)[-1] + ".log", ) ) return table_name
[docs] def get_table(self, table_names): if self.particle_species in ["kaon", "pion"]: table = concat([read_csv(x) for x in table_names]) self.make_pidvar_compatible(self.MC_production, table, max_uncertainty=10) elif "elec" in self.particle_species or "muon" in self.particle_species: # elif self.particle_species in ["elec", "muon"]: table = concat([read_csv(x) for x in table_names]) table.query(self.get_lid_queries(), inplace=True) self.add_mcPDG_to_table(table) return table
[docs] def add_mcPDG_to_table(self, table): table.loc[:, "mcPDG"] = -9999 table.loc[table["charge"] == "-", "mcPDG"] = self.true_pdg[0] table.loc[table["charge"] == "+", "mcPDG"] = self.true_pdg[1]
[docs] @staticmethod def make_pidvar_compatible( MC_production: str, table: DataFrame, max_uncertainty: Optional[float] = 1e2 ): """ Convert the pandas dataframes obtained Hadron ID CSV tables via into a format consistent with the format of the lepton ID tables which ``PIDvar`` understands. In particular, convert the ``charge`` column from ``1``/``-1`` integer entries to ``+``/``-`` string entries and calculate the ``theta_min``/``theta_max`` columns. :param table: Pandas dataframe obtained from ``pandas.read_csv`` on HID table :param inplace: Whether to modify the existing dataframe in place. Otherwise, a copy of the existing dataframe will be returned. :param max_uncertainty: Drop rows in HID tables where any of the data-MC uncertainties (sys/stat up/down) exceed this value. Rationale: The HID tables contain rows with nonsense uncertainties > 10⁸, so it is meant to remove those entries. Therefore, the exact value is not important. Set to ``None`` to disable dropping any columns. :return: Modified dataframe that can be used by ``PIDvar``. """ # Some checks that table has expected format of Hadron ID tables if MC_production == "MC15ri": if not set(table["charge"]).issubset({1, -1}): raise ValueError( "Expected that the ``charge`` entries of the original Hadron ID dataframe consists" + "only of ``1`` and ``-1``, but it contains {}".format( set(table["charge"]) ) ) # if "theta_min" in table or "theta_max" in table: # raise ValueError("Dataframe already has ``theta_…`` columns") if max_uncertainty is not None: unc_cols = [ "data_MC_uncertainty_stat_up", "data_MC_uncertainty_stat_dn", "data_MC_uncertainty_sys_up", "data_MC_uncertainty_sys_dn", ] # table = table[table[unc_cols].max(axis=1) <= max_uncertainty] # for θ in [0, π], cos(θ) is strictly decreasing, so we have invert min and max when inverting the cosine if MC_production == "MC15ri": table["theta_min"] = -9999 table["theta_max"] = -9999 table.loc[:, "theta_min"] = np.arccos(table["cos_max"].copy(deep=True)) table.loc[:, "theta_max"] = np.arccos(table["cos_min"].copy(deep=True)) elif MC_production == "MC15rd": pass # PIDvar expects charge columns to contain + or - if MC_production == "MC15ri": table.loc[:, "charge"] = np.where(table["charge"] == +1, "+", "-") elif MC_production == "MC15rd": table.loc[:, "charge"] = np.where(table["charge_min"] == -2, "-", "+") return table
[docs] def get_lid_queries(self): working_point = f"(working_point == '{self.cut_type}')" best_available = "(is_best_available == True)" if "elec" in self.particle_species: exclude_bins = "(not ((theta_min == 0.56 and theta_max == 2.23) or (theta_min == 0.22 and theta_max == 2.71) or (p_min == 0.2 and p_max == 7) or (p_min == 0.2 and p_max == 5)))" elif "muon" in self.particle_species: exclude_bins = "(not ((theta_min == 0.82 and theta_max == 2.22) or (theta_min == 0.4 and theta_max == 0.82) or (theta_min == 0.4 and theta_max == 2.6) or (p_min == 0.2 and p_max == 5)))" variable = f"(variable == '{self.variable}')" # PATCH exclude negative and extremely large weights physical_values = "(0 < data_MC_ratio < 10)" logging.warning( f"If negative weights or extremely large weights are present in the table, these will be excluded. The arbitrarily selected 'physical range' is {physical_values}" ) total_cutstring = " and ".join( (working_point, best_available, exclude_bins, variable, physical_values) ) logging.info( f"The following cutstring has been applied to the provide LID table: {total_cutstring}" ) return total_cutstring