Source code for callmefair.mitigation.fair_grid

"""
Grid Search for Bias Mitigation Combinations

This module provides a comprehensive grid search framework for evaluating different
combinations of bias mitigation techniques. It supports systematic evaluation of
preprocessing, in-processing, and postprocessing bias mitigation methods across
various machine learning models.

The module implements:
- Automatic model adaptation for different ML frameworks
- Systematic evaluation of bias mitigation combinations
- Comprehensive logging and result aggregation
- Support for single sensitive attribute evaluation (extensible to multiple)

Classes:
    dummy_model: Dummy model class for compatibility
    BMGridSearch: Main grid search class for bias mitigation evaluation

Functions:
    get_model_proba: Adapts different ML models for probability prediction

Example:
    >>> from callmefair.mitigation.fair_grid import BMGridSearch
    >>> from callmefair.mitigation.fair_bm import BMType
    >>> from sklearn.ensemble import RandomForestClassifier
    >>> 
    >>> # Define bias mitigation combinations to test
    >>> bm_combinations = [
    >>>     [BMType.preReweighing],
    >>>     [BMType.preDisparate],
    >>>     [BMType.preReweighing, BMType.posCEO],
    >>>     [BMType.inAdversarial]
    >>> ]
    >>> 
    >>> # Initialize grid search
    >>> grid_search = BMGridSearch(
    >>>     bmI=bm_interface,
    >>>     model=RandomForestClassifier(),
    >>>     bm_list=bm_combinations,
    >>>     privileged_group=privileged_groups,
    >>>     unprivileged_group=unprivileged_groups
    >>> )
    >>> 
    >>> # Run grid search
    >>> grid_search.run_single_sensitive()
"""

from callmefair.util.fair_util import BMInterface, BMMetrics
from callmefair.mitigation.fair_bm import BMManager, BMType
from callmefair.mitigation.fair_log import csvLogger
import numpy as np
from datetime import datetime
from dataclasses import dataclass   
try:
    import tensorflow.compat.v1 as tf
    tf.disable_eager_execution()
