Module epiclass.utils.shap.shap_utils
Module containing utility functions for shap files handling and a bit of analysis.
Functions
def extract_shap_values_and_info(shap_logdir: str | Path, verbose: bool = True)-
Extract and print basic statistics about SHAP values from an archive.
Args
shap_logdir:str- The directory where the SHAP values archive is located.
verbose:bool- Whether to print basic statistics about the SHAP values.
Returns
shap_matrices (np.ndarray): SHAP matrices. eval_md5s (List[str]): List of evaluation MD5s. classes (List[Tuple[str, str]]): List of classes. Each class is a tuple containing the class index and the class label.
def get_archives(shap_values_dir: str | Path)-
Extracts SHAP values and explainer background information from .npz files in a specified directory.
This function searches for files in the provided directory, specifically looking for files that match the patterns "evaluation.npz" and "explainer_background.npz". It loads these .npz files as dictionaries and returns them. The function raises a FileNotFoundError if the required files are not found in the directory.
Args
shap_values_dir (str | Path): The directory path where the .npz files are located.
Returns
Tuple[Dict, Dict]- The first dictionary contains the SHAP values extracted from the "evaluation.npz" file,
and the second contains the explainer background information extracted from the "explainer_background.npz" file.
Raises
FileNotFoundError- If either the SHAP values file or the explainer background file is not found
in the specified directory.
def get_shap_matrix(meta: Metadata, shap_matrices: np.ndarray, eval_md5s: List[str], label_category: str, selected_labels: List[str], class_idx: int, copy_meta: bool = True) ‑> Tuple[numpy.ndarray, List[int]]-
Generates a SHAP matrix corresponding to a selected subset of samples.
This function selects a subset of samples based on specified criteria and then generates a SHAP matrix for these selected samples. It filters the metadata if a specific target subsample is provided, and selects a subset of samples that are identified by their md5 hash. It then selects the SHAP values of these samples under the matrix of the given class number.
Args
meta:metadata.Metadata- Metadata object containing information about the samples.
shap_matrices:np.ndarray- Array of SHAP matrices for each class.
eval_md5s:List[str]- List of md5 hashes identifying the evaluation samples.
label_category:str- Name of the category in the metadata that contains the desired labels.
selected_labels:List[str]- Name of the classes for which samples will be considered.
class_idx:int- Index of the class for which the shap values matrix will be used.
Returns
np.ndarray- The selected SHAP matrix for the selected class and for the chosen samples based on the provided criteria.
List[int]- The indices of the chosen samples in the original SHAP matrix.
Raises
IndexError- If the
class_idxis out of bounds for theshap_matrices.
def n_most_important_features(sample_shaps: np.ndarray, n: int) ‑> numpy.ndarray-
Return indices of features with the highest absolute shap values.
Args
sample_shaps:np.ndarray- Array of SHAP values for a single sample.
n:int- Number of top features to return.
Returns
np.ndarray- Indices of the top
nfeatures with the highest absolute SHAP values.
def select_random_shap_samples(shap_dict: Dict[str, List[np.ndarray]], n: int) ‑> Dict[str, List[numpy.ndarray]]-
Selects a random subset of SHAP values and their corresponding IDs from a given dictionary.
This function randomly selects 'n' samples from the provided SHAP values. It ensures that the selection is non-repetitive. The function is designed to work with a dictionary containing SHAP values and their corresponding IDs. The resulting subset contains both SHAP values and IDs, maintaining their association.
Args
shap_dict:Dict[str, List[np.ndarray]]- A dictionary with two keys: 'shap' and 'ids'. 'shap' should be a list of numpy arrays containing SHAP values, and 'ids' should be a list of identifiers corresponding to each SHAP value.
n:int- The number of random samples to select. If 'n' is larger than the total number of samples available, all samples are returned without duplication.
Returns
Dict[str, List[np.ndarray]]- A dictionary containing two keys: 'shap' and 'ids'. 'shap' is a list of numpy arrays representing the randomly selected SHAP values, and 'ids' is a list of the corresponding identifiers. The length of the lists equals 'n', or the total number of samples if 'n' is larger than the available samples.
Raises
ValueError- If 'n' is negative.
IndexError- If the provided 'shap_dict' does not contain the required keys ('shap' and 'ids').
def subsample_md5s(md5s: List[str], metadata: Metadata, category_label: str, labels: List[str], copy_metadata: bool = True) ‑> List[int]-
Subsample md5s index based on metadata filtering provided, for a given category and filtering labels.
Args
md5s:list- A list of MD5 hashes.
metadata:Metadata- A metadata object containing the data to be filtered.
category_label:str- The category label to be used for filtering the metadata.
labels:list- A list of labels to be used for selecting category subsets in the metadata.
Returns
list- A list of indices corresponding to the selected md5s.