except ImportError:
[docs] tf = None
#TODO adapt model fit and predict probabilites to an abstract interface
[docs] def get_model_proba(model, bmI: BMInterface) -> tuple[np.ndarray]: """ Adapt different ML models for probability prediction. This function provides a unified interface for training models and obtaining probability predictions across different ML frameworks. It handles various model types including scikit-learn, XGBoost, TabNet, and others. Args: model: The machine learning model to train and evaluate bmI (BMInterface): Interface for managing binary label datasets Returns: tuple[np.ndarray]: Tuple containing (validation_predictions, test_predictions) Both arrays contain probability predictions for the positive class. Raises: ValueError: If the model type is not supported Example: >>> from sklearn.ensemble import RandomForestClassifier >>> model = RandomForestClassifier() >>> val_pred, test_pred = get_model_proba(model, bm_interface) >>> print(f"Validation predictions shape: {val_pred.shape}") >>> print(f"Test predictions shape: {test_pred.shape}") """ x_train, y_train = bmI.get_train_xy() x_val , y_val = bmI.get_val_xy() x_test, _ = bmI.get_test_xy() if model.__str__().startswith('LogisticRegression'): model = model.fit(x_train, y_train, sample_weight=bmI.get_train_BLD().instance_weights) y_val_pred = model.predict_proba(x_val) y_test_pred = model.predict_proba(x_test) elif any([model.__str__().startswith(i) for i in ('XGBClassifier', 'MLP')]): model = model.fit(x_train, y_train) y_val_pred = model.predict_proba(x_val) y_test_pred = model.predict_proba(x_test) elif model.__str__().startswith('TabNet'): model.fit(x_train, y_train, eval_set=[(x_val, y_val)] ) y_val_pred = model.predict_proba(x_val) y_test_pred = model.predict_proba(x_test) else: model = model.fit(x_train, y_train, eval_set=[(x_val , y_val)]) y_val_pred = model.predict_proba(x_val) y_test_pred = model.predict_proba(x_test) return (y_val_pred, y_test_pred)
@dataclass
[docs] class dummy_model: """ Dummy model class for compatibility with bias mitigation metrics. This class provides a minimal interface required by the BMMetrics class for evaluating bias mitigation techniques when no specific model is used (e.g., for in-processing techniques that have their own models). Attributes: classes_ (np.ndarray): Array of class labels [0, 1] """
[docs] classes_ = np.array([0,1])
[docs] class BMGridSearch: """ Grid Search for Bias Mitigation Combinations. This class provides a systematic framework for evaluating different combinations of bias mitigation techniques. It supports preprocessing, in-processing, and postprocessing methods, and can work with various machine learning models. The grid search evaluates each combination of bias mitigation techniques and logs the results for comparison. It currently supports single sensitive attribute evaluation, with plans for multiple sensitive attributes. Attributes: bmI (BMInterface): Interface for managing binary label datasets bmMR (BMManager): Bias mitigation manager for applying techniques model: The machine learning model to evaluate bm_list (list[list[BMType]]): List of bias mitigation combinations to test privileged_group (list[dict]): List of dictionaries defining privileged groups unprivileged_group (list[dict]): List of dictionaries defining unprivileged groups is_model_in (bool): Whether using in-processing bias mitigation Example: >>> from callmefair.mitigation.fair_grid import BMGridSearch >>> from callmefair.mitigation.fair_bm import BMType >>> >>> # Define combinations to test >>> combinations = [ >>> [BMType.preReweighing], >>> [BMType.preDisparate, BMType.posCEO], >>> [BMType.inAdversarial] >>> ] >>> >>> # Create grid search >>> grid_search = BMGridSearch( >>> bmI=bm_interface, >>> model=RandomForestClassifier(), >>> bm_list=combinations, >>> privileged_group=privileged_groups, >>> unprivileged_group=unprivileged_groups >>> ) >>> >>> # Run evaluation >>> grid_search.run_single_sensitive() """ def __init__(self, bmI: BMInterface, model, bm_list: list[list[BMType]], privileged_group: list[dict], unprivileged_group: list[dict]): """ Initialize the Bias Mitigation Grid Search. Args: bmI (BMInterface): Interface for managing binary label datasets model: The machine learning model to evaluate. Can be None for in-processing techniques that have their own models. bm_list (list[list[BMType]]): List of bias mitigation combinations to test. Each inner list represents one combination of techniques. privileged_group (list[dict]): List of dictionaries defining privileged groups. Each dict should contain protected attribute names and their privileged values. unprivileged_group (list[dict]): List of dictionaries defining unprivileged groups. Each dict should contain protected attribute names and their unprivileged values. Example: >>> bm_combinations = [ >>> [BMType.preReweighing], >>> [BMType.preDisparate, BMType.posCEO], >>> [BMType.inAdversarial] >>> ] >>> >>> grid_search = BMGridSearch( >>> bmI=bm_interface, >>> model=RandomForestClassifier(), >>> bm_list=bm_combinations, >>> privileged_group=[{'gender': 1}], >>> unprivileged_group=[{'gender': 0}] >>> ) """
[docs] self.bmI = bmI
[docs] self.bmMR = BMManager(self.bmI, privileged_group, unprivileged_group)
[docs] self.model = model
if model is None: # prepare BMI to deal with transform classifier (scaler) self.bmI.set_transform()
[docs] self.bm_list = bm_list
[docs] self.privileged_group = privileged_group
[docs] self.unprivileged_group = unprivileged_group
[docs] self.is_model_in = False
def __in_model_run(self) -> tuple: """ Run in-processing model training and prediction. This method handles the special case of in-processing bias mitigation techniques that have their own training procedures. Returns: tuple: Tuple containing (trained_model, validation_predictions, test_predictions) Example: >>> model, val_pred, test_pred = self.__in_model_run() """ infer_model = self.model.fit(self.bmI.get_train_BLD()) y_val_pred = self.model.predict(self.bmI.get_val_BLD()) y_test_pred = self.model.predict(self.bmI.get_test_BLD()) return infer_model, y_val_pred, y_test_pred def __warmup(self) -> None: """ Initialize the bias mitigation metrics evaluator. This method sets up the BMMetrics object for evaluating fairness metrics across all bias mitigation combinations. It handles both standard models and in-processing models. Example: >>> self.__warmup() >>> # BMMetrics object is now ready for evaluation """ if self.is_model_in: _, y_val_pred, y_test_pred = self.__in_model_run() else: y_val_pred, y_test_pred = get_model_proba(self.model, self.bmI) self.bmM = BMMetrics(self.bmI, dummy_model.classes_, y_val_pred, y_test_pred, self.privileged_group, self.unprivileged_group) def __is_valid_in_processing(self, in_set: list[set[BMType]]) -> tuple[bool, BMType]: """ Check if in-processing bias mitigation is valid in the current combination. This method validates that in-processing techniques are properly configured and returns the specific in-processing type if found. Args: in_set (list[set[BMType]]): List of bias mitigation combinations to check Returns: tuple[bool, BMType]: Tuple containing (is_valid, in_processing_type) is_valid: Whether in-processing is valid in the combination in_processing_type: The specific in-processing technique found, or None Example: >>> is_valid, in_type = self.__is_valid_in_processing(bm_combinations) >>> if is_valid: >>> print(f"Found in-processing technique: {in_type}") """ in_type = None enum_count = 0 for current_set in in_set: for item in current_set: if item.is_in: enum_count += 1 in_type = item if enum_count == 0: return False, in_type return True, in_type
[docs] def run_single_sensitive(self) -> None: """ Run grid search evaluation for single sensitive attribute. This method performs a comprehensive evaluation of all bias mitigation combinations in the grid search. It evaluates each combination and logs the results for comparison. Currently supports single sensitive attribute evaluation, with plans for multiple sensitive attributes. The method: 1. Validates in-processing configurations 2. Evaluates baseline performance (no bias mitigation) 3. Evaluates each bias mitigation combination 4. Logs results to CSV files for analysis Raises: ValueError: If in-processing bias mitigation is defined with a classifier model Example: >>> # Define bias mitigation combinations >>> combinations = [ >>> [BMType.preReweighing], >>> [BMType.preDisparate, BMType.posCEO], >>> [BMType.inAdversarial] >>> ] >>> >>> # Run grid search >>> grid_search.run_single_sensitive() >>> # Results are logged to CSV files """ # check if in processing is possible is_in, in_type = self.__is_valid_in_processing(self.bm_list) if is_in and self.model is not None: raise ValueError('In processing BM defined. Combination with classifier is invalid.') if is_in: self.is_model_in = True if in_type == BMType.inAdversarial: self.model = self.bmMR.in_AD() elif in_type == BMType.inMeta: self.model = self.bmMR.in_Meta(self.bmI.get_protected_att()[0]) # create BMMetric object self.__warmup() logger = csvLogger(f'experiment_({datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")})') experiment_dict = {'model':self.model.__str__().split('(')[0], 'BM':'baseline'} experiment_dict.update(self.bmM.get_report()) experiment_dict.update({'fair_score':self.bmM.get_score()}) exp_data_list = [experiment_dict] for c_set in self.bm_list: bm_name = '' pre_in_set = [c for c in c_set if c.is_pre] in_in_set = [c for c in c_set if c.is_in] pos_in_set = [c for c in c_set if c.is_pos] for c in pre_in_set: bm_name += f' {c.name}' if c == BMType.preReweighing: self.bmMR.pre_Reweighing() if c == BMType.preDisparate: self.bmMR.pre_DR(self.bmI.get_protected_att()[0]) if c == BMType.preLFR: self.bmMR.pre_LFR() # check if in-processing is in bm_list if is_in: # clear memory on AD if hasattr(self.model, 'sess'): self.model.sess.close() if any(in_in_set): if in_type == BMType.inAdversarial: self.model = self.bmMR.in_AD(debias=True) bm_name += ' inAD' elif in_type == BMType.inMeta: self.model = self.bmMR.in_Meta(self.bmI.get_protected_att()[0], tau=0.7) else: if in_type == BMType.inAdversarial: self.model = self.bmMR.in_AD() _, y_val_pred, y_test_pred = self.__in_model_run() else: y_val_pred, y_test_pred = get_model_proba(self.model, self.bmI) self.bmM.set_new_pred(y_val_pred, y_test_pred) for c in pos_in_set: bm_name += f' {c.name}' if c == BMType.posCalibrated: self.bmMR.pos_CEO(self.bmI.get_val_BLD(), self.bmI.get_test_BLD()) elif c == BMType.posEqqOds: self.bmMR.pos_EO(self.bmI.get_val_BLD(), self.bmI.get_test_BLD()) elif c == BMType.posROC: self.bmMR.pos_ROC(self.bmI.get_val_BLD(), self.bmI.get_test_BLD()) new_exp_dict = {'model':self.model.__str__().split('(')[0], 'BM':bm_name[1:]} new_exp_dict.update(self.bmM.get_report()) new_exp_dict.update({'fair_score':self.bmM.get_score()}) exp_data_list.append(new_exp_dict) self.bmI.restore_BLD() logger(exp_data_list)
#aggregate_csv_files('./results/', f'./results/experiment_{datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")}.csv')