diff --git a/examples/demo_FACTS.ipynb b/examples/demo_FACTS.ipynb
index a99d26ef..7541cc86 100644
--- a/examples/demo_FACTS.ipynb
+++ b/examples/demo_FACTS.ipynb
@@ -1,775 +1,847 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Fairness auditing for subgroups using Fairness Aware Counterfactuals for Subgroups (FACTS).\n",
- "\n",
- "[FACTS](https://arxiv.org/abs/2306.14978) is an efficient, model-agnostic, highly parameterizable, and explainable framework for auditing subgroup fairness through counterfactual explanations. FACTS focuses on identifying a specific type of bias, i.e. the *difficulty in achieving recourse*. In short, it focuses on the population that has obtained the unfavorable outcome (*affected population*) by a ML model and tries to identify differences in the difficulty of changing the ML model's decision to obtain the favorable outcome, between affected subpopulations.\n",
- "\n",
- "In this notebook, we will see how to use this algorithm for discovering subgroups where the bias of a model (logistic regression for simplicity) between Males and Females is high.\n",
- "\n",
- "We will use the Adult dataset from UCI ([reference](https://archive.ics.uci.edu/ml/datasets/adult))."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Preliminaries"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Import dependencies\n",
- "\n",
- "As usual in python, the first step is to import all necessary packages."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
+ "cells": [
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "WARNING:root:No module named 'tempeh': fetch_lawschool_gpa will be unavailable. To install, run:\n",
- "pip install 'aif360[LawSchoolGPA]'\n"
- ]
- }
- ],
- "source": [
- "from sklearn.model_selection import train_test_split\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn.compose import ColumnTransformer\n",
- "from sklearn.pipeline import Pipeline\n",
- "from sklearn.preprocessing import OneHotEncoder\n",
- "\n",
- "from aif360.sklearn.datasets.openml_datasets import fetch_adult\n",
- "from aif360.sklearn.detectors.facts.clean import clean_dataset\n",
- "from aif360.sklearn.detectors.facts import FACTS, FACTS_bias_scan\n",
- "\n",
- "from IPython.display import display\n",
- "\n",
- "import warnings\n",
- "warnings.filterwarnings(\"ignore\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Below, you can change the `random_seed` variable to `None` if you would like for the pseudo-random parts to actually change between runs. We have set it to a specific value for reproducibility."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "random_seed = 131313 # for reproducibility"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_FACTS.ipynb)\n"
+ ],
+ "metadata": {
+ "id": "3VmYWXk6gRbj"
+ }
+ },
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
age
\n",
- "
workclass
\n",
- "
education-num
\n",
- "
marital-status
\n",
- "
occupation
\n",
- "
relationship
\n",
- "
race
\n",
- "
sex
\n",
- "
capital-gain
\n",
- "
capital-loss
\n",
- "
hours-per-week
\n",
- "
native-country
\n",
- "
income
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
(16.999, 26.0]
\n",
- "
Private
\n",
- "
7.0
\n",
- "
Never-married
\n",
- "
Machine-op-inspct
\n",
- "
Own-child
\n",
- "
Black
\n",
- "
Male
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
FullTime
\n",
- "
United-States
\n",
- "
0
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
(34.0, 41.0]
\n",
- "
Private
\n",
- "
9.0
\n",
- "
Married-civ-spouse
\n",
- "
Farming-fishing
\n",
- "
Married
\n",
- "
White
\n",
- "
Male
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
OverTime
\n",
- "
United-States
\n",
- "
0
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
(26.0, 34.0]
\n",
- "
Local-gov
\n",
- "
12.0
\n",
- "
Married-civ-spouse
\n",
- "
Protective-serv
\n",
- "
Married
\n",
- "
White
\n",
- "
Male
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
FullTime
\n",
- "
United-States
\n",
- "
1
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
(41.0, 50.0]
\n",
- "
Private
\n",
- "
10.0
\n",
- "
Married-civ-spouse
\n",
- "
Machine-op-inspct
\n",
- "
Married
\n",
- "
Black
\n",
- "
Male
\n",
- "
7688.0
\n",
- "
0.0
\n",
- "
FullTime
\n",
- "
United-States
\n",
- "
1
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
(26.0, 34.0]
\n",
- "
Private
\n",
- "
6.0
\n",
- "
Never-married
\n",
- "
Other-service
\n",
- "
Not-in-family
\n",
- "
White
\n",
- "
Male
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
MidTime
\n",
- "
United-States
\n",
- "
0
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qHiNaF-VgPB4"
+ },
+ "source": [
+ "# Fairness auditing for subgroups using Fairness Aware Counterfactuals for Subgroups (FACTS).\n",
+ "\n",
+ "[FACTS](https://arxiv.org/abs/2306.14978) is an efficient, model-agnostic, highly parameterizable, and explainable framework for auditing subgroup fairness through counterfactual explanations. FACTS focuses on identifying a specific type of bias, i.e. the *difficulty in achieving recourse*. In short, it focuses on the population that has obtained the unfavorable outcome (*affected population*) by a ML model and tries to identify differences in the difficulty of changing the ML model's decision to obtain the favorable outcome, between affected subpopulations.\n",
+ "\n",
+ "In this notebook, we will see how to use this algorithm for discovering subgroups where the bias of a model (logistic regression for simplicity) between Males and Females is high.\n",
+ "\n",
+ "We will use the Adult dataset from UCI ([reference](https://archive.ics.uci.edu/ml/datasets/adult))."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4cavLAfZgPB5"
+ },
+ "source": [
+ "# Preliminaries"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1FLtToZmgPB6"
+ },
+ "source": [
+ "## Import dependencies\n",
+ "\n",
+ "As usual in python, the first step is to import all necessary packages."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qz_p2JzegPB6",
+ "outputId": "dc1206fb-14b9-4ac7-ce74-0f2d83ce6a61"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:No module named 'tempeh': fetch_lawschool_gpa will be unavailable. To install, run:\n",
+ "pip install 'aif360[LawSchoolGPA]'\n"
+ ]
+ }
],
- "text/plain": [
- " age workclass education-num marital-status \\\n",
- "0 (16.999, 26.0] Private 7.0 Never-married \n",
- "1 (34.0, 41.0] Private 9.0 Married-civ-spouse \n",
- "2 (26.0, 34.0] Local-gov 12.0 Married-civ-spouse \n",
- "3 (41.0, 50.0] Private 10.0 Married-civ-spouse \n",
- "4 (26.0, 34.0] Private 6.0 Never-married \n",
- "\n",
- " occupation relationship race sex capital-gain capital-loss \\\n",
- "0 Machine-op-inspct Own-child Black Male 0.0 0.0 \n",
- "1 Farming-fishing Married White Male 0.0 0.0 \n",
- "2 Protective-serv Married White Male 0.0 0.0 \n",
- "3 Machine-op-inspct Married Black Male 7688.0 0.0 \n",
- "4 Other-service Not-in-family White Male 0.0 0.0 \n",
- "\n",
- " hours-per-week native-country income \n",
- "0 FullTime United-States 0 \n",
- "1 OverTime United-States 0 \n",
- "2 FullTime United-States 1 \n",
- "3 FullTime United-States 1 \n",
- "4 MidTime United-States 0 "
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.compose import ColumnTransformer\n",
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.preprocessing import OneHotEncoder\n",
+ "\n",
+ "from aif360.sklearn.datasets.openml_datasets import fetch_adult\n",
+ "from aif360.sklearn.detectors.facts.clean import clean_dataset\n",
+ "from aif360.sklearn.detectors.facts import FACTS, FACTS_bias_scan\n",
+ "\n",
+ "from IPython.display import display\n",
+ "\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")"
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# load the adult dataset and perform some simple preprocessing steps\n",
- "# See output for a glimpse of the final dataset's characteristics\n",
- "X, y, sample_weight = fetch_adult()\n",
- "data = clean_dataset(X.assign(income=y), \"adult\")\n",
- "display(data.head())\n",
- "\n",
- "# split into train-test data\n",
- "y = data['income']\n",
- "X = data.drop('income', axis=1)\n",
- "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Example Model to be used for Auditing\n",
- "\n",
- "We use the train set to train a simple logistic regression model. This will serve as the demonstrative model, which we will then treat as a black box and apply our algorithm.\n",
- "\n",
- "Of course, any model can be used in its place. Our purpose here is not to produce a good model, but to audit the fairness of an existing one."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "#### here, we incrementally build the example model. It consists of one preprocessing step,\n",
- "#### which is to turn categorical features into the respective one-hot encodings, and\n",
- "#### a simple scikit-learn logistic regressor.\n",
- "categorical_features = X.select_dtypes(include=[\"object\", \"category\"]).columns.to_list()\n",
- "categorical_features_onehot_transformer = ColumnTransformer(\n",
- " transformers=[\n",
- " (\"one-hot-encoder\", OneHotEncoder(), categorical_features)\n",
- " ],\n",
- " remainder=\"passthrough\"\n",
- ")\n",
- "model = Pipeline([\n",
- " (\"one-hot-encoder\", categorical_features_onehot_transformer),\n",
- " (\"clf\", LogisticRegression(max_iter=1500))\n",
- "])\n",
- "\n",
- "#### train the model\n",
- "model = model.fit(X_train, y_train)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Accuracy = 85.16%\n"
- ]
- }
- ],
- "source": [
- "# showcase model's accuracy\n",
- "y_pred = model.predict(X_test)\n",
- "print(f\"Accuracy = {(y_test.values == y_pred).sum() / y_test.shape[0]:.2%}\")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# A Practical Example of FACTS\n",
- "\n",
- "The real essence of our work starts here. Specifically, we showcase the generation of candidate subpopulation groups and counterfactuals and the detection of those groups that exhibit the greatest unfairness, with respect to one of several metrics."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load and Fit FACTS"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "# load FACTS framework with:\n",
- "# - the model to be audited\n",
- "# - protected attribute \"sex\" and\n",
- "# - assigning equal, unit weights to all features for cost computation.\n",
- "# - no features forbidden from changing, i.e. user can specify any features that cannot change at all.\n",
- "detector = FACTS(\n",
- " clf=model,\n",
- " prot_attr=\"sex\",\n",
- " feature_weights={f: 1 for f in X.columns},\n",
- " feats_not_allowed_to_change=[]\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "m1Ye20legPB8"
+ },
+ "source": [
+ "Below, you can change the `random_seed` variable to `None` if you would like for the pseudo-random parts to actually change between runs. We have set it to a specific value for reproducibility."
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Computing candidate subgroups.\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "t8KJD-ICgPB8"
+ },
+ "outputs": [],
+ "source": [
+ "random_seed = 131313 # for reproducibility"
+ ]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████| 1046/1046 [00:00<00:00, 523287.45it/s]"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tq79M4GxgPB8"
+ },
+ "source": [
+ "## Load Dataset"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Number of subgroups: 563\n",
- "Computing candidate recourses for all subgroups.\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HwnmXFJ_gPB8",
+ "outputId": "3cddbc05-73e7-44c5-8482-af3b942f5c6f"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
age
\n",
+ "
workclass
\n",
+ "
education-num
\n",
+ "
marital-status
\n",
+ "
occupation
\n",
+ "
relationship
\n",
+ "
race
\n",
+ "
sex
\n",
+ "
capital-gain
\n",
+ "
capital-loss
\n",
+ "
hours-per-week
\n",
+ "
native-country
\n",
+ "
income
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
(16.999, 26.0]
\n",
+ "
Private
\n",
+ "
7.0
\n",
+ "
Never-married
\n",
+ "
Machine-op-inspct
\n",
+ "
Own-child
\n",
+ "
Black
\n",
+ "
Male
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
FullTime
\n",
+ "
United-States
\n",
+ "
0
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
(34.0, 41.0]
\n",
+ "
Private
\n",
+ "
9.0
\n",
+ "
Married-civ-spouse
\n",
+ "
Farming-fishing
\n",
+ "
Married
\n",
+ "
White
\n",
+ "
Male
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
OverTime
\n",
+ "
United-States
\n",
+ "
0
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
(26.0, 34.0]
\n",
+ "
Local-gov
\n",
+ "
12.0
\n",
+ "
Married-civ-spouse
\n",
+ "
Protective-serv
\n",
+ "
Married
\n",
+ "
White
\n",
+ "
Male
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
FullTime
\n",
+ "
United-States
\n",
+ "
1
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
(41.0, 50.0]
\n",
+ "
Private
\n",
+ "
10.0
\n",
+ "
Married-civ-spouse
\n",
+ "
Machine-op-inspct
\n",
+ "
Married
\n",
+ "
Black
\n",
+ "
Male
\n",
+ "
7688.0
\n",
+ "
0.0
\n",
+ "
FullTime
\n",
+ "
United-States
\n",
+ "
1
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
(26.0, 34.0]
\n",
+ "
Private
\n",
+ "
6.0
\n",
+ "
Never-married
\n",
+ "
Other-service
\n",
+ "
Not-in-family
\n",
+ "
White
\n",
+ "
Male
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
MidTime
\n",
+ "
United-States
\n",
+ "
0
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age workclass education-num marital-status \\\n",
+ "0 (16.999, 26.0] Private 7.0 Never-married \n",
+ "1 (34.0, 41.0] Private 9.0 Married-civ-spouse \n",
+ "2 (26.0, 34.0] Local-gov 12.0 Married-civ-spouse \n",
+ "3 (41.0, 50.0] Private 10.0 Married-civ-spouse \n",
+ "4 (26.0, 34.0] Private 6.0 Never-married \n",
+ "\n",
+ " occupation relationship race sex capital-gain capital-loss \\\n",
+ "0 Machine-op-inspct Own-child Black Male 0.0 0.0 \n",
+ "1 Farming-fishing Married White Male 0.0 0.0 \n",
+ "2 Protective-serv Married White Male 0.0 0.0 \n",
+ "3 Machine-op-inspct Married Black Male 7688.0 0.0 \n",
+ "4 Other-service Not-in-family White Male 0.0 0.0 \n",
+ "\n",
+ " hours-per-week native-country income \n",
+ "0 FullTime United-States 0 \n",
+ "1 OverTime United-States 0 \n",
+ "2 FullTime United-States 1 \n",
+ "3 FullTime United-States 1 \n",
+ "4 MidTime United-States 0 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# load the adult dataset and perform some simple preprocessing steps\n",
+ "# See output for a glimpse of the final dataset's characteristics\n",
+ "X, y, sample_weight = fetch_adult()\n",
+ "data = clean_dataset(X.assign(income=y), \"adult\")\n",
+ "display(data.head())\n",
+ "\n",
+ "# split into train-test data\n",
+ "y = data['income']\n",
+ "X = data.drop('income', axis=1)\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)"
+ ]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\n",
- "100%|█████████████████████████████████████████████████████████████████████████████| 563/563 [00:00<00:00, 50669.32it/s]"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WdckzZdwgPB8"
+ },
+ "source": [
+ "## Example Model to be used for Auditing\n",
+ "\n",
+ "We use the train set to train a simple logistic regression model. This will serve as the demonstrative model, which we will then treat as a black box and apply our algorithm.\n",
+ "\n",
+ "Of course, any model can be used in its place. Our purpose here is not to produce a good model, but to audit the fairness of an existing one."
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Computing percentages of individuals flipped by each action independently.\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9NuWT4LMgPB9"
+ },
+ "outputs": [],
+ "source": [
+ "#### here, we incrementally build the example model. It consists of one preprocessing step,\n",
+ "#### which is to turn categorical features into the respective one-hot encodings, and\n",
+ "#### a simple scikit-learn logistic regressor.\n",
+ "categorical_features = X.select_dtypes(include=[\"object\", \"category\"]).columns.to_list()\n",
+ "categorical_features_onehot_transformer = ColumnTransformer(\n",
+ " transformers=[\n",
+ " (\"one-hot-encoder\", OneHotEncoder(), categorical_features)\n",
+ " ],\n",
+ " remainder=\"passthrough\"\n",
+ ")\n",
+ "model = Pipeline([\n",
+ " (\"one-hot-encoder\", categorical_features_onehot_transformer),\n",
+ " (\"clf\", LogisticRegression(max_iter=1500))\n",
+ "])\n",
+ "\n",
+ "#### train the model\n",
+ "model = model.fit(X_train, y_train)"
+ ]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\n",
- "100%|████████████████████████████████████████████████████████████████████████████████| 590/590 [00:13<00:00, 43.37it/s]"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iHQwJo_PgPB9",
+ "outputId": "88e849b7-f24a-4b87-fc48-256612ed3626"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy = 85.16%\n"
+ ]
+ }
+ ],
+ "source": [
+ "# showcase model's accuracy\n",
+ "y_pred = model.predict(X_test)\n",
+ "print(f\"Accuracy = {(y_test.values == y_pred).sum() / y_test.shape[0]:.2%}\")"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Computing percentages of individuals flipped by any action with cost up to c, for every c\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "avuD6xf-gPB9"
+ },
+ "source": [
+ "# A Practical Example of FACTS\n",
+ "\n",
+ "The real essence of our work starts here. Specifically, we showcase the generation of candidate subpopulation groups and counterfactuals and the detection of those groups that exhibit the greatest unfairness, with respect to one of several metrics."
+ ]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\n",
- "100%|████████████████████████████████████████████████████████████████████████████████| 416/416 [00:12<00:00, 32.57it/s]\n"
- ]
- }
- ],
- "source": [
- "# generates candidate subpopulation groups for bias and candidate actions\n",
- "detector = detector.fit(X_test)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Detect Groups with Unfairness in Protected Subgroups (using \"Equal Choice for Recourse\" metric)\n",
- "\n",
- "Here we demonstrate the `bias_scan` method of our detector, which ranks subpopulation groups from most to least unfair, with respect to the chosen metric and, of course, the protected attribute.\n",
- "\n",
- "For the purposes of the demo, we use the \"Equal Choice for Recourse\" definition / metric. This posits that the classifier acts fairly for the group in question if the protected subgroups can choose among the same number of sufficiently effective actions to achieve recourse. By sufficiently effective we mean those actions (out of all candidates) which work for at least $100\\phi \\%$ (for some $\\phi \\in [0,1]$) of the subgroup.\n",
- "\n",
- "Given this definition, the respective unfairness *metric* is defined to be the difference in the number of sufficiently effective actions between the two protected subgroups.\n",
- "\n",
- "**Suggestion**: this metric may find utility in scenarios where the aim is to guarantee that protected subgroups have a similar range of options available to them when it comes to making adjustments in order to attain a favorable outcome. For example, when evaluating job candidates, the employer may wish to ensure that applicants from different backgrounds (that currently fail to meet expectations) have an equal array of career / retraining opportunities that may land them the job, so as to ensure diversity in all sectors of the company, which employ individuals with a plethora of roles."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Detects the top `top_count` most biased groups based on the given metric\n",
- "# available metrics are:\n",
- "# - equal-effectiveness\n",
- "# - equal-choice-for-recourse\n",
- "# - equal-effectiveness-within-budget\n",
- "# - equal-cost-of-effectiveness\n",
- "# - equal-mean-recourse\n",
- "# - fair-tradeoff\n",
- "# a short description for each metric is given below\n",
- "detector.bias_scan(\n",
- " metric=\"equal-choice-for-recourse\",\n",
- " phi=0.1,\n",
- " top_count=3\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "35vYtSBFgPB9"
+ },
+ "source": [
+ "## Load and Fit FACTS"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "If \u001b[1mage = (26.0, 34.0], hours-per-week = FullTime\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.59%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.73%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m3.98%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.39%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.78%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m19.66%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m10.63%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m13.39%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-3.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 3.\u001b[39m\n",
- "If \u001b[1mage = (26.0, 34.0], capital-loss = 0.0, hours-per-week = FullTime\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.34%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.67%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m4.08%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.28%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.27%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m18.43%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m9.27%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m11.92%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-2.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 2.\u001b[39m\n",
- "If \u001b[1mhours-per-week = FullTime, native-country = United-States\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m41.66%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m2.62%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m1.79%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m46.78%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m10.08%\u001b[39m.\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m8.70%\u001b[39m.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-1.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 1.\u001b[39m\n"
- ]
- }
- ],
- "source": [
- "# prints the result into a nicely formatted report\n",
- "detector.print_recourse_report(\n",
- " show_action_costs=False,\n",
- " show_subgroup_costs=True,\n",
- " show_unbiased_subgroups=False,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Example Output Breakdown"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let us now disect the above example and the output we see, one step at a time.\n",
- "\n",
- "#### Prelude: $\\phi = 0.1$\n",
- "\n",
- "As we mentioned in the general description of this metric, this is the parameter that determines whether we consider an action sufficiently effective or not. So, here, we consider an action effective if it manages to flip the prediction for at least 10% of the individuals under study, and ineffective otherwise.\n",
- "\n",
- "#### **age = (26.0, 34.0], hours-per-week = FullTime**\n",
- "\n",
- "This is the first (hence, most biased) group. The group description is mostly self-explanatory: everything inside this block concerns all those (affected) individuals that are from 26 (not inclusive) to 34 years old and have a fulltime job. Now, since the output has the same structure for all groups, let us consider this group as an example and further disect the output we see in this block.\n",
- "\n",
- "#### *Protected subgroups 'Male' / 'Female'*\n",
- "\n",
- "We split the population of this group, according to the protected attribute. Hence, we distinguish between males that are 26-34 years old and have a fulltime job and females that are 26-34 years old and have a fulltime job.\n",
- "\n",
- "The \"covered\" percentage reported here in blue signifies that out of all affected females, 10.59% are 26-34 years old and have a fulltime job, while the respective percentage for males is 13.78%.\n",
- "\n",
- "#### *Make age = (41.0, 50.0], hours-per-week = OverTime*\n",
- "\n",
- "This is one of the 3 actions we have tried to apply on the individuals in the current subpopulation group. We report the action, along with its effectiveness and, optionally, the cost; here we omit the action cost because the \"Equal Choice for Recourse\" metric does not take it into account.\n",
- "\n",
- "At this point, let us give a more direct interpretation for the **effectiveness**. In this case, for example, the interpretation could be the following: if all females aged 26-34 with fulltime jobs change their age group to 41-50 years old and their working hours to overtime, then 7.73% of them will actually manage to receive the positive prediction from the model. The rest will still receive the negative prediction.\n",
- "\n",
- "#### *Protected Subgroups' Aggregate Cost*\n",
- "\n",
- "The \"aggregate cost of the above recourses\" message shows how we quantify the *cost of recourse* for all actions in each protected subgroup.\n",
- "\n",
- "This is derived directly from the definition of each metric. Here, for example, we use the \"Equal Choice for Recourse\" metric, which counts the number of effective actions available to each of the protected subgroups. In this group, females have no (sufficiently) effective actions, and as such we say that they gain 0 units. Males have 3 effective actions, so they gain 3 units.\n",
- "\n",
- "Finally, to keep the formalization of having costs everywhere, we rephrase this instead into males having a recourse cost of -3 and females having a recourse cost of 0.\n",
- "\n",
- "As we also mention in the next paragraph, the final bias score of the subgroup is nothing more than the absolute difference of these 2 costs.\n",
- "\n",
- "#### *Bias Deduction / Metric Application*\n",
- "\n",
- "Given the above, one can see that the (same) actions, if applied to females of the subpopulation group, cannot yield more than 10% effectiveness, while in males they achieve up to 19.66%! This is why we argue that, in the terms of bias of recourse, this group exhibits bias against females.\n",
- "\n",
- "This is, of course, with respect to the \"Equal Choice for Recourse\" metric, which posits that the 2 protected subgroups should have the same number of effective actions. Since none of the 3 actions are sufficiently effective for females, and all 3 of them are sufficiently effective for males, we score this group as having a bias measure of $|0 - 3| = 3$."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Example without Bias of Recourse\n",
- "\n",
- "For completeness, we also demonstrate how, for some choices of metrics and parameters, FACTS may fail to find any subpopulation groups that exhibit bias between the protected populations, and thus deduce that in this case there is no recourse related bias."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "detector.bias_scan(\n",
- " metric=\"equal-choice-for-recourse\",\n",
- " phi=0.7,\n",
- " top_count=3\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3FEEyse9gPB9"
+ },
+ "outputs": [],
+ "source": [
+ "# load FACTS framework with:\n",
+ "# - the model to be audited\n",
+ "# - protected attribute \"sex\" and\n",
+ "# - assigning equal, unit weights to all features for cost computation.\n",
+ "# - no features forbidden from changing, i.e. user can specify any features that cannot change at all.\n",
+ "detector = FACTS(\n",
+ " clf=model,\n",
+ " prot_attr=\"sex\",\n",
+ " feature_weights={f: 1 for f in X.columns},\n",
+ " feats_not_allowed_to_change=[]\n",
+ ")"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[1mWith the given parameters, no recourses showing unfairness have been found!\u001b[0m\n"
- ]
- }
- ],
- "source": [
- "# prints the result into a nicely formatted report\n",
- "detector.print_recourse_report(\n",
- " show_action_costs=False,\n",
- " show_subgroup_costs=True,\n",
- " show_unbiased_subgroups=False,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Aternative API\n",
- "\n",
- "We also provide a more succinct API in the form of a wrapper function. This is closer in style to the API of existing `aif360` detectors.\n",
- "\n",
- "The previous example could be run equivalently with the following."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "duEkwMgngPB-",
+ "outputId": "ddcb45a9-fa3e-466c-e492-bf6a8b26b2dc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Computing candidate subgroups.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████████████████████████████████████████████████████████████████████| 1046/1046 [00:00<00:00, 523287.45it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of subgroups: 563\n",
+ "Computing candidate recourses for all subgroups.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "100%|█████████████████████████████████████████████████████████████████████████████| 563/563 [00:00<00:00, 50669.32it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Computing percentages of individuals flipped by each action independently.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "100%|████████████████████████████████████████████████████████████████████████████████| 590/590 [00:13<00:00, 43.37it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Computing percentages of individuals flipped by any action with cost up to c, for every c\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "100%|████████████████████████████████████████████████████████████████████████████████| 416/416 [00:12<00:00, 32.57it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# generates candidate subpopulation groups for bias and candidate actions\n",
+ "detector = detector.fit(X_test)"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "If \u001b[1mage = (26.0, 34.0], hours-per-week = FullTime\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.59%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.73%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m3.98%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.39%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.78%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m19.66%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m10.63%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m13.39%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-3.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 3.\u001b[39m\n",
- "If \u001b[1mage = (26.0, 34.0], capital-loss = 0.0, hours-per-week = FullTime\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.34%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.67%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m4.08%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.28%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.27%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m18.43%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m9.27%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m11.92%\u001b[39m and counterfactual cost = 2.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-2.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 2.\u001b[39m\n",
- "If \u001b[1mhours-per-week = FullTime, native-country = United-States\u001b[0m:\n",
- "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m41.66%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m2.62%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m1.79%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
- "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m46.78%\u001b[39m covered\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m10.08%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m8.70%\u001b[39m and counterfactual cost = 1.0.\n",
- "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-1.00\u001b[39m\n",
- "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 1.\u001b[39m\n"
- ]
- }
- ],
- "source": [
- "most_biased_subgroups = FACTS_bias_scan(\n",
- " X=X_test,\n",
- " clf=model,\n",
- " prot_attr=\"sex\",\n",
- " feature_weights={f: 1 for f in X.columns},\n",
- " feats_not_allowed_to_change=[],\n",
- " metric=\"equal-choice-for-recourse\",\n",
- " phi=0.1,\n",
- " top_count=3,\n",
- " verbose=False, # hides progress bars\n",
- " print_recourse_report=True,\n",
- " show_action_costs=True,\n",
- " show_subgroup_costs=True,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cFPOLhkEgPB-"
+ },
+ "source": [
+ "## Detect Groups with Unfairness in Protected Subgroups (using \"Equal Choice for Recourse\" metric)\n",
+ "\n",
+ "Here we demonstrate the `bias_scan` method of our detector, which ranks subpopulation groups from most to least unfair, with respect to the chosen metric and, of course, the protected attribute.\n",
+ "\n",
+ "For the purposes of the demo, we use the \"Equal Choice for Recourse\" definition / metric. This posits that the classifier acts fairly for the group in question if the protected subgroups can choose among the same number of sufficiently effective actions to achieve recourse. By sufficiently effective we mean those actions (out of all candidates) which work for at least $100\\phi \\%$ (for some $\\phi \\in [0,1]$) of the subgroup.\n",
+ "\n",
+ "Given this definition, the respective unfairness *metric* is defined to be the difference in the number of sufficiently effective actions between the two protected subgroups.\n",
+ "\n",
+ "**Suggestion**: this metric may find utility in scenarios where the aim is to guarantee that protected subgroups have a similar range of options available to them when it comes to making adjustments in order to attain a favorable outcome. For example, when evaluating job candidates, the employer may wish to ensure that applicants from different backgrounds (that currently fail to meet expectations) have an equal array of career / retraining opportunities that may land them the job, so as to ensure diversity in all sectors of the company, which employ individuals with a plethora of roles."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lApb3FmVgPB-"
+ },
+ "outputs": [],
+ "source": [
+ "# Detects the top `top_count` most biased groups based on the given metric\n",
+ "# available metrics are:\n",
+ "# - equal-effectiveness\n",
+ "# - equal-choice-for-recourse\n",
+ "# - equal-effectiveness-within-budget\n",
+ "# - equal-cost-of-effectiveness\n",
+ "# - equal-mean-recourse\n",
+ "# - fair-tradeoff\n",
+ "# a short description for each metric is given below\n",
+ "detector.bias_scan(\n",
+ " metric=\"equal-choice-for-recourse\",\n",
+ " phi=0.1,\n",
+ " top_count=3\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QbuX_rC1gPB-",
+ "outputId": "8292d5e3-b86e-4f3a-fcb5-4eac0093e3ba"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "If \u001b[1mage = (26.0, 34.0], hours-per-week = FullTime\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.59%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.73%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m3.98%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.39%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.78%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m19.66%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m10.63%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m13.39%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-3.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 3.\u001b[39m\n",
+ "If \u001b[1mage = (26.0, 34.0], capital-loss = 0.0, hours-per-week = FullTime\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.34%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.67%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m4.08%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.28%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.27%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m18.43%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m9.27%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m11.92%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-2.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 2.\u001b[39m\n",
+ "If \u001b[1mhours-per-week = FullTime, native-country = United-States\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m41.66%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m2.62%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m1.79%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m46.78%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m10.08%\u001b[39m.\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m8.70%\u001b[39m.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-1.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse. Unfairness score = 1.\u001b[39m\n"
+ ]
+ }
+ ],
+ "source": [
+ "# prints the result into a nicely formatted report\n",
+ "detector.print_recourse_report(\n",
+ " show_action_costs=False,\n",
+ " show_subgroup_costs=True,\n",
+ " show_unbiased_subgroups=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HK5JNhNwgPB-"
+ },
+ "source": [
+ "### Example Output Breakdown"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xf-Qbq6agPB-"
+ },
+ "source": [
+ "Let us now disect the above example and the output we see, one step at a time.\n",
+ "\n",
+ "#### Prelude: $\\phi = 0.1$\n",
+ "\n",
+ "As we mentioned in the general description of this metric, this is the parameter that determines whether we consider an action sufficiently effective or not. So, here, we consider an action effective if it manages to flip the prediction for at least 10% of the individuals under study, and ineffective otherwise.\n",
+ "\n",
+ "#### **age = (26.0, 34.0], hours-per-week = FullTime**\n",
+ "\n",
+ "This is the first (hence, most biased) group. The group description is mostly self-explanatory: everything inside this block concerns all those (affected) individuals that are from 26 (not inclusive) to 34 years old and have a fulltime job. Now, since the output has the same structure for all groups, let us consider this group as an example and further disect the output we see in this block.\n",
+ "\n",
+ "#### *Protected subgroups 'Male' / 'Female'*\n",
+ "\n",
+ "We split the population of this group, according to the protected attribute. Hence, we distinguish between males that are 26-34 years old and have a fulltime job and females that are 26-34 years old and have a fulltime job.\n",
+ "\n",
+ "The \"covered\" percentage reported here in blue signifies that out of all affected females, 10.59% are 26-34 years old and have a fulltime job, while the respective percentage for males is 13.78%.\n",
+ "\n",
+ "#### *Make age = (41.0, 50.0], hours-per-week = OverTime*\n",
+ "\n",
+ "This is one of the 3 actions we have tried to apply on the individuals in the current subpopulation group. We report the action, along with its effectiveness and, optionally, the cost; here we omit the action cost because the \"Equal Choice for Recourse\" metric does not take it into account.\n",
+ "\n",
+ "At this point, let us give a more direct interpretation for the **effectiveness**. In this case, for example, the interpretation could be the following: if all females aged 26-34 with fulltime jobs change their age group to 41-50 years old and their working hours to overtime, then 7.73% of them will actually manage to receive the positive prediction from the model. The rest will still receive the negative prediction.\n",
+ "\n",
+ "#### *Protected Subgroups' Aggregate Cost*\n",
+ "\n",
+ "The \"aggregate cost of the above recourses\" message shows how we quantify the *cost of recourse* for all actions in each protected subgroup.\n",
+ "\n",
+ "This is derived directly from the definition of each metric. Here, for example, we use the \"Equal Choice for Recourse\" metric, which counts the number of effective actions available to each of the protected subgroups. In this group, females have no (sufficiently) effective actions, and as such we say that they gain 0 units. Males have 3 effective actions, so they gain 3 units.\n",
+ "\n",
+ "Finally, to keep the formalization of having costs everywhere, we rephrase this instead into males having a recourse cost of -3 and females having a recourse cost of 0.\n",
+ "\n",
+ "As we also mention in the next paragraph, the final bias score of the subgroup is nothing more than the absolute difference of these 2 costs.\n",
+ "\n",
+ "#### *Bias Deduction / Metric Application*\n",
+ "\n",
+ "Given the above, one can see that the (same) actions, if applied to females of the subpopulation group, cannot yield more than 10% effectiveness, while in males they achieve up to 19.66%! This is why we argue that, in the terms of bias of recourse, this group exhibits bias against females.\n",
+ "\n",
+ "This is, of course, with respect to the \"Equal Choice for Recourse\" metric, which posits that the 2 protected subgroups should have the same number of effective actions. Since none of the 3 actions are sufficiently effective for females, and all 3 of them are sufficiently effective for males, we score this group as having a bias measure of $|0 - 3| = 3$."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qba9HsyEgPB-"
+ },
+ "source": [
+ "### Example without Bias of Recourse\n",
+ "\n",
+ "For completeness, we also demonstrate how, for some choices of metrics and parameters, FACTS may fail to find any subpopulation groups that exhibit bias between the protected populations, and thus deduce that in this case there is no recourse related bias."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9ZazYqQ5gPB_"
+ },
+ "outputs": [],
+ "source": [
+ "detector.bias_scan(\n",
+ " metric=\"equal-choice-for-recourse\",\n",
+ " phi=0.7,\n",
+ " top_count=3\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LNuBKASZgPB_",
+ "outputId": "f2c75748-e4af-4719-b774-c8904bb5f6f6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[1mWith the given parameters, no recourses showing unfairness have been found!\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "# prints the result into a nicely formatted report\n",
+ "detector.print_recourse_report(\n",
+ " show_action_costs=False,\n",
+ " show_subgroup_costs=True,\n",
+ " show_unbiased_subgroups=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HIp_hVhLgPB_"
+ },
+ "source": [
+ "## Aternative API\n",
+ "\n",
+ "We also provide a more succinct API in the form of a wrapper function. This is closer in style to the API of existing `aif360` detectors.\n",
+ "\n",
+ "The previous example could be run equivalently with the following."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RyIvVWtKgPB_",
+ "outputId": "cfc9da11-5120-48f1-fbad-63c5018f4b95"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "If \u001b[1mage = (26.0, 34.0], hours-per-week = FullTime\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.59%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.73%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m3.98%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.39%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.78%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m19.66%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m10.63%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m13.39%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-3.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 3.\u001b[39m\n",
+ "If \u001b[1mage = (26.0, 34.0], capital-loss = 0.0, hours-per-week = FullTime\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m10.34%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m7.67%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m4.08%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m5.28%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m13.27%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m18.43%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (41.0, 50.0]\u001b[39m\u001b[0m with effectiveness \u001b[32m9.27%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mage = (34.0, 41.0]\u001b[39m, \u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m11.92%\u001b[39m and counterfactual cost = 2.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-2.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 2.\u001b[39m\n",
+ "If \u001b[1mhours-per-week = FullTime, native-country = United-States\u001b[0m:\n",
+ "\tProtected Subgroup '\u001b[1mFemale\u001b[0m', \u001b[34m41.66%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m2.62%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m1.79%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m0.00\u001b[39m\n",
+ "\tProtected Subgroup '\u001b[1mMale\u001b[0m', \u001b[34m46.78%\u001b[39m covered\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = OverTime\u001b[39m\u001b[0m with effectiveness \u001b[32m10.08%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\tMake \u001b[1m\u001b[31mhours-per-week = BrainDrain\u001b[39m\u001b[0m with effectiveness \u001b[32m8.70%\u001b[39m and counterfactual cost = 1.0.\n",
+ "\t\t\u001b[1mAggregate cost\u001b[0m of the above recourses = \u001b[35m-1.00\u001b[39m\n",
+ "\t\u001b[35mBias against Female with respect to equal-choice-for-recourse.. Unfairness score = 1.\u001b[39m\n"
+ ]
+ }
+ ],
+ "source": [
+ "most_biased_subgroups = FACTS_bias_scan(\n",
+ " X=X_test,\n",
+ " clf=model,\n",
+ " prot_attr=\"sex\",\n",
+ " feature_weights={f: 1 for f in X.columns},\n",
+ " feats_not_allowed_to_change=[],\n",
+ " metric=\"equal-choice-for-recourse\",\n",
+ " phi=0.1,\n",
+ " top_count=3,\n",
+ " verbose=False, # hides progress bars\n",
+ " print_recourse_report=True,\n",
+ " show_action_costs=True,\n",
+ " show_subgroup_costs=True,\n",
+ ")"
+ ]
+ },
{
- "data": {
- "text/plain": [
- "[({'hours-per-week': 'FullTime', 'native-country': 'United-States'}, 1),\n",
- " ({'age': Interval(26.0, 34.0, closed='right'), 'hours-per-week': 'FullTime'},\n",
- " 3),\n",
- " ({'age': Interval(26.0, 34.0, closed='right'),\n",
- " 'capital-loss': 0.0,\n",
- " 'hours-per-week': 'FullTime'},\n",
- " 2)]"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IGbEBY7NgPB_",
+ "outputId": "5796d312-5ef6-4b0f-c2ec-8888e3376552"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[({'hours-per-week': 'FullTime', 'native-country': 'United-States'}, 1),\n",
+ " ({'age': Interval(26.0, 34.0, closed='right'), 'hours-per-week': 'FullTime'},\n",
+ " 3),\n",
+ " ({'age': Interval(26.0, 34.0, closed='right'),\n",
+ " 'capital-loss': 0.0,\n",
+ " 'hours-per-week': 'FullTime'},\n",
+ " 2)]"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "most_biased_subgroups"
]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3po1Pz8KgPB_"
+ },
+ "source": [
+ "# Short Description of all Definitions / Metrics of Subgroup Recourse Fairness\n",
+ "\n",
+ "Here we give a brief description of each of the metrics available in our framework apart from \"Equal Choice for Recourse\".\n",
+ "\n",
+ "## Equal Effectiveness\n",
+ "\n",
+ "The classifier is considered to act fairly for a population group if the same proportion of individuals in the protected subgroups can achieve recourse.\n",
+ "\n",
+ "**Suggestion**: this metric ignores costs altogether and compares only the percentage of males VS females that can cross the model's decision boundary by the same actions. We would use it in applications where the goal is equal impact, in the sense that a change (or a set thereof) affects the same proportion of individuals in the protected subgroups. For example, in a hiring scenario, a similar proportion of males and females are expected to benefit from the same change.\n",
+ "\n",
+ "## Equal Effectiveness within Budget\n",
+ "\n",
+ "The classifier is considered to act fairly for a population group if the same proportion of individuals in the protected subgroups can achieve recourse with a cost at most $c$, where $c$ is some user-provided cost budget.\n",
+ "\n",
+ "**Suggestion**: this metric is similar to the above, but puts a bound on how large the cost of an action can be. Could be used to limit changes with undesirably large cost, e.g., salary changes up to 10K.\n",
+ "\n",
+ "## Equal Cost of Effectiveness\n",
+ "\n",
+ "The classifier is considered to act fairly for a population group if the minimum cost required to be sufficiently effective in the protected subgroups is equal. Again, as in \"Equal Choice for Recourse\", by \"sufficiently effective\" we refer to those actions that successfully flip the model's decision for at least $100\\phi \\%$ (for $\\phi \\in [0,1]$) of the subgroup.\n",
+ "\n",
+ "**Suggestion**: this metric could be useful when an external factor imposes a specific threshold, e.g. in credit risk assessment, a guideline which states that the effort required to be 80% certain that you will have your loan accepted should be the same for males and females.\n",
+ "\n",
+ "## Equal (Conditional) Mean Recourse\n",
+ "\n",
+ "This definition extends the notion of *burden* from literature ([reference](https://dl.acm.org/doi/10.1145/3375627.3375812)) to the case where not all individuals may achieve recourse. Omitting some details, given any set of individuals, the **conditional mean recourse cost** is the mean recourse cost among the subset of individuals that can actually achieve recourse, i.e. by at least one of the available actions.\n",
+ "\n",
+ "Given the above, this definition considers the classifier to act fairly for a population group if the (conditional) mean recourse cost for the protected subgroups is the same.\n",
+ "\n",
+ "**Suggestion**: this metric compares the mean cost required to achieve recourse for the protected subgroups. It could be useful in a scenario like loan approval, where one needs to ensure that the cost of changes needed to receive the loan are the same for males and females on average.\n",
+ "\n",
+ "## Fair Effectiveness-Cost Trade-Off\n",
+ "\n",
+ "This is the strictest definition, which considers the classifier to act fairly for a population group only if the protected subgroups have the same effectiveness-cost distribution (checked in the implementation via a statistical test).\n",
+ "\n",
+ "Equivalently, Equal Effectiveness within Budget must hold for *every* value of the cost budget $c$.\n",
+ "\n",
+ "**Suggestion**: this metric considers all available actions and compares all their possible trade-offs between effectiveness and cost among the protected subgroups. This could be useful for cases where the protected attribute should have absolutely no impact on the available options to achieve recourse, such as in high-risk situations like estimating the risk of a convicted individual to act unlawfully in the future (as in the well known [COMPAS dataset](https://www.propublica.org/datastore/dataset/compas-recidivism-risk-score-data-and-analysis))."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5sGwLyBNgPCA"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.18"
+ },
+ "colab": {
+ "provenance": []
}
- ],
- "source": [
- "most_biased_subgroups"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Short Description of all Definitions / Metrics of Subgroup Recourse Fairness\n",
- "\n",
- "Here we give a brief description of each of the metrics available in our framework apart from \"Equal Choice for Recourse\".\n",
- "\n",
- "## Equal Effectiveness\n",
- "\n",
- "The classifier is considered to act fairly for a population group if the same proportion of individuals in the protected subgroups can achieve recourse.\n",
- "\n",
- "**Suggestion**: this metric ignores costs altogether and compares only the percentage of males VS females that can cross the model's decision boundary by the same actions. We would use it in applications where the goal is equal impact, in the sense that a change (or a set thereof) affects the same proportion of individuals in the protected subgroups. For example, in a hiring scenario, a similar proportion of males and females are expected to benefit from the same change.\n",
- "\n",
- "## Equal Effectiveness within Budget\n",
- "\n",
- "The classifier is considered to act fairly for a population group if the same proportion of individuals in the protected subgroups can achieve recourse with a cost at most $c$, where $c$ is some user-provided cost budget.\n",
- "\n",
- "**Suggestion**: this metric is similar to the above, but puts a bound on how large the cost of an action can be. Could be used to limit changes with undesirably large cost, e.g., salary changes up to 10K.\n",
- "\n",
- "## Equal Cost of Effectiveness\n",
- "\n",
- "The classifier is considered to act fairly for a population group if the minimum cost required to be sufficiently effective in the protected subgroups is equal. Again, as in \"Equal Choice for Recourse\", by \"sufficiently effective\" we refer to those actions that successfully flip the model's decision for at least $100\\phi \\%$ (for $\\phi \\in [0,1]$) of the subgroup.\n",
- "\n",
- "**Suggestion**: this metric could be useful when an external factor imposes a specific threshold, e.g. in credit risk assessment, a guideline which states that the effort required to be 80% certain that you will have your loan accepted should be the same for males and females.\n",
- "\n",
- "## Equal (Conditional) Mean Recourse\n",
- "\n",
- "This definition extends the notion of *burden* from literature ([reference](https://dl.acm.org/doi/10.1145/3375627.3375812)) to the case where not all individuals may achieve recourse. Omitting some details, given any set of individuals, the **conditional mean recourse cost** is the mean recourse cost among the subset of individuals that can actually achieve recourse, i.e. by at least one of the available actions.\n",
- "\n",
- "Given the above, this definition considers the classifier to act fairly for a population group if the (conditional) mean recourse cost for the protected subgroups is the same.\n",
- "\n",
- "**Suggestion**: this metric compares the mean cost required to achieve recourse for the protected subgroups. It could be useful in a scenario like loan approval, where one needs to ensure that the cost of changes needed to receive the loan are the same for males and females on average.\n",
- "\n",
- "## Fair Effectiveness-Cost Trade-Off\n",
- "\n",
- "This is the strictest definition, which considers the classifier to act fairly for a population group only if the protected subgroups have the same effectiveness-cost distribution (checked in the implementation via a statistical test).\n",
- "\n",
- "Equivalently, Equal Effectiveness within Budget must hold for *every* value of the cost budget $c$.\n",
- "\n",
- "**Suggestion**: this metric considers all available actions and compares all their possible trade-offs between effectiveness and cost among the protected subgroups. This could be useful for cases where the protected attribute should have absolutely no impact on the available options to achieve recourse, such as in high-risk situations like estimating the risk of a convicted individual to act unlawfully in the future (as in the well known [COMPAS dataset](https://www.propublica.org/datastore/dataset/compas-recidivism-risk-score-data-and-analysis))."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.18"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/demo_gerryfair.ipynb b/examples/demo_gerryfair.ipynb
index efdb5983..6b6352f5 100644
--- a/examples/demo_gerryfair.ipynb
+++ b/examples/demo_gerryfair.ipynb
@@ -1,884 +1,918 @@
{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "pycharm": {
- "is_executing": false
- }
- },
- "outputs": [],
- "source": [
- "%matplotlib inline\n",
- "import warnings\n",
- "warnings.filterwarnings(\"ignore\")\n",
- "import sys\n",
- "sys.path.append(\"../\")\n",
- "from aif360.algorithms.inprocessing import GerryFairClassifier\n",
- "from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n",
- "from aif360.algorithms.inprocessing.gerryfair.auditor import Auditor\n",
- "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
- "from sklearn import svm\n",
- "from sklearn import tree\n",
- "from sklearn.kernel_ridge import KernelRidge\n",
- "from sklearn import linear_model\n",
- "from aif360.metrics import BinaryLabelDatasetMetric\n",
- "from IPython.display import Image\n",
- "import pickle\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# load data set\n",
- "data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n",
- "max_iterations = 500"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**instantiate, fit, and predict** \n",
- "\n",
- "\n",
- "We first demonstrate how to instantiate a `GerryFairClassifier`, `train` it with respect to rich subgroup fairness, and `predict` the label of a new example. We remark that when we set the `print_flag = True` at each iteration of the algorithm we print the error, fairness violation, and violated group size of most recent model. The error is the classification error of the classifier. At each round the Learner tries to find a classifier that minimizes the classification error plus a weighted sum of the fairness disparities on all the groups that the Auditor has found up until that point. By contrast the Auditor tries to find the group at each round with the greatest rich subgroup disparity with respect to the Learner's model. We define `violated group size` as the size (as a fraction of the dataset size) of this group, and the `fairness violation` as the `violated group size` times the difference in the statistical rate (FP or FN rate) on the group vs. the whole population. \n",
- "\n",
- "In the example below we set `max_iterations=500` which is an order of magnitude less than the time to convergence observed in [the rich subgroup fairness empirical paper](https://arxiv.org/abs/1808.08166), but advise that this can be highly dataset dependent. Our target $\\gamma$-disparity is $\\gamma = .005$, our statistical rate is false positive rate or `FP`, and our cost-sensitive classification oracle is linear regression (more on that below). \n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "pycharm": {
- "is_executing": true
- }
- },
- "outputs": [
+ "cells": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 1, error: 0.263, fairness violation: 0.028780000000000007, violated group size: 0.217\n",
- "iteration: 2, error: 0.3815, fairness violation: 0.014390000000000003, violated group size: 0.217\n",
- "iteration: 3, error: 0.42099999999999993, fairness violation: 0.009593333333333339, violated group size: 0.283\n",
- "iteration: 4, error: 0.44075, fairness violation: 0.007195000000000002, violated group size: 0.217\n",
- "iteration: 5, error: 0.45260000000000006, fairness violation: 0.005756000000000001, violated group size: 0.217\n",
- "iteration: 6, error: 0.4605000000000001, fairness violation: 0.004796666666666668, violated group size: 0.283\n",
- "iteration: 7, error: 0.4661428571428572, fairness violation: 0.004111428571428572, violated group size: 0.217\n",
- "iteration: 8, error: 0.470375, fairness violation: 0.0035975000000000017, violated group size: 0.217\n",
- "iteration: 9, error: 0.4691111111111112, fairness violation: 0.0033906666666666677, violated group size: 0.283\n",
- "iteration: 10, error: 0.4681, fairness violation: 0.003225200000000001, violated group size: 0.283\n",
- "iteration: 11, error: 0.4672727272727271, fairness violation: 0.0030898181818181836, violated group size: 0.283\n",
- "iteration: 12, error: 0.4665833333333333, fairness violation: 0.0029769999999999996, violated group size: 0.217\n",
- "iteration: 13, error: 0.466, fairness violation: 0.0028815384615384627, violated group size: 0.283\n",
- "iteration: 14, error: 0.4655000000000001, fairness violation: 0.0027997142857142865, violated group size: 0.217\n",
- "iteration: 15, error: 0.46506666666666674, fairness violation: 0.002728800000000001, violated group size: 0.217\n",
- "iteration: 16, error: 0.4646875, fairness violation: 0.0026667500000000007, violated group size: 0.217\n",
- "iteration: 17, error: 0.4643529411764707, fairness violation: 0.002612000000000001, violated group size: 0.283\n",
- "iteration: 18, error: 0.46405555555555567, fairness violation: 0.002563333333333334, violated group size: 0.217\n",
- "iteration: 19, error: 0.4637894736842106, fairness violation: 0.0025197894736842096, violated group size: 0.217\n",
- "iteration: 20, error: 0.46354999999999996, fairness violation: 0.0024806000000000008, violated group size: 0.283\n",
- "iteration: 21, error: 0.4633333333333334, fairness violation: 0.0024451428571428584, violated group size: 0.217\n",
- "iteration: 22, error: 0.4631363636363638, fairness violation: 0.0024129090909090914, violated group size: 0.283\n",
- "iteration: 23, error: 0.46295652173913054, fairness violation: 0.002383478260869566, violated group size: 0.217\n",
- "iteration: 24, error: 0.4627916666666667, fairness violation: 0.002356500000000001, violated group size: 0.283\n",
- "iteration: 25, error: 0.4626400000000001, fairness violation: 0.0023316800000000018, violated group size: 0.283\n",
- "iteration: 26, error: 0.4625000000000001, fairness violation: 0.0023087692307692314, violated group size: 0.217\n",
- "iteration: 27, error: 0.4623703703703705, fairness violation: 0.0022875555555555557, violated group size: 0.217\n",
- "iteration: 28, error: 0.46224999999999994, fairness violation: 0.0022678571428571426, violated group size: 0.217\n",
- "iteration: 29, error: 0.46213793103448264, fairness violation: 0.0022495172413793106, violated group size: 0.217\n",
- "iteration: 30, error: 0.46203333333333335, fairness violation: 0.0022324000000000003, violated group size: 0.217\n",
- "iteration: 31, error: 0.46193548387096783, fairness violation: 0.0022163870967741935, violated group size: 0.217\n",
- "iteration: 32, error: 0.46184375, fairness violation: 0.0022013749999999993, violated group size: 0.217\n",
- "iteration: 33, error: 0.459969696969697, fairness violation: 0.0023319393939393944, violated group size: 0.283\n",
- "iteration: 34, error: 0.4582058823529412, fairness violation: 0.002454823529411765, violated group size: 0.217\n",
- "iteration: 35, error: 0.45654285714285714, fairness violation: 0.0025706857142857144, violated group size: 0.217\n",
- "iteration: 36, error: 0.4549722222222221, fairness violation: 0.0026801111111111114, violated group size: 0.283\n",
- "iteration: 37, error: 0.4534864864864866, fairness violation: 0.0027836216216216214, violated group size: 0.283\n",
- "iteration: 38, error: 0.45207894736842097, fairness violation: 0.0028816842105263162, violated group size: 0.283\n",
- "iteration: 39, error: 0.4507435897435898, fairness violation: 0.0029747179487179492, violated group size: 0.217\n",
- "iteration: 40, error: 0.44947499999999996, fairness violation: 0.0030631000000000005, violated group size: 0.217\n",
- "iteration: 41, error: 0.44826829268292684, fairness violation: 0.0031471707317073175, violated group size: 0.283\n",
- "iteration: 42, error: 0.4471190476190476, fairness violation: 0.0032272380952380955, violated group size: 0.217\n",
- "iteration: 43, error: 0.44602325581395347, fairness violation: 0.0033035813953488386, violated group size: 0.283\n",
- "iteration: 44, error: 0.44497727272727267, fairness violation: 0.0033764545454545453, violated group size: 0.283\n",
- "iteration: 45, error: 0.4439777777777778, fairness violation: 0.003446088888888888, violated group size: 0.217\n",
- "iteration: 46, error: 0.44302173913043474, fairness violation: 0.0035126956521739122, violated group size: 0.217\n",
- "iteration: 47, error: 0.44210638297872346, fairness violation: 0.0035764680851063826, violated group size: 0.217\n",
- "iteration: 48, error: 0.4412291666666666, fairness violation: 0.003637583333333332, violated group size: 0.217\n",
- "iteration: 49, error: 0.4403877551020407, fairness violation: 0.0036962040816326523, violated group size: 0.217\n",
- "iteration: 50, error: 0.4395600000000001, fairness violation: 0.0037524800000000003, violated group size: 0.217\n",
- "iteration: 51, error: 0.43876470588235295, fairness violation: 0.0038065490196078425, violated group size: 0.217\n",
- "iteration: 52, error: 0.438, fairness violation: 0.003858538461538461, violated group size: 0.283\n",
- "iteration: 53, error: 0.4372641509433963, fairness violation: 0.003908566037735848, violated group size: 0.217\n",
- "iteration: 54, error: 0.4365555555555556, fairness violation: 0.003956740740740741, violated group size: 0.283\n",
- "iteration: 55, error: 0.4358181818181819, fairness violation: 0.004003163636363636, violated group size: 0.217\n",
- "iteration: 56, error: 0.4351071428571429, fairness violation: 0.004047928571428571, violated group size: 0.217\n",
- "iteration: 57, error: 0.4344736842105262, fairness violation: 0.004091122807017543, violated group size: 0.217\n",
- "iteration: 58, error: 0.43381034482758624, fairness violation: 0.004132827586206895, violated group size: 0.217\n",
- "iteration: 59, error: 0.4331694915254237, fairness violation: 0.0041731186440677965, violated group size: 0.283\n",
- "iteration: 60, error: 0.43254999999999993, fairness violation: 0.004212066666666666, violated group size: 0.217\n",
- "iteration: 61, error: 0.4319508196721312, fairness violation: 0.004249737704918031, violated group size: 0.217\n",
- "iteration: 62, error: 0.4313709677419356, fairness violation: 0.004286193548387096, violated group size: 0.217\n",
- "iteration: 63, error: 0.43080952380952386, fairness violation: 0.004321492063492062, violated group size: 0.283\n",
- "iteration: 64, error: 0.430265625, fairness violation: 0.004355687499999999, violated group size: 0.283\n",
- "iteration: 65, error: 0.4297384615384615, fairness violation: 0.004388830769230769, violated group size: 0.283\n",
- "iteration: 66, error: 0.42922727272727274, fairness violation: 0.004420969696969697, violated group size: 0.217\n",
- "iteration: 67, error: 0.42873134328358203, fairness violation: 0.004452149253731343, violated group size: 0.217\n",
- "iteration: 68, error: 0.42824999999999996, fairness violation: 0.0044824117647058815, violated group size: 0.283\n",
- "iteration: 69, error: 0.42778260869565227, fairness violation: 0.004511797101449274, violated group size: 0.217\n",
- "iteration: 70, error: 0.42732857142857145, fairness violation: 0.004540342857142856, violated group size: 0.283\n",
- "iteration: 71, error: 0.42688732394366197, fairness violation: 0.004568084507042252, violated group size: 0.217\n",
- "iteration: 72, error: 0.4264583333333332, fairness violation: 0.004595055555555555, violated group size: 0.283\n",
- "iteration: 73, error: 0.42604109589041106, fairness violation: 0.004621287671232876, violated group size: 0.217\n",
- "iteration: 74, error: 0.4256351351351351, fairness violation: 0.0046468108108108095, violated group size: 0.283\n",
- "iteration: 75, error: 0.42524, fairness violation: 0.004671653333333331, violated group size: 0.217\n",
- "iteration: 76, error: 0.4248552631578947, fairness violation: 0.004695842105263155, violated group size: 0.217\n",
- "iteration: 77, error: 0.42448051948051946, fairness violation: 0.004719402597402596, violated group size: 0.217\n",
- "iteration: 78, error: 0.4239871794871795, fairness violation: 0.00475905128205128, violated group size: 0.217\n",
- "iteration: 79, error: 0.42363291139240505, fairness violation: 0.004781215189873418, violated group size: 0.283\n",
- "iteration: 80, error: 0.42328750000000015, fairness violation: 0.004802824999999999, violated group size: 0.283\n"
- ]
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_gerryfair.ipynb)"
+ ],
+ "metadata": {
+ "id": "_Y3xm4CthJnm"
+ }
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 81, error: 0.4229506172839506, fairness violation: 0.004823901234567901, violated group size: 0.283\n",
- "iteration: 82, error: 0.4226219512195123, fairness violation: 0.004844463414634145, violated group size: 0.217\n",
- "iteration: 83, error: 0.4221807228915662, fairness violation: 0.004880216867469879, violated group size: 0.217\n",
- "iteration: 84, error: 0.42175, fairness violation: 0.004915119047619047, violated group size: 0.217\n",
- "iteration: 85, error: 0.4214470588235294, fairness violation: 0.004933882352941174, violated group size: 0.217\n",
- "iteration: 86, error: 0.4210348837209302, fairness violation: 0.004967348837209301, violated group size: 0.217\n",
- "iteration: 87, error: 0.420632183908046, fairness violation: 0.005000045977011494, violated group size: 0.283\n",
- "iteration: 88, error: 0.42035227272727277, fairness violation: 0.0050172045454545434, violated group size: 0.217\n",
- "iteration: 89, error: 0.4200786516853933, fairness violation: 0.005033977528089887, violated group size: 0.217\n",
- "iteration: 90, error: 0.4198111111111112, fairness violation: 0.005050377777777776, violated group size: 0.283\n",
- "iteration: 91, error: 0.4195824175824176, fairness violation: 0.0050664175824175805, violated group size: 0.217\n",
- "iteration: 92, error: 0.4193695652173913, fairness violation: 0.005082108695652173, violated group size: 0.217\n",
- "iteration: 93, error: 0.41916129032258065, fairness violation: 0.005097462365591397, violated group size: 0.217\n",
- "iteration: 94, error: 0.41895744680851066, fairness violation: 0.005112489361702126, violated group size: 0.217\n",
- "iteration: 95, error: 0.41875789473684216, fairness violation: 0.005127199999999998, violated group size: 0.217\n",
- "iteration: 96, error: 0.41856250000000006, fairness violation: 0.005141604166666665, violated group size: 0.283\n",
- "iteration: 97, error: 0.418979381443299, fairness violation: 0.005106494845360823, violated group size: 0.217\n",
- "iteration: 98, error: 0.41938775510204085, fairness violation: 0.005072102040816325, violated group size: 0.217\n",
- "iteration: 99, error: 0.4197878787878788, fairness violation: 0.0050384040404040376, violated group size: 0.217\n",
- "iteration: 100, error: 0.42018000000000005, fairness violation: 0.0050053799999999985, violated group size: 0.217\n",
- "iteration: 101, error: 0.42056435643564366, fairness violation: 0.004973009900990098, violated group size: 0.217\n",
- "iteration: 102, error: 0.42094117647058826, fairness violation: 0.00494127450980392, violated group size: 0.217\n",
- "iteration: 103, error: 0.4213106796116506, fairness violation: 0.004910155339805824, violated group size: 0.217\n",
- "iteration: 104, error: 0.4216730769230769, fairness violation: 0.004879634615384614, violated group size: 0.217\n",
- "iteration: 105, error: 0.4220285714285715, fairness violation: 0.004849695238095237, violated group size: 0.217\n",
- "iteration: 106, error: 0.4223773584905662, fairness violation: 0.004820320754716981, violated group size: 0.283\n",
- "iteration: 107, error: 0.42271962616822434, fairness violation: 0.004791495327102803, violated group size: 0.217\n",
- "iteration: 108, error: 0.4230555555555556, fairness violation: 0.0047632037037037035, violated group size: 0.217\n",
- "iteration: 109, error: 0.4233853211009175, fairness violation: 0.00473543119266055, violated group size: 0.217\n",
- "iteration: 110, error: 0.4237090909090908, fairness violation: 0.004708163636363636, violated group size: 0.217\n",
- "iteration: 111, error: 0.424027027027027, fairness violation: 0.004681387387387387, violated group size: 0.283\n",
- "iteration: 112, error: 0.42433928571428586, fairness violation: 0.004655089285714286, violated group size: 0.283\n",
- "iteration: 113, error: 0.4241238938053097, fairness violation: 0.004671504424778761, violated group size: 0.217\n",
- "iteration: 114, error: 0.42442982456140343, fairness violation: 0.004645754385964912, violated group size: 0.283\n",
- "iteration: 115, error: 0.42473043478260875, fairness violation: 0.0046204521739130425, violated group size: 0.283\n",
- "iteration: 116, error: 0.42502586206896553, fairness violation: 0.0045955862068965524, violated group size: 0.283\n",
- "iteration: 117, error: 0.42481196581196584, fairness violation: 0.004611948717948717, violated group size: 0.217\n",
- "iteration: 118, error: 0.4251016949152542, fairness violation: 0.004587576271186439, violated group size: 0.217\n",
- "iteration: 119, error: 0.42489075630252104, fairness violation: 0.004603731092436974, violated group size: 0.217\n",
- "iteration: 120, error: 0.4251750000000001, fairness violation: 0.0045798333333333325, violated group size: 0.217\n",
- "iteration: 121, error: 0.4249669421487604, fairness violation: 0.004595785123966942, violated group size: 0.283\n",
- "iteration: 122, error: 0.4247622950819671, fairness violation: 0.0046114754098360656, violated group size: 0.217\n",
- "iteration: 123, error: 0.42456097560975614, fairness violation: 0.00462691056910569, violated group size: 0.217\n",
- "iteration: 124, error: 0.42436290322580644, fairness violation: 0.004642096774193548, violated group size: 0.217\n",
- "iteration: 125, error: 0.4241680000000001, fairness violation: 0.00465704, violated group size: 0.217\n",
- "iteration: 126, error: 0.4239761904761905, fairness violation: 0.004671746031746031, violated group size: 0.217\n",
- "iteration: 127, error: 0.42425196850393704, fairness violation: 0.004648629921259842, violated group size: 0.217\n",
- "iteration: 128, error: 0.4240625, fairness violation: 0.004663171874999999, violated group size: 0.217\n",
- "iteration: 129, error: 0.4238759689922481, fairness violation: 0.004677488372093024, violated group size: 0.283\n",
- "iteration: 130, error: 0.42369230769230776, fairness violation: 0.004691584615384614, violated group size: 0.217\n",
- "iteration: 131, error: 0.42351145038167937, fairness violation: 0.004705465648854962, violated group size: 0.217\n",
- "iteration: 132, error: 0.4233333333333333, fairness violation: 0.004719136363636364, violated group size: 0.283\n",
- "iteration: 133, error: 0.423157894736842, fairness violation: 0.0047326015037594, violated group size: 0.217\n",
- "iteration: 134, error: 0.4229850746268656, fairness violation: 0.004745865671641791, violated group size: 0.217\n",
- "iteration: 135, error: 0.42281481481481475, fairness violation: 0.004758933333333335, violated group size: 0.283\n",
- "iteration: 136, error: 0.4226470588235294, fairness violation: 0.004771808823529411, violated group size: 0.217\n",
- "iteration: 137, error: 0.42248175182481745, fairness violation: 0.004784496350364964, violated group size: 0.283\n",
- "iteration: 138, error: 0.42231884057971014, fairness violation: 0.004797000000000002, violated group size: 0.283\n",
- "iteration: 139, error: 0.42215827338129497, fairness violation: 0.004809323741007196, violated group size: 0.283\n",
- "iteration: 140, error: 0.42200000000000004, fairness violation: 0.004821471428571429, violated group size: 0.217\n",
- "iteration: 141, error: 0.4218439716312057, fairness violation: 0.0048334468085106394, violated group size: 0.217\n",
- "iteration: 142, error: 0.42169014084507045, fairness violation: 0.004845253521126761, violated group size: 0.283\n",
- "iteration: 143, error: 0.4215384615384616, fairness violation: 0.004856895104895106, violated group size: 0.283\n",
- "iteration: 144, error: 0.4213888888888888, fairness violation: 0.004868375, violated group size: 0.217\n",
- "iteration: 145, error: 0.42124137931034483, fairness violation: 0.004879696551724138, violated group size: 0.217\n",
- "iteration: 146, error: 0.4210958904109589, fairness violation: 0.00489086301369863, violated group size: 0.217\n",
- "iteration: 147, error: 0.4209523809523809, fairness violation: 0.004901877551020409, violated group size: 0.217\n",
- "iteration: 148, error: 0.42081081081081084, fairness violation: 0.004912743243243244, violated group size: 0.217\n",
- "iteration: 149, error: 0.42067114093959734, fairness violation: 0.004923463087248323, violated group size: 0.283\n",
- "iteration: 150, error: 0.4205333333333334, fairness violation: 0.004934040000000001, violated group size: 0.217\n",
- "iteration: 151, error: 0.4203973509933776, fairness violation: 0.004944476821192053, violated group size: 0.217\n",
- "iteration: 152, error: 0.4202631578947368, fairness violation: 0.0049547763157894754, violated group size: 0.283\n",
- "iteration: 153, error: 0.4201307189542483, fairness violation: 0.00496494117647059, violated group size: 0.283\n",
- "iteration: 154, error: 0.42, fairness violation: 0.004974974025974027, violated group size: 0.283\n",
- "iteration: 155, error: 0.4198709677419355, fairness violation: 0.0049848774193548395, violated group size: 0.217\n",
- "iteration: 156, error: 0.4197435897435898, fairness violation: 0.004994653846153847, violated group size: 0.217\n",
- "iteration: 157, error: 0.4196178343949045, fairness violation: 0.0050043057324840766, violated group size: 0.217\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": false
+ },
+ "id": "0xQSRpo5hI5H"
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "from aif360.algorithms.inprocessing import GerryFairClassifier\n",
+ "from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n",
+ "from aif360.algorithms.inprocessing.gerryfair.auditor import Auditor\n",
+ "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
+ "from sklearn import svm\n",
+ "from sklearn import tree\n",
+ "from sklearn.kernel_ridge import KernelRidge\n",
+ "from sklearn import linear_model\n",
+ "from aif360.metrics import BinaryLabelDatasetMetric\n",
+ "from IPython.display import Image\n",
+ "import pickle\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# load data set\n",
+ "data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n",
+ "max_iterations = 500"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 158, error: 0.4194936708860761, fairness violation: 0.005013835443037975, violated group size: 0.217\n",
- "iteration: 159, error: 0.41937106918239, fairness violation: 0.005023245283018869, violated group size: 0.283\n",
- "iteration: 160, error: 0.41925000000000007, fairness violation: 0.0050325375, violated group size: 0.217\n",
- "iteration: 161, error: 0.4191304347826087, fairness violation: 0.005041714285714285, violated group size: 0.217\n",
- "iteration: 162, error: 0.41901234567901235, fairness violation: 0.005050777777777778, violated group size: 0.283\n",
- "iteration: 163, error: 0.41889570552147243, fairness violation: 0.005059730061349694, violated group size: 0.283\n",
- "iteration: 164, error: 0.4191402439024391, fairness violation: 0.005039463414634148, violated group size: 0.283\n",
- "iteration: 165, error: 0.41938181818181824, fairness violation: 0.005019442424242424, violated group size: 0.217\n",
- "iteration: 166, error: 0.4192650602409639, fairness violation: 0.005028421686746988, violated group size: 0.217\n",
- "iteration: 167, error: 0.41950299401197605, fairness violation: 0.005008706586826348, violated group size: 0.217\n",
- "iteration: 168, error: 0.41973809523809513, fairness violation: 0.004989226190476189, violated group size: 0.217\n",
- "iteration: 169, error: 0.41997041420118336, fairness violation: 0.0049699763313609474, violated group size: 0.283\n",
- "iteration: 170, error: 0.4202, fairness violation: 0.00495095294117647, violated group size: 0.283\n",
- "iteration: 171, error: 0.4204269005847953, fairness violation: 0.004932152046783625, violated group size: 0.217\n",
- "iteration: 172, error: 0.42065116279069764, fairness violation: 0.00491356976744186, violated group size: 0.217\n",
- "iteration: 173, error: 0.4208728323699421, fairness violation: 0.004895202312138728, violated group size: 0.217\n",
- "iteration: 174, error: 0.42109195402298855, fairness violation: 0.004877045977011494, violated group size: 0.217\n",
- "iteration: 175, error: 0.4213085714285715, fairness violation: 0.004859097142857142, violated group size: 0.217\n",
- "iteration: 176, error: 0.42152272727272727, fairness violation: 0.0048413522727272715, violated group size: 0.217\n",
- "iteration: 177, error: 0.42173446327683617, fairness violation: 0.00482380790960452, violated group size: 0.217\n",
- "iteration: 178, error: 0.42161235955056187, fairness violation: 0.004833280898876404, violated group size: 0.217\n",
- "iteration: 179, error: 0.42182122905027924, fairness violation: 0.004815977653631285, violated group size: 0.217\n",
- "iteration: 180, error: 0.4220277777777778, fairness violation: 0.004798866666666665, violated group size: 0.217\n",
- "iteration: 181, error: 0.42223204419889504, fairness violation: 0.004781944751381214, violated group size: 0.283\n",
- "iteration: 182, error: 0.4224340659340659, fairness violation: 0.004765208791208789, violated group size: 0.217\n",
- "iteration: 183, error: 0.4226338797814208, fairness violation: 0.004748655737704917, violated group size: 0.217\n",
- "iteration: 184, error: 0.4228315217391304, fairness violation: 0.004732282608695651, violated group size: 0.217\n",
- "iteration: 185, error: 0.4230270270270271, fairness violation: 0.004716086486486487, violated group size: 0.217\n",
- "iteration: 186, error: 0.4229032258064515, fairness violation: 0.0047257311827957, violated group size: 0.283\n",
- "iteration: 187, error: 0.4230962566844919, fairness violation: 0.0047097433155080205, violated group size: 0.217\n",
- "iteration: 188, error: 0.4229734042553191, fairness violation: 0.00471931914893617, violated group size: 0.283\n",
- "iteration: 189, error: 0.4231640211640213, fairness violation: 0.004703534391534391, violated group size: 0.217\n",
- "iteration: 190, error: 0.42304210526315794, fairness violation: 0.004713042105263158, violated group size: 0.217\n",
- "iteration: 191, error: 0.4232303664921467, fairness violation: 0.004697455497382198, violated group size: 0.217\n",
- "iteration: 192, error: 0.42310937499999995, fairness violation: 0.004706895833333333, violated group size: 0.217\n",
- "iteration: 193, error: 0.42329533678756476, fairness violation: 0.004691502590673575, violated group size: 0.283\n",
- "iteration: 194, error: 0.4231752577319588, fairness violation: 0.004700876288659792, violated group size: 0.217\n",
- "iteration: 195, error: 0.4230564102564103, fairness violation: 0.004710153846153845, violated group size: 0.217\n",
- "iteration: 196, error: 0.4229387755102041, fairness violation: 0.004719336734693878, violated group size: 0.283\n",
- "iteration: 197, error: 0.4228223350253807, fairness violation: 0.004728426395939086, violated group size: 0.283\n",
- "iteration: 198, error: 0.4227070707070707, fairness violation: 0.004737424242424242, violated group size: 0.217\n",
- "iteration: 199, error: 0.4228894472361809, fairness violation: 0.004722341708542713, violated group size: 0.217\n",
- "iteration: 200, error: 0.42277499999999996, fairness violation: 0.004731279999999999, violated group size: 0.217\n",
- "iteration: 201, error: 0.4226616915422886, fairness violation: 0.004740129353233829, violated group size: 0.217\n",
- "iteration: 202, error: 0.4225495049504951, fairness violation: 0.00474889108910891, violated group size: 0.217\n",
- "iteration: 203, error: 0.42243842364532025, fairness violation: 0.004757566502463053, violated group size: 0.283\n",
- "iteration: 204, error: 0.42232843137254905, fairness violation: 0.004766156862745097, violated group size: 0.283\n",
- "iteration: 205, error: 0.422219512195122, fairness violation: 0.004774663414634145, violated group size: 0.217\n",
- "iteration: 206, error: 0.422111650485437, fairness violation: 0.004783087378640775, violated group size: 0.217\n",
- "iteration: 207, error: 0.4220048309178744, fairness violation: 0.00479142995169082, violated group size: 0.217\n",
- "iteration: 208, error: 0.4218990384615385, fairness violation: 0.004799692307692306, violated group size: 0.217\n",
- "iteration: 209, error: 0.42179425837320567, fairness violation: 0.004807875598086124, violated group size: 0.217\n",
- "iteration: 210, error: 0.4216904761904762, fairness violation: 0.004815980952380952, violated group size: 0.217\n",
- "iteration: 211, error: 0.42158767772511846, fairness violation: 0.0048240094786729856, violated group size: 0.217\n",
- "iteration: 212, error: 0.42148584905660386, fairness violation: 0.004831962264150944, violated group size: 0.217\n",
- "iteration: 213, error: 0.4213849765258215, fairness violation: 0.004839840375586855, violated group size: 0.283\n",
- "iteration: 214, error: 0.421285046728972, fairness violation: 0.0048476448598130835, violated group size: 0.217\n",
- "iteration: 215, error: 0.42118604651162794, fairness violation: 0.004855376744186045, violated group size: 0.217\n",
- "iteration: 216, error: 0.4210879629629629, fairness violation: 0.004863037037037037, violated group size: 0.283\n",
- "iteration: 217, error: 0.42099078341013824, fairness violation: 0.004870626728110601, violated group size: 0.283\n",
- "iteration: 218, error: 0.42089449541284396, fairness violation: 0.004878146788990825, violated group size: 0.217\n",
- "iteration: 219, error: 0.42079908675799094, fairness violation: 0.004885598173515983, violated group size: 0.283\n",
- "iteration: 220, error: 0.4207045454545455, fairness violation: 0.004892981818181818, violated group size: 0.283\n",
- "iteration: 221, error: 0.4206108597285068, fairness violation: 0.004900298642533936, violated group size: 0.283\n",
- "iteration: 222, error: 0.4205180180180179, fairness violation: 0.004907549549549549, violated group size: 0.217\n",
- "iteration: 223, error: 0.4204260089686098, fairness violation: 0.004914735426008968, violated group size: 0.217\n",
- "iteration: 224, error: 0.4203348214285714, fairness violation: 0.004921857142857143, violated group size: 0.283\n",
- "iteration: 225, error: 0.4202444444444444, fairness violation: 0.004928915555555555, violated group size: 0.217\n",
- "iteration: 226, error: 0.42015486725663725, fairness violation: 0.004935911504424777, violated group size: 0.217\n",
- "iteration: 227, error: 0.4200660792951542, fairness violation: 0.004942845814977973, violated group size: 0.217\n",
- "iteration: 228, error: 0.4199780701754386, fairness violation: 0.004949719298245614, violated group size: 0.217\n",
- "iteration: 229, error: 0.4198908296943231, fairness violation: 0.004956532751091703, violated group size: 0.283\n",
- "iteration: 230, error: 0.419804347826087, fairness violation: 0.004963286956521739, violated group size: 0.283\n",
- "iteration: 231, error: 0.4197186147186147, fairness violation: 0.004969982683982686, violated group size: 0.283\n",
- "iteration: 232, error: 0.4196336206896552, fairness violation: 0.004976620689655175, violated group size: 0.283\n",
- "iteration: 233, error: 0.419549356223176, fairness violation: 0.004983201716738197, violated group size: 0.283\n",
- "iteration: 234, error: 0.419465811965812, fairness violation: 0.0049897264957264945, violated group size: 0.217\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iO2ueydohI5J"
+ },
+ "source": [
+ "**instantiate, fit, and predict**\n",
+ "\n",
+ "\n",
+ "We first demonstrate how to instantiate a `GerryFairClassifier`, `train` it with respect to rich subgroup fairness, and `predict` the label of a new example. We remark that when we set the `print_flag = True` at each iteration of the algorithm we print the error, fairness violation, and violated group size of most recent model. The error is the classification error of the classifier. At each round the Learner tries to find a classifier that minimizes the classification error plus a weighted sum of the fairness disparities on all the groups that the Auditor has found up until that point. By contrast the Auditor tries to find the group at each round with the greatest rich subgroup disparity with respect to the Learner's model. We define `violated group size` as the size (as a fraction of the dataset size) of this group, and the `fairness violation` as the `violated group size` times the difference in the statistical rate (FP or FN rate) on the group vs. the whole population.\n",
+ "\n",
+ "In the example below we set `max_iterations=500` which is an order of magnitude less than the time to convergence observed in [the rich subgroup fairness empirical paper](https://arxiv.org/abs/1808.08166), but advise that this can be highly dataset dependent. Our target $\\gamma$-disparity is $\\gamma = .005$, our statistical rate is false positive rate or `FP`, and our cost-sensitive classification oracle is linear regression (more on that below).\n"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 235, error: 0.4193829787234044, fairness violation: 0.00499619574468085, violated group size: 0.217\n",
- "iteration: 236, error: 0.41930084745762713, fairness violation: 0.005002610169491525, violated group size: 0.283\n",
- "iteration: 237, error: 0.4192194092827004, fairness violation: 0.005008970464135021, violated group size: 0.283\n",
- "iteration: 238, error: 0.41913865546218493, fairness violation: 0.00501527731092437, violated group size: 0.283\n",
- "iteration: 239, error: 0.4190585774058578, fairness violation: 0.005021531380753138, violated group size: 0.283\n",
- "iteration: 240, error: 0.4189791666666667, fairness violation: 0.005027733333333333, violated group size: 0.217\n",
- "iteration: 241, error: 0.41890041493775937, fairness violation: 0.005033883817427385, violated group size: 0.217\n",
- "iteration: 242, error: 0.4188223140495868, fairness violation: 0.0050399834710743805, violated group size: 0.217\n",
- "iteration: 243, error: 0.4187448559670781, fairness violation: 0.005046032921810699, violated group size: 0.217\n",
- "iteration: 244, error: 0.41890983606557386, fairness violation: 0.005032467213114753, violated group size: 0.217\n",
- "iteration: 245, error: 0.419073469387755, fairness violation: 0.005019012244897959, violated group size: 0.217\n",
- "iteration: 246, error: 0.41923577235772364, fairness violation: 0.005005666666666667, violated group size: 0.283\n",
- "iteration: 247, error: 0.4191578947368422, fairness violation: 0.005011757085020243, violated group size: 0.217\n",
- "iteration: 248, error: 0.41908064516129034, fairness violation: 0.005017798387096774, violated group size: 0.283\n",
- "iteration: 249, error: 0.4192409638554217, fairness violation: 0.005004618473895581, violated group size: 0.217\n",
- "iteration: 250, error: 0.4194000000000001, fairness violation: 0.004991544, violated group size: 0.217\n",
- "iteration: 251, error: 0.4195577689243028, fairness violation: 0.004978573705179282, violated group size: 0.283\n",
- "iteration: 252, error: 0.4197142857142857, fairness violation: 0.004965706349206349, violated group size: 0.217\n",
- "iteration: 253, error: 0.4198695652173912, fairness violation: 0.00495294071146245, violated group size: 0.217\n",
- "iteration: 254, error: 0.42002362204724414, fairness violation: 0.00494027559055118, violated group size: 0.217\n",
- "iteration: 255, error: 0.4201764705882353, fairness violation: 0.004927709803921568, violated group size: 0.217\n",
- "iteration: 256, error: 0.420328125, fairness violation: 0.004915242187499999, violated group size: 0.217\n",
- "iteration: 257, error: 0.42047859922178993, fairness violation: 0.004902871595330739, violated group size: 0.217\n",
- "iteration: 258, error: 0.42062790697674424, fairness violation: 0.004890596899224807, violated group size: 0.217\n",
- "iteration: 259, error: 0.4207760617760617, fairness violation: 0.004878416988416988, violated group size: 0.283\n",
- "iteration: 260, error: 0.42069615384615383, fairness violation: 0.004884692307692307, violated group size: 0.217\n",
- "iteration: 261, error: 0.42061685823754785, fairness violation: 0.004890919540229886, violated group size: 0.283\n",
- "iteration: 262, error: 0.42076335877862603, fairness violation: 0.004878877862595419, violated group size: 0.217\n",
- "iteration: 263, error: 0.4209087452471483, fairness violation: 0.004866927756653992, violated group size: 0.217\n",
- "iteration: 264, error: 0.4208295454545454, fairness violation: 0.004873151515151516, violated group size: 0.283\n",
- "iteration: 265, error: 0.4209735849056603, fairness violation: 0.004861313207547171, violated group size: 0.283\n",
- "iteration: 266, error: 0.4211165413533834, fairness violation: 0.004849563909774436, violated group size: 0.217\n",
- "iteration: 267, error: 0.42125842696629223, fairness violation: 0.0048379026217228475, violated group size: 0.217\n",
- "iteration: 268, error: 0.42139925373134324, fairness violation: 0.004826328358208956, violated group size: 0.283\n",
- "iteration: 269, error: 0.42131970260223045, fairness violation: 0.0048325873605947955, violated group size: 0.217\n",
- "iteration: 270, error: 0.42145925925925937, fairness violation: 0.00482111851851852, violated group size: 0.283\n",
- "iteration: 271, error: 0.42159778597785963, fairness violation: 0.004809734317343174, violated group size: 0.217\n",
- "iteration: 272, error: 0.42173529411764704, fairness violation: 0.004798433823529413, violated group size: 0.283\n",
- "iteration: 273, error: 0.4216556776556777, fairness violation: 0.004804703296703296, violated group size: 0.217\n",
- "iteration: 274, error: 0.42157664233576647, fairness violation: 0.00481092700729927, violated group size: 0.217\n",
- "iteration: 275, error: 0.42171272727272724, fairness violation: 0.004799745454545455, violated group size: 0.217\n",
- "iteration: 276, error: 0.42184782608695653, fairness violation: 0.004788644927536233, violated group size: 0.217\n",
- "iteration: 277, error: 0.4219819494584837, fairness violation: 0.004777624548736462, violated group size: 0.217\n",
- "iteration: 278, error: 0.42190287769784157, fairness violation: 0.004783856115107913, violated group size: 0.217\n",
- "iteration: 279, error: 0.42182437275985674, fairness violation: 0.004790043010752689, violated group size: 0.217\n",
- "iteration: 280, error: 0.42174642857142863, fairness violation: 0.004796185714285715, violated group size: 0.283\n",
- "iteration: 281, error: 0.42166903914590753, fairness violation: 0.0048022846975088965, violated group size: 0.283\n",
- "iteration: 282, error: 0.4218014184397163, fairness violation: 0.00479141134751773, violated group size: 0.217\n",
- "iteration: 283, error: 0.42172438162544174, fairness violation: 0.0047974840989399295, violated group size: 0.217\n",
- "iteration: 284, error: 0.4216478873239437, fairness violation: 0.004803514084507042, violated group size: 0.217\n",
- "iteration: 285, error: 0.42157192982456126, fairness violation: 0.004809501754385964, violated group size: 0.217\n",
- "iteration: 286, error: 0.4217027972027972, fairness violation: 0.004798755244755245, violated group size: 0.217\n",
- "iteration: 287, error: 0.4218327526132404, fairness violation: 0.004788083623693379, violated group size: 0.283\n",
- "iteration: 288, error: 0.4219618055555556, fairness violation: 0.004777486111111113, violated group size: 0.283\n",
- "iteration: 289, error: 0.4218858131487888, fairness violation: 0.004783480968858131, violated group size: 0.217\n",
- "iteration: 290, error: 0.4218103448275861, fairness violation: 0.004789434482758621, violated group size: 0.217\n",
- "iteration: 291, error: 0.42193814432989696, fairness violation: 0.004778941580756014, violated group size: 0.283\n",
- "iteration: 292, error: 0.42186301369863016, fairness violation: 0.0047848698630136985, violated group size: 0.217\n",
- "iteration: 293, error: 0.4217883959044368, fairness violation: 0.004790757679180888, violated group size: 0.217\n",
- "iteration: 294, error: 0.42171428571428576, fairness violation: 0.004796605442176871, violated group size: 0.217\n",
- "iteration: 295, error: 0.4216406779661017, fairness violation: 0.004802413559322035, violated group size: 0.217\n",
- "iteration: 296, error: 0.4215675675675675, fairness violation: 0.004808182432432432, violated group size: 0.217\n",
- "iteration: 297, error: 0.4214949494949495, fairness violation: 0.00481391245791246, violated group size: 0.283\n",
- "iteration: 298, error: 0.4214228187919464, fairness violation: 0.0048196040268456385, violated group size: 0.217\n",
- "iteration: 299, error: 0.42135117056856186, fairness violation: 0.0048252575250836115, violated group size: 0.217\n",
- "iteration: 300, error: 0.42128, fairness violation: 0.004830873333333335, violated group size: 0.283\n",
- "iteration: 301, error: 0.42120930232558146, fairness violation: 0.004836451827242525, violated group size: 0.217\n",
- "iteration: 302, error: 0.42113907284768215, fairness violation: 0.004841993377483444, violated group size: 0.217\n",
- "iteration: 303, error: 0.42106930693069305, fairness violation: 0.004847498349834984, violated group size: 0.217\n",
- "iteration: 304, error: 0.4211940789473684, fairness violation: 0.004837263157894738, violated group size: 0.283\n",
- "iteration: 305, error: 0.4211245901639345, fairness violation: 0.004842747540983607, violated group size: 0.283\n",
- "iteration: 306, error: 0.4210555555555555, fairness violation: 0.004848196078431373, violated group size: 0.217\n",
- "iteration: 307, error: 0.42098697068403895, fairness violation: 0.004853609120521175, violated group size: 0.283\n",
- "iteration: 308, error: 0.4209188311688312, fairness violation: 0.004858987012987015, violated group size: 0.283\n",
- "iteration: 309, error: 0.42085113268608415, fairness violation: 0.00486433009708738, violated group size: 0.283\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "sgydei4GhI5K",
+ "outputId": "5ec1d33f-dce5-4bee-8f91-3f2901f5a15a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 1, error: 0.263, fairness violation: 0.028780000000000007, violated group size: 0.217\n",
+ "iteration: 2, error: 0.3815, fairness violation: 0.014390000000000003, violated group size: 0.217\n",
+ "iteration: 3, error: 0.42099999999999993, fairness violation: 0.009593333333333339, violated group size: 0.283\n",
+ "iteration: 4, error: 0.44075, fairness violation: 0.007195000000000002, violated group size: 0.217\n",
+ "iteration: 5, error: 0.45260000000000006, fairness violation: 0.005756000000000001, violated group size: 0.217\n",
+ "iteration: 6, error: 0.4605000000000001, fairness violation: 0.004796666666666668, violated group size: 0.283\n",
+ "iteration: 7, error: 0.4661428571428572, fairness violation: 0.004111428571428572, violated group size: 0.217\n",
+ "iteration: 8, error: 0.470375, fairness violation: 0.0035975000000000017, violated group size: 0.217\n",
+ "iteration: 9, error: 0.4691111111111112, fairness violation: 0.0033906666666666677, violated group size: 0.283\n",
+ "iteration: 10, error: 0.4681, fairness violation: 0.003225200000000001, violated group size: 0.283\n",
+ "iteration: 11, error: 0.4672727272727271, fairness violation: 0.0030898181818181836, violated group size: 0.283\n",
+ "iteration: 12, error: 0.4665833333333333, fairness violation: 0.0029769999999999996, violated group size: 0.217\n",
+ "iteration: 13, error: 0.466, fairness violation: 0.0028815384615384627, violated group size: 0.283\n",
+ "iteration: 14, error: 0.4655000000000001, fairness violation: 0.0027997142857142865, violated group size: 0.217\n",
+ "iteration: 15, error: 0.46506666666666674, fairness violation: 0.002728800000000001, violated group size: 0.217\n",
+ "iteration: 16, error: 0.4646875, fairness violation: 0.0026667500000000007, violated group size: 0.217\n",
+ "iteration: 17, error: 0.4643529411764707, fairness violation: 0.002612000000000001, violated group size: 0.283\n",
+ "iteration: 18, error: 0.46405555555555567, fairness violation: 0.002563333333333334, violated group size: 0.217\n",
+ "iteration: 19, error: 0.4637894736842106, fairness violation: 0.0025197894736842096, violated group size: 0.217\n",
+ "iteration: 20, error: 0.46354999999999996, fairness violation: 0.0024806000000000008, violated group size: 0.283\n",
+ "iteration: 21, error: 0.4633333333333334, fairness violation: 0.0024451428571428584, violated group size: 0.217\n",
+ "iteration: 22, error: 0.4631363636363638, fairness violation: 0.0024129090909090914, violated group size: 0.283\n",
+ "iteration: 23, error: 0.46295652173913054, fairness violation: 0.002383478260869566, violated group size: 0.217\n",
+ "iteration: 24, error: 0.4627916666666667, fairness violation: 0.002356500000000001, violated group size: 0.283\n",
+ "iteration: 25, error: 0.4626400000000001, fairness violation: 0.0023316800000000018, violated group size: 0.283\n",
+ "iteration: 26, error: 0.4625000000000001, fairness violation: 0.0023087692307692314, violated group size: 0.217\n",
+ "iteration: 27, error: 0.4623703703703705, fairness violation: 0.0022875555555555557, violated group size: 0.217\n",
+ "iteration: 28, error: 0.46224999999999994, fairness violation: 0.0022678571428571426, violated group size: 0.217\n",
+ "iteration: 29, error: 0.46213793103448264, fairness violation: 0.0022495172413793106, violated group size: 0.217\n",
+ "iteration: 30, error: 0.46203333333333335, fairness violation: 0.0022324000000000003, violated group size: 0.217\n",
+ "iteration: 31, error: 0.46193548387096783, fairness violation: 0.0022163870967741935, violated group size: 0.217\n",
+ "iteration: 32, error: 0.46184375, fairness violation: 0.0022013749999999993, violated group size: 0.217\n",
+ "iteration: 33, error: 0.459969696969697, fairness violation: 0.0023319393939393944, violated group size: 0.283\n",
+ "iteration: 34, error: 0.4582058823529412, fairness violation: 0.002454823529411765, violated group size: 0.217\n",
+ "iteration: 35, error: 0.45654285714285714, fairness violation: 0.0025706857142857144, violated group size: 0.217\n",
+ "iteration: 36, error: 0.4549722222222221, fairness violation: 0.0026801111111111114, violated group size: 0.283\n",
+ "iteration: 37, error: 0.4534864864864866, fairness violation: 0.0027836216216216214, violated group size: 0.283\n",
+ "iteration: 38, error: 0.45207894736842097, fairness violation: 0.0028816842105263162, violated group size: 0.283\n",
+ "iteration: 39, error: 0.4507435897435898, fairness violation: 0.0029747179487179492, violated group size: 0.217\n",
+ "iteration: 40, error: 0.44947499999999996, fairness violation: 0.0030631000000000005, violated group size: 0.217\n",
+ "iteration: 41, error: 0.44826829268292684, fairness violation: 0.0031471707317073175, violated group size: 0.283\n",
+ "iteration: 42, error: 0.4471190476190476, fairness violation: 0.0032272380952380955, violated group size: 0.217\n",
+ "iteration: 43, error: 0.44602325581395347, fairness violation: 0.0033035813953488386, violated group size: 0.283\n",
+ "iteration: 44, error: 0.44497727272727267, fairness violation: 0.0033764545454545453, violated group size: 0.283\n",
+ "iteration: 45, error: 0.4439777777777778, fairness violation: 0.003446088888888888, violated group size: 0.217\n",
+ "iteration: 46, error: 0.44302173913043474, fairness violation: 0.0035126956521739122, violated group size: 0.217\n",
+ "iteration: 47, error: 0.44210638297872346, fairness violation: 0.0035764680851063826, violated group size: 0.217\n",
+ "iteration: 48, error: 0.4412291666666666, fairness violation: 0.003637583333333332, violated group size: 0.217\n",
+ "iteration: 49, error: 0.4403877551020407, fairness violation: 0.0036962040816326523, violated group size: 0.217\n",
+ "iteration: 50, error: 0.4395600000000001, fairness violation: 0.0037524800000000003, violated group size: 0.217\n",
+ "iteration: 51, error: 0.43876470588235295, fairness violation: 0.0038065490196078425, violated group size: 0.217\n",
+ "iteration: 52, error: 0.438, fairness violation: 0.003858538461538461, violated group size: 0.283\n",
+ "iteration: 53, error: 0.4372641509433963, fairness violation: 0.003908566037735848, violated group size: 0.217\n",
+ "iteration: 54, error: 0.4365555555555556, fairness violation: 0.003956740740740741, violated group size: 0.283\n",
+ "iteration: 55, error: 0.4358181818181819, fairness violation: 0.004003163636363636, violated group size: 0.217\n",
+ "iteration: 56, error: 0.4351071428571429, fairness violation: 0.004047928571428571, violated group size: 0.217\n",
+ "iteration: 57, error: 0.4344736842105262, fairness violation: 0.004091122807017543, violated group size: 0.217\n",
+ "iteration: 58, error: 0.43381034482758624, fairness violation: 0.004132827586206895, violated group size: 0.217\n",
+ "iteration: 59, error: 0.4331694915254237, fairness violation: 0.0041731186440677965, violated group size: 0.283\n",
+ "iteration: 60, error: 0.43254999999999993, fairness violation: 0.004212066666666666, violated group size: 0.217\n",
+ "iteration: 61, error: 0.4319508196721312, fairness violation: 0.004249737704918031, violated group size: 0.217\n",
+ "iteration: 62, error: 0.4313709677419356, fairness violation: 0.004286193548387096, violated group size: 0.217\n",
+ "iteration: 63, error: 0.43080952380952386, fairness violation: 0.004321492063492062, violated group size: 0.283\n",
+ "iteration: 64, error: 0.430265625, fairness violation: 0.004355687499999999, violated group size: 0.283\n",
+ "iteration: 65, error: 0.4297384615384615, fairness violation: 0.004388830769230769, violated group size: 0.283\n",
+ "iteration: 66, error: 0.42922727272727274, fairness violation: 0.004420969696969697, violated group size: 0.217\n",
+ "iteration: 67, error: 0.42873134328358203, fairness violation: 0.004452149253731343, violated group size: 0.217\n",
+ "iteration: 68, error: 0.42824999999999996, fairness violation: 0.0044824117647058815, violated group size: 0.283\n",
+ "iteration: 69, error: 0.42778260869565227, fairness violation: 0.004511797101449274, violated group size: 0.217\n",
+ "iteration: 70, error: 0.42732857142857145, fairness violation: 0.004540342857142856, violated group size: 0.283\n",
+ "iteration: 71, error: 0.42688732394366197, fairness violation: 0.004568084507042252, violated group size: 0.217\n",
+ "iteration: 72, error: 0.4264583333333332, fairness violation: 0.004595055555555555, violated group size: 0.283\n",
+ "iteration: 73, error: 0.42604109589041106, fairness violation: 0.004621287671232876, violated group size: 0.217\n",
+ "iteration: 74, error: 0.4256351351351351, fairness violation: 0.0046468108108108095, violated group size: 0.283\n",
+ "iteration: 75, error: 0.42524, fairness violation: 0.004671653333333331, violated group size: 0.217\n",
+ "iteration: 76, error: 0.4248552631578947, fairness violation: 0.004695842105263155, violated group size: 0.217\n",
+ "iteration: 77, error: 0.42448051948051946, fairness violation: 0.004719402597402596, violated group size: 0.217\n",
+ "iteration: 78, error: 0.4239871794871795, fairness violation: 0.00475905128205128, violated group size: 0.217\n",
+ "iteration: 79, error: 0.42363291139240505, fairness violation: 0.004781215189873418, violated group size: 0.283\n",
+ "iteration: 80, error: 0.42328750000000015, fairness violation: 0.004802824999999999, violated group size: 0.283\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 81, error: 0.4229506172839506, fairness violation: 0.004823901234567901, violated group size: 0.283\n",
+ "iteration: 82, error: 0.4226219512195123, fairness violation: 0.004844463414634145, violated group size: 0.217\n",
+ "iteration: 83, error: 0.4221807228915662, fairness violation: 0.004880216867469879, violated group size: 0.217\n",
+ "iteration: 84, error: 0.42175, fairness violation: 0.004915119047619047, violated group size: 0.217\n",
+ "iteration: 85, error: 0.4214470588235294, fairness violation: 0.004933882352941174, violated group size: 0.217\n",
+ "iteration: 86, error: 0.4210348837209302, fairness violation: 0.004967348837209301, violated group size: 0.217\n",
+ "iteration: 87, error: 0.420632183908046, fairness violation: 0.005000045977011494, violated group size: 0.283\n",
+ "iteration: 88, error: 0.42035227272727277, fairness violation: 0.0050172045454545434, violated group size: 0.217\n",
+ "iteration: 89, error: 0.4200786516853933, fairness violation: 0.005033977528089887, violated group size: 0.217\n",
+ "iteration: 90, error: 0.4198111111111112, fairness violation: 0.005050377777777776, violated group size: 0.283\n",
+ "iteration: 91, error: 0.4195824175824176, fairness violation: 0.0050664175824175805, violated group size: 0.217\n",
+ "iteration: 92, error: 0.4193695652173913, fairness violation: 0.005082108695652173, violated group size: 0.217\n",
+ "iteration: 93, error: 0.41916129032258065, fairness violation: 0.005097462365591397, violated group size: 0.217\n",
+ "iteration: 94, error: 0.41895744680851066, fairness violation: 0.005112489361702126, violated group size: 0.217\n",
+ "iteration: 95, error: 0.41875789473684216, fairness violation: 0.005127199999999998, violated group size: 0.217\n",
+ "iteration: 96, error: 0.41856250000000006, fairness violation: 0.005141604166666665, violated group size: 0.283\n",
+ "iteration: 97, error: 0.418979381443299, fairness violation: 0.005106494845360823, violated group size: 0.217\n",
+ "iteration: 98, error: 0.41938775510204085, fairness violation: 0.005072102040816325, violated group size: 0.217\n",
+ "iteration: 99, error: 0.4197878787878788, fairness violation: 0.0050384040404040376, violated group size: 0.217\n",
+ "iteration: 100, error: 0.42018000000000005, fairness violation: 0.0050053799999999985, violated group size: 0.217\n",
+ "iteration: 101, error: 0.42056435643564366, fairness violation: 0.004973009900990098, violated group size: 0.217\n",
+ "iteration: 102, error: 0.42094117647058826, fairness violation: 0.00494127450980392, violated group size: 0.217\n",
+ "iteration: 103, error: 0.4213106796116506, fairness violation: 0.004910155339805824, violated group size: 0.217\n",
+ "iteration: 104, error: 0.4216730769230769, fairness violation: 0.004879634615384614, violated group size: 0.217\n",
+ "iteration: 105, error: 0.4220285714285715, fairness violation: 0.004849695238095237, violated group size: 0.217\n",
+ "iteration: 106, error: 0.4223773584905662, fairness violation: 0.004820320754716981, violated group size: 0.283\n",
+ "iteration: 107, error: 0.42271962616822434, fairness violation: 0.004791495327102803, violated group size: 0.217\n",
+ "iteration: 108, error: 0.4230555555555556, fairness violation: 0.0047632037037037035, violated group size: 0.217\n",
+ "iteration: 109, error: 0.4233853211009175, fairness violation: 0.00473543119266055, violated group size: 0.217\n",
+ "iteration: 110, error: 0.4237090909090908, fairness violation: 0.004708163636363636, violated group size: 0.217\n",
+ "iteration: 111, error: 0.424027027027027, fairness violation: 0.004681387387387387, violated group size: 0.283\n",
+ "iteration: 112, error: 0.42433928571428586, fairness violation: 0.004655089285714286, violated group size: 0.283\n",
+ "iteration: 113, error: 0.4241238938053097, fairness violation: 0.004671504424778761, violated group size: 0.217\n",
+ "iteration: 114, error: 0.42442982456140343, fairness violation: 0.004645754385964912, violated group size: 0.283\n",
+ "iteration: 115, error: 0.42473043478260875, fairness violation: 0.0046204521739130425, violated group size: 0.283\n",
+ "iteration: 116, error: 0.42502586206896553, fairness violation: 0.0045955862068965524, violated group size: 0.283\n",
+ "iteration: 117, error: 0.42481196581196584, fairness violation: 0.004611948717948717, violated group size: 0.217\n",
+ "iteration: 118, error: 0.4251016949152542, fairness violation: 0.004587576271186439, violated group size: 0.217\n",
+ "iteration: 119, error: 0.42489075630252104, fairness violation: 0.004603731092436974, violated group size: 0.217\n",
+ "iteration: 120, error: 0.4251750000000001, fairness violation: 0.0045798333333333325, violated group size: 0.217\n",
+ "iteration: 121, error: 0.4249669421487604, fairness violation: 0.004595785123966942, violated group size: 0.283\n",
+ "iteration: 122, error: 0.4247622950819671, fairness violation: 0.0046114754098360656, violated group size: 0.217\n",
+ "iteration: 123, error: 0.42456097560975614, fairness violation: 0.00462691056910569, violated group size: 0.217\n",
+ "iteration: 124, error: 0.42436290322580644, fairness violation: 0.004642096774193548, violated group size: 0.217\n",
+ "iteration: 125, error: 0.4241680000000001, fairness violation: 0.00465704, violated group size: 0.217\n",
+ "iteration: 126, error: 0.4239761904761905, fairness violation: 0.004671746031746031, violated group size: 0.217\n",
+ "iteration: 127, error: 0.42425196850393704, fairness violation: 0.004648629921259842, violated group size: 0.217\n",
+ "iteration: 128, error: 0.4240625, fairness violation: 0.004663171874999999, violated group size: 0.217\n",
+ "iteration: 129, error: 0.4238759689922481, fairness violation: 0.004677488372093024, violated group size: 0.283\n",
+ "iteration: 130, error: 0.42369230769230776, fairness violation: 0.004691584615384614, violated group size: 0.217\n",
+ "iteration: 131, error: 0.42351145038167937, fairness violation: 0.004705465648854962, violated group size: 0.217\n",
+ "iteration: 132, error: 0.4233333333333333, fairness violation: 0.004719136363636364, violated group size: 0.283\n",
+ "iteration: 133, error: 0.423157894736842, fairness violation: 0.0047326015037594, violated group size: 0.217\n",
+ "iteration: 134, error: 0.4229850746268656, fairness violation: 0.004745865671641791, violated group size: 0.217\n",
+ "iteration: 135, error: 0.42281481481481475, fairness violation: 0.004758933333333335, violated group size: 0.283\n",
+ "iteration: 136, error: 0.4226470588235294, fairness violation: 0.004771808823529411, violated group size: 0.217\n",
+ "iteration: 137, error: 0.42248175182481745, fairness violation: 0.004784496350364964, violated group size: 0.283\n",
+ "iteration: 138, error: 0.42231884057971014, fairness violation: 0.004797000000000002, violated group size: 0.283\n",
+ "iteration: 139, error: 0.42215827338129497, fairness violation: 0.004809323741007196, violated group size: 0.283\n",
+ "iteration: 140, error: 0.42200000000000004, fairness violation: 0.004821471428571429, violated group size: 0.217\n",
+ "iteration: 141, error: 0.4218439716312057, fairness violation: 0.0048334468085106394, violated group size: 0.217\n",
+ "iteration: 142, error: 0.42169014084507045, fairness violation: 0.004845253521126761, violated group size: 0.283\n",
+ "iteration: 143, error: 0.4215384615384616, fairness violation: 0.004856895104895106, violated group size: 0.283\n",
+ "iteration: 144, error: 0.4213888888888888, fairness violation: 0.004868375, violated group size: 0.217\n",
+ "iteration: 145, error: 0.42124137931034483, fairness violation: 0.004879696551724138, violated group size: 0.217\n",
+ "iteration: 146, error: 0.4210958904109589, fairness violation: 0.00489086301369863, violated group size: 0.217\n",
+ "iteration: 147, error: 0.4209523809523809, fairness violation: 0.004901877551020409, violated group size: 0.217\n",
+ "iteration: 148, error: 0.42081081081081084, fairness violation: 0.004912743243243244, violated group size: 0.217\n",
+ "iteration: 149, error: 0.42067114093959734, fairness violation: 0.004923463087248323, violated group size: 0.283\n",
+ "iteration: 150, error: 0.4205333333333334, fairness violation: 0.004934040000000001, violated group size: 0.217\n",
+ "iteration: 151, error: 0.4203973509933776, fairness violation: 0.004944476821192053, violated group size: 0.217\n",
+ "iteration: 152, error: 0.4202631578947368, fairness violation: 0.0049547763157894754, violated group size: 0.283\n",
+ "iteration: 153, error: 0.4201307189542483, fairness violation: 0.00496494117647059, violated group size: 0.283\n",
+ "iteration: 154, error: 0.42, fairness violation: 0.004974974025974027, violated group size: 0.283\n",
+ "iteration: 155, error: 0.4198709677419355, fairness violation: 0.0049848774193548395, violated group size: 0.217\n",
+ "iteration: 156, error: 0.4197435897435898, fairness violation: 0.004994653846153847, violated group size: 0.217\n",
+ "iteration: 157, error: 0.4196178343949045, fairness violation: 0.0050043057324840766, violated group size: 0.217\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 158, error: 0.4194936708860761, fairness violation: 0.005013835443037975, violated group size: 0.217\n",
+ "iteration: 159, error: 0.41937106918239, fairness violation: 0.005023245283018869, violated group size: 0.283\n",
+ "iteration: 160, error: 0.41925000000000007, fairness violation: 0.0050325375, violated group size: 0.217\n",
+ "iteration: 161, error: 0.4191304347826087, fairness violation: 0.005041714285714285, violated group size: 0.217\n",
+ "iteration: 162, error: 0.41901234567901235, fairness violation: 0.005050777777777778, violated group size: 0.283\n",
+ "iteration: 163, error: 0.41889570552147243, fairness violation: 0.005059730061349694, violated group size: 0.283\n",
+ "iteration: 164, error: 0.4191402439024391, fairness violation: 0.005039463414634148, violated group size: 0.283\n",
+ "iteration: 165, error: 0.41938181818181824, fairness violation: 0.005019442424242424, violated group size: 0.217\n",
+ "iteration: 166, error: 0.4192650602409639, fairness violation: 0.005028421686746988, violated group size: 0.217\n",
+ "iteration: 167, error: 0.41950299401197605, fairness violation: 0.005008706586826348, violated group size: 0.217\n",
+ "iteration: 168, error: 0.41973809523809513, fairness violation: 0.004989226190476189, violated group size: 0.217\n",
+ "iteration: 169, error: 0.41997041420118336, fairness violation: 0.0049699763313609474, violated group size: 0.283\n",
+ "iteration: 170, error: 0.4202, fairness violation: 0.00495095294117647, violated group size: 0.283\n",
+ "iteration: 171, error: 0.4204269005847953, fairness violation: 0.004932152046783625, violated group size: 0.217\n",
+ "iteration: 172, error: 0.42065116279069764, fairness violation: 0.00491356976744186, violated group size: 0.217\n",
+ "iteration: 173, error: 0.4208728323699421, fairness violation: 0.004895202312138728, violated group size: 0.217\n",
+ "iteration: 174, error: 0.42109195402298855, fairness violation: 0.004877045977011494, violated group size: 0.217\n",
+ "iteration: 175, error: 0.4213085714285715, fairness violation: 0.004859097142857142, violated group size: 0.217\n",
+ "iteration: 176, error: 0.42152272727272727, fairness violation: 0.0048413522727272715, violated group size: 0.217\n",
+ "iteration: 177, error: 0.42173446327683617, fairness violation: 0.00482380790960452, violated group size: 0.217\n",
+ "iteration: 178, error: 0.42161235955056187, fairness violation: 0.004833280898876404, violated group size: 0.217\n",
+ "iteration: 179, error: 0.42182122905027924, fairness violation: 0.004815977653631285, violated group size: 0.217\n",
+ "iteration: 180, error: 0.4220277777777778, fairness violation: 0.004798866666666665, violated group size: 0.217\n",
+ "iteration: 181, error: 0.42223204419889504, fairness violation: 0.004781944751381214, violated group size: 0.283\n",
+ "iteration: 182, error: 0.4224340659340659, fairness violation: 0.004765208791208789, violated group size: 0.217\n",
+ "iteration: 183, error: 0.4226338797814208, fairness violation: 0.004748655737704917, violated group size: 0.217\n",
+ "iteration: 184, error: 0.4228315217391304, fairness violation: 0.004732282608695651, violated group size: 0.217\n",
+ "iteration: 185, error: 0.4230270270270271, fairness violation: 0.004716086486486487, violated group size: 0.217\n",
+ "iteration: 186, error: 0.4229032258064515, fairness violation: 0.0047257311827957, violated group size: 0.283\n",
+ "iteration: 187, error: 0.4230962566844919, fairness violation: 0.0047097433155080205, violated group size: 0.217\n",
+ "iteration: 188, error: 0.4229734042553191, fairness violation: 0.00471931914893617, violated group size: 0.283\n",
+ "iteration: 189, error: 0.4231640211640213, fairness violation: 0.004703534391534391, violated group size: 0.217\n",
+ "iteration: 190, error: 0.42304210526315794, fairness violation: 0.004713042105263158, violated group size: 0.217\n",
+ "iteration: 191, error: 0.4232303664921467, fairness violation: 0.004697455497382198, violated group size: 0.217\n",
+ "iteration: 192, error: 0.42310937499999995, fairness violation: 0.004706895833333333, violated group size: 0.217\n",
+ "iteration: 193, error: 0.42329533678756476, fairness violation: 0.004691502590673575, violated group size: 0.283\n",
+ "iteration: 194, error: 0.4231752577319588, fairness violation: 0.004700876288659792, violated group size: 0.217\n",
+ "iteration: 195, error: 0.4230564102564103, fairness violation: 0.004710153846153845, violated group size: 0.217\n",
+ "iteration: 196, error: 0.4229387755102041, fairness violation: 0.004719336734693878, violated group size: 0.283\n",
+ "iteration: 197, error: 0.4228223350253807, fairness violation: 0.004728426395939086, violated group size: 0.283\n",
+ "iteration: 198, error: 0.4227070707070707, fairness violation: 0.004737424242424242, violated group size: 0.217\n",
+ "iteration: 199, error: 0.4228894472361809, fairness violation: 0.004722341708542713, violated group size: 0.217\n",
+ "iteration: 200, error: 0.42277499999999996, fairness violation: 0.004731279999999999, violated group size: 0.217\n",
+ "iteration: 201, error: 0.4226616915422886, fairness violation: 0.004740129353233829, violated group size: 0.217\n",
+ "iteration: 202, error: 0.4225495049504951, fairness violation: 0.00474889108910891, violated group size: 0.217\n",
+ "iteration: 203, error: 0.42243842364532025, fairness violation: 0.004757566502463053, violated group size: 0.283\n",
+ "iteration: 204, error: 0.42232843137254905, fairness violation: 0.004766156862745097, violated group size: 0.283\n",
+ "iteration: 205, error: 0.422219512195122, fairness violation: 0.004774663414634145, violated group size: 0.217\n",
+ "iteration: 206, error: 0.422111650485437, fairness violation: 0.004783087378640775, violated group size: 0.217\n",
+ "iteration: 207, error: 0.4220048309178744, fairness violation: 0.00479142995169082, violated group size: 0.217\n",
+ "iteration: 208, error: 0.4218990384615385, fairness violation: 0.004799692307692306, violated group size: 0.217\n",
+ "iteration: 209, error: 0.42179425837320567, fairness violation: 0.004807875598086124, violated group size: 0.217\n",
+ "iteration: 210, error: 0.4216904761904762, fairness violation: 0.004815980952380952, violated group size: 0.217\n",
+ "iteration: 211, error: 0.42158767772511846, fairness violation: 0.0048240094786729856, violated group size: 0.217\n",
+ "iteration: 212, error: 0.42148584905660386, fairness violation: 0.004831962264150944, violated group size: 0.217\n",
+ "iteration: 213, error: 0.4213849765258215, fairness violation: 0.004839840375586855, violated group size: 0.283\n",
+ "iteration: 214, error: 0.421285046728972, fairness violation: 0.0048476448598130835, violated group size: 0.217\n",
+ "iteration: 215, error: 0.42118604651162794, fairness violation: 0.004855376744186045, violated group size: 0.217\n",
+ "iteration: 216, error: 0.4210879629629629, fairness violation: 0.004863037037037037, violated group size: 0.283\n",
+ "iteration: 217, error: 0.42099078341013824, fairness violation: 0.004870626728110601, violated group size: 0.283\n",
+ "iteration: 218, error: 0.42089449541284396, fairness violation: 0.004878146788990825, violated group size: 0.217\n",
+ "iteration: 219, error: 0.42079908675799094, fairness violation: 0.004885598173515983, violated group size: 0.283\n",
+ "iteration: 220, error: 0.4207045454545455, fairness violation: 0.004892981818181818, violated group size: 0.283\n",
+ "iteration: 221, error: 0.4206108597285068, fairness violation: 0.004900298642533936, violated group size: 0.283\n",
+ "iteration: 222, error: 0.4205180180180179, fairness violation: 0.004907549549549549, violated group size: 0.217\n",
+ "iteration: 223, error: 0.4204260089686098, fairness violation: 0.004914735426008968, violated group size: 0.217\n",
+ "iteration: 224, error: 0.4203348214285714, fairness violation: 0.004921857142857143, violated group size: 0.283\n",
+ "iteration: 225, error: 0.4202444444444444, fairness violation: 0.004928915555555555, violated group size: 0.217\n",
+ "iteration: 226, error: 0.42015486725663725, fairness violation: 0.004935911504424777, violated group size: 0.217\n",
+ "iteration: 227, error: 0.4200660792951542, fairness violation: 0.004942845814977973, violated group size: 0.217\n",
+ "iteration: 228, error: 0.4199780701754386, fairness violation: 0.004949719298245614, violated group size: 0.217\n",
+ "iteration: 229, error: 0.4198908296943231, fairness violation: 0.004956532751091703, violated group size: 0.283\n",
+ "iteration: 230, error: 0.419804347826087, fairness violation: 0.004963286956521739, violated group size: 0.283\n",
+ "iteration: 231, error: 0.4197186147186147, fairness violation: 0.004969982683982686, violated group size: 0.283\n",
+ "iteration: 232, error: 0.4196336206896552, fairness violation: 0.004976620689655175, violated group size: 0.283\n",
+ "iteration: 233, error: 0.419549356223176, fairness violation: 0.004983201716738197, violated group size: 0.283\n",
+ "iteration: 234, error: 0.419465811965812, fairness violation: 0.0049897264957264945, violated group size: 0.217\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 235, error: 0.4193829787234044, fairness violation: 0.00499619574468085, violated group size: 0.217\n",
+ "iteration: 236, error: 0.41930084745762713, fairness violation: 0.005002610169491525, violated group size: 0.283\n",
+ "iteration: 237, error: 0.4192194092827004, fairness violation: 0.005008970464135021, violated group size: 0.283\n",
+ "iteration: 238, error: 0.41913865546218493, fairness violation: 0.00501527731092437, violated group size: 0.283\n",
+ "iteration: 239, error: 0.4190585774058578, fairness violation: 0.005021531380753138, violated group size: 0.283\n",
+ "iteration: 240, error: 0.4189791666666667, fairness violation: 0.005027733333333333, violated group size: 0.217\n",
+ "iteration: 241, error: 0.41890041493775937, fairness violation: 0.005033883817427385, violated group size: 0.217\n",
+ "iteration: 242, error: 0.4188223140495868, fairness violation: 0.0050399834710743805, violated group size: 0.217\n",
+ "iteration: 243, error: 0.4187448559670781, fairness violation: 0.005046032921810699, violated group size: 0.217\n",
+ "iteration: 244, error: 0.41890983606557386, fairness violation: 0.005032467213114753, violated group size: 0.217\n",
+ "iteration: 245, error: 0.419073469387755, fairness violation: 0.005019012244897959, violated group size: 0.217\n",
+ "iteration: 246, error: 0.41923577235772364, fairness violation: 0.005005666666666667, violated group size: 0.283\n",
+ "iteration: 247, error: 0.4191578947368422, fairness violation: 0.005011757085020243, violated group size: 0.217\n",
+ "iteration: 248, error: 0.41908064516129034, fairness violation: 0.005017798387096774, violated group size: 0.283\n",
+ "iteration: 249, error: 0.4192409638554217, fairness violation: 0.005004618473895581, violated group size: 0.217\n",
+ "iteration: 250, error: 0.4194000000000001, fairness violation: 0.004991544, violated group size: 0.217\n",
+ "iteration: 251, error: 0.4195577689243028, fairness violation: 0.004978573705179282, violated group size: 0.283\n",
+ "iteration: 252, error: 0.4197142857142857, fairness violation: 0.004965706349206349, violated group size: 0.217\n",
+ "iteration: 253, error: 0.4198695652173912, fairness violation: 0.00495294071146245, violated group size: 0.217\n",
+ "iteration: 254, error: 0.42002362204724414, fairness violation: 0.00494027559055118, violated group size: 0.217\n",
+ "iteration: 255, error: 0.4201764705882353, fairness violation: 0.004927709803921568, violated group size: 0.217\n",
+ "iteration: 256, error: 0.420328125, fairness violation: 0.004915242187499999, violated group size: 0.217\n",
+ "iteration: 257, error: 0.42047859922178993, fairness violation: 0.004902871595330739, violated group size: 0.217\n",
+ "iteration: 258, error: 0.42062790697674424, fairness violation: 0.004890596899224807, violated group size: 0.217\n",
+ "iteration: 259, error: 0.4207760617760617, fairness violation: 0.004878416988416988, violated group size: 0.283\n",
+ "iteration: 260, error: 0.42069615384615383, fairness violation: 0.004884692307692307, violated group size: 0.217\n",
+ "iteration: 261, error: 0.42061685823754785, fairness violation: 0.004890919540229886, violated group size: 0.283\n",
+ "iteration: 262, error: 0.42076335877862603, fairness violation: 0.004878877862595419, violated group size: 0.217\n",
+ "iteration: 263, error: 0.4209087452471483, fairness violation: 0.004866927756653992, violated group size: 0.217\n",
+ "iteration: 264, error: 0.4208295454545454, fairness violation: 0.004873151515151516, violated group size: 0.283\n",
+ "iteration: 265, error: 0.4209735849056603, fairness violation: 0.004861313207547171, violated group size: 0.283\n",
+ "iteration: 266, error: 0.4211165413533834, fairness violation: 0.004849563909774436, violated group size: 0.217\n",
+ "iteration: 267, error: 0.42125842696629223, fairness violation: 0.0048379026217228475, violated group size: 0.217\n",
+ "iteration: 268, error: 0.42139925373134324, fairness violation: 0.004826328358208956, violated group size: 0.283\n",
+ "iteration: 269, error: 0.42131970260223045, fairness violation: 0.0048325873605947955, violated group size: 0.217\n",
+ "iteration: 270, error: 0.42145925925925937, fairness violation: 0.00482111851851852, violated group size: 0.283\n",
+ "iteration: 271, error: 0.42159778597785963, fairness violation: 0.004809734317343174, violated group size: 0.217\n",
+ "iteration: 272, error: 0.42173529411764704, fairness violation: 0.004798433823529413, violated group size: 0.283\n",
+ "iteration: 273, error: 0.4216556776556777, fairness violation: 0.004804703296703296, violated group size: 0.217\n",
+ "iteration: 274, error: 0.42157664233576647, fairness violation: 0.00481092700729927, violated group size: 0.217\n",
+ "iteration: 275, error: 0.42171272727272724, fairness violation: 0.004799745454545455, violated group size: 0.217\n",
+ "iteration: 276, error: 0.42184782608695653, fairness violation: 0.004788644927536233, violated group size: 0.217\n",
+ "iteration: 277, error: 0.4219819494584837, fairness violation: 0.004777624548736462, violated group size: 0.217\n",
+ "iteration: 278, error: 0.42190287769784157, fairness violation: 0.004783856115107913, violated group size: 0.217\n",
+ "iteration: 279, error: 0.42182437275985674, fairness violation: 0.004790043010752689, violated group size: 0.217\n",
+ "iteration: 280, error: 0.42174642857142863, fairness violation: 0.004796185714285715, violated group size: 0.283\n",
+ "iteration: 281, error: 0.42166903914590753, fairness violation: 0.0048022846975088965, violated group size: 0.283\n",
+ "iteration: 282, error: 0.4218014184397163, fairness violation: 0.00479141134751773, violated group size: 0.217\n",
+ "iteration: 283, error: 0.42172438162544174, fairness violation: 0.0047974840989399295, violated group size: 0.217\n",
+ "iteration: 284, error: 0.4216478873239437, fairness violation: 0.004803514084507042, violated group size: 0.217\n",
+ "iteration: 285, error: 0.42157192982456126, fairness violation: 0.004809501754385964, violated group size: 0.217\n",
+ "iteration: 286, error: 0.4217027972027972, fairness violation: 0.004798755244755245, violated group size: 0.217\n",
+ "iteration: 287, error: 0.4218327526132404, fairness violation: 0.004788083623693379, violated group size: 0.283\n",
+ "iteration: 288, error: 0.4219618055555556, fairness violation: 0.004777486111111113, violated group size: 0.283\n",
+ "iteration: 289, error: 0.4218858131487888, fairness violation: 0.004783480968858131, violated group size: 0.217\n",
+ "iteration: 290, error: 0.4218103448275861, fairness violation: 0.004789434482758621, violated group size: 0.217\n",
+ "iteration: 291, error: 0.42193814432989696, fairness violation: 0.004778941580756014, violated group size: 0.283\n",
+ "iteration: 292, error: 0.42186301369863016, fairness violation: 0.0047848698630136985, violated group size: 0.217\n",
+ "iteration: 293, error: 0.4217883959044368, fairness violation: 0.004790757679180888, violated group size: 0.217\n",
+ "iteration: 294, error: 0.42171428571428576, fairness violation: 0.004796605442176871, violated group size: 0.217\n",
+ "iteration: 295, error: 0.4216406779661017, fairness violation: 0.004802413559322035, violated group size: 0.217\n",
+ "iteration: 296, error: 0.4215675675675675, fairness violation: 0.004808182432432432, violated group size: 0.217\n",
+ "iteration: 297, error: 0.4214949494949495, fairness violation: 0.00481391245791246, violated group size: 0.283\n",
+ "iteration: 298, error: 0.4214228187919464, fairness violation: 0.0048196040268456385, violated group size: 0.217\n",
+ "iteration: 299, error: 0.42135117056856186, fairness violation: 0.0048252575250836115, violated group size: 0.217\n",
+ "iteration: 300, error: 0.42128, fairness violation: 0.004830873333333335, violated group size: 0.283\n",
+ "iteration: 301, error: 0.42120930232558146, fairness violation: 0.004836451827242525, violated group size: 0.217\n",
+ "iteration: 302, error: 0.42113907284768215, fairness violation: 0.004841993377483444, violated group size: 0.217\n",
+ "iteration: 303, error: 0.42106930693069305, fairness violation: 0.004847498349834984, violated group size: 0.217\n",
+ "iteration: 304, error: 0.4211940789473684, fairness violation: 0.004837263157894738, violated group size: 0.283\n",
+ "iteration: 305, error: 0.4211245901639345, fairness violation: 0.004842747540983607, violated group size: 0.283\n",
+ "iteration: 306, error: 0.4210555555555555, fairness violation: 0.004848196078431373, violated group size: 0.217\n",
+ "iteration: 307, error: 0.42098697068403895, fairness violation: 0.004853609120521175, violated group size: 0.283\n",
+ "iteration: 308, error: 0.4209188311688312, fairness violation: 0.004858987012987015, violated group size: 0.283\n",
+ "iteration: 309, error: 0.42085113268608415, fairness violation: 0.00486433009708738, violated group size: 0.283\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 310, error: 0.4207838709677419, fairness violation: 0.00486963870967742, violated group size: 0.283\n",
+ "iteration: 311, error: 0.420717041800643, fairness violation: 0.004874913183279744, violated group size: 0.217\n",
+ "iteration: 312, error: 0.42065064102564104, fairness violation: 0.0048801538461538466, violated group size: 0.217\n",
+ "iteration: 313, error: 0.42058466453674115, fairness violation: 0.0048853610223642185, violated group size: 0.217\n",
+ "iteration: 314, error: 0.42051910828025474, fairness violation: 0.004890535031847135, violated group size: 0.217\n",
+ "iteration: 315, error: 0.4204539682539683, fairness violation: 0.004895676190476191, violated group size: 0.217\n",
+ "iteration: 316, error: 0.4203892405063292, fairness violation: 0.004900784810126583, violated group size: 0.217\n",
+ "iteration: 317, error: 0.4203249211356468, fairness violation: 0.004905861198738172, violated group size: 0.283\n",
+ "iteration: 318, error: 0.4202610062893082, fairness violation: 0.00491090566037736, violated group size: 0.217\n",
+ "iteration: 319, error: 0.4201974921630094, fairness violation: 0.004915918495297806, violated group size: 0.217\n",
+ "iteration: 320, error: 0.4201343749999999, fairness violation: 0.004920900000000002, violated group size: 0.217\n",
+ "iteration: 321, error: 0.4200716510903427, fairness violation: 0.004925850467289721, violated group size: 0.217\n",
+ "iteration: 322, error: 0.4200093167701862, fairness violation: 0.0049307701863354056, violated group size: 0.283\n",
+ "iteration: 323, error: 0.4199473684210526, fairness violation: 0.00493565944272446, violated group size: 0.283\n",
+ "iteration: 324, error: 0.41988580246913576, fairness violation: 0.004940518518518519, violated group size: 0.217\n",
+ "iteration: 325, error: 0.41982461538461535, fairness violation: 0.004945347692307694, violated group size: 0.217\n",
+ "iteration: 326, error: 0.4197638036809816, fairness violation: 0.004950147239263805, violated group size: 0.283\n",
+ "iteration: 327, error: 0.4197033639143731, fairness violation: 0.004954917431192661, violated group size: 0.283\n",
+ "iteration: 328, error: 0.4196432926829268, fairness violation: 0.004959658536585366, violated group size: 0.217\n",
+ "iteration: 329, error: 0.41958358662613987, fairness violation: 0.004964370820668694, violated group size: 0.283\n",
+ "iteration: 330, error: 0.41952424242424236, fairness violation: 0.004969054545454545, violated group size: 0.217\n",
+ "iteration: 331, error: 0.41946525679758306, fairness violation: 0.00497370996978852, violated group size: 0.217\n",
+ "iteration: 332, error: 0.41940662650602417, fairness violation: 0.004978337349397591, violated group size: 0.217\n",
+ "iteration: 333, error: 0.4193483483483482, fairness violation: 0.004982936936936937, violated group size: 0.217\n",
+ "iteration: 334, error: 0.4192904191616766, fairness violation: 0.004987508982035928, violated group size: 0.217\n",
+ "iteration: 335, error: 0.4192328358208956, fairness violation: 0.004992053731343284, violated group size: 0.283\n",
+ "iteration: 336, error: 0.4191755952380953, fairness violation: 0.00499657142857143, violated group size: 0.283\n",
+ "iteration: 337, error: 0.4191186943620178, fairness violation: 0.0050010623145400595, violated group size: 0.217\n",
+ "iteration: 338, error: 0.41906213017751476, fairness violation: 0.005005526627218935, violated group size: 0.217\n",
+ "iteration: 339, error: 0.4190058997050148, fairness violation: 0.005009964601769911, violated group size: 0.217\n",
+ "iteration: 340, error: 0.41894999999999993, fairness violation: 0.005014376470588236, violated group size: 0.283\n",
+ "iteration: 341, error: 0.41889442815249267, fairness violation: 0.005018762463343108, violated group size: 0.217\n",
+ "iteration: 342, error: 0.41883918128654973, fairness violation: 0.005023122807017544, violated group size: 0.217\n",
+ "iteration: 343, error: 0.41878425655976675, fairness violation: 0.0050274577259475225, violated group size: 0.283\n",
+ "iteration: 344, error: 0.4187296511627907, fairness violation: 0.005031767441860465, violated group size: 0.217\n",
+ "iteration: 345, error: 0.4186753623188406, fairness violation: 0.005036052173913045, violated group size: 0.283\n",
+ "iteration: 346, error: 0.4186213872832369, fairness violation: 0.005040312138728323, violated group size: 0.217\n",
+ "iteration: 347, error: 0.41856772334293946, fairness violation: 0.005044547550432276, violated group size: 0.283\n",
+ "iteration: 348, error: 0.41851436781609197, fairness violation: 0.005048758620689655, violated group size: 0.217\n",
+ "iteration: 349, error: 0.418461318051576, fairness violation: 0.005052945558739255, violated group size: 0.283\n",
+ "iteration: 350, error: 0.4185771428571428, fairness violation: 0.005043468571428572, violated group size: 0.283\n",
+ "iteration: 351, error: 0.4186923076923077, fairness violation: 0.005034045584045584, violated group size: 0.217\n",
+ "iteration: 352, error: 0.4188068181818182, fairness violation: 0.005024676136363637, violated group size: 0.283\n",
+ "iteration: 353, error: 0.4189206798866855, fairness violation: 0.005015359773371105, violated group size: 0.217\n",
+ "iteration: 354, error: 0.41903389830508475, fairness violation: 0.005006096045197741, violated group size: 0.283\n",
+ "iteration: 355, error: 0.41914647887323936, fairness violation: 0.004996884507042254, violated group size: 0.283\n",
+ "iteration: 356, error: 0.4192584269662922, fairness violation: 0.004987724719101122, violated group size: 0.217\n",
+ "iteration: 357, error: 0.41936974789915965, fairness violation: 0.0049786162464986, violated group size: 0.217\n",
+ "iteration: 358, error: 0.41948044692737424, fairness violation: 0.004969558659217878, violated group size: 0.217\n",
+ "iteration: 359, error: 0.41959052924791085, fairness violation: 0.004960551532033426, violated group size: 0.283\n",
+ "iteration: 360, error: 0.4195361111111111, fairness violation: 0.004964855555555557, violated group size: 0.283\n",
+ "iteration: 361, error: 0.4196454293628808, fairness violation: 0.004955911357340723, violated group size: 0.283\n",
+ "iteration: 362, error: 0.4197541436464089, fairness violation: 0.004947016574585636, violated group size: 0.217\n",
+ "iteration: 363, error: 0.4198622589531681, fairness violation: 0.004938170798898072, violated group size: 0.283\n",
+ "iteration: 364, error: 0.41996978021978026, fairness violation: 0.004929373626373626, violated group size: 0.217\n",
+ "iteration: 365, error: 0.42007671232876714, fairness violation: 0.004920624657534246, violated group size: 0.217\n",
+ "iteration: 366, error: 0.42018306010928963, fairness violation: 0.004911923497267759, violated group size: 0.217\n",
+ "iteration: 367, error: 0.4202888283378746, fairness violation: 0.004903269754768393, violated group size: 0.217\n",
+ "iteration: 368, error: 0.42039402173913043, fairness violation: 0.00489466304347826, violated group size: 0.217\n",
+ "iteration: 369, error: 0.4204986449864499, fairness violation: 0.00488610298102981, violated group size: 0.283\n",
+ "iteration: 370, error: 0.4206027027027027, fairness violation: 0.0048775891891891885, violated group size: 0.217\n",
+ "iteration: 371, error: 0.4207061994609164, fairness violation: 0.004869121293800538, violated group size: 0.217\n",
+ "iteration: 372, error: 0.4208091397849463, fairness violation: 0.004860698924731182, violated group size: 0.217\n",
+ "iteration: 373, error: 0.420911528150134, fairness violation: 0.004852321715817694, violated group size: 0.217\n",
+ "iteration: 374, error: 0.420855614973262, fairness violation: 0.004856754010695187, violated group size: 0.217\n",
+ "iteration: 375, error: 0.4209573333333334, fairness violation: 0.004848432, violated group size: 0.217\n",
+ "iteration: 376, error: 0.42105851063829786, fairness violation: 0.004840154255319148, violated group size: 0.217\n",
+ "iteration: 377, error: 0.4211591511936339, fairness violation: 0.004831920424403182, violated group size: 0.217\n",
+ "iteration: 378, error: 0.4211031746031746, fairness violation: 0.004836359788359788, violated group size: 0.217\n",
+ "iteration: 379, error: 0.42120316622691284, fairness violation: 0.004828179419525066, violated group size: 0.217\n",
+ "iteration: 380, error: 0.42130263157894726, fairness violation: 0.004820042105263157, violated group size: 0.217\n",
+ "iteration: 381, error: 0.42124671916010503, fairness violation: 0.004824477690288715, violated group size: 0.283\n",
+ "iteration: 382, error: 0.42134554973821986, fairness violation: 0.004816392670157068, violated group size: 0.283\n",
+ "iteration: 383, error: 0.42144386422976504, fairness violation: 0.004808349869451696, violated group size: 0.217\n",
+ "iteration: 384, error: 0.42154166666666676, fairness violation: 0.004800348958333333, violated group size: 0.283\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 385, error: 0.4214857142857143, fairness violation: 0.00480478961038961, violated group size: 0.217\n",
+ "iteration: 386, error: 0.4215829015544041, fairness violation: 0.0047968393782383415, violated group size: 0.217\n",
+ "iteration: 387, error: 0.42152713178294576, fairness violation: 0.004801266149870801, violated group size: 0.283\n",
+ "iteration: 388, error: 0.4216237113402061, fairness violation: 0.004793365979381444, violated group size: 0.283\n",
+ "iteration: 389, error: 0.42171979434447304, fairness violation: 0.004785506426735219, violated group size: 0.217\n",
+ "iteration: 390, error: 0.4218153846153847, fairness violation: 0.00477768717948718, violated group size: 0.217\n",
+ "iteration: 391, error: 0.4219104859335038, fairness violation: 0.004769907928388747, violated group size: 0.217\n",
+ "iteration: 392, error: 0.4218545918367348, fairness violation: 0.00477434693877551, violated group size: 0.217\n",
+ "iteration: 393, error: 0.42194910941475816, fairness violation: 0.004766615776081424, violated group size: 0.217\n",
+ "iteration: 394, error: 0.42204314720812186, fairness violation: 0.004758923857868021, violated group size: 0.283\n",
+ "iteration: 395, error: 0.4221367088607595, fairness violation: 0.00475127088607595, violated group size: 0.217\n",
+ "iteration: 396, error: 0.42208080808080806, fairness violation: 0.004755712121212122, violated group size: 0.283\n",
+ "iteration: 397, error: 0.42202518891687657, fairness violation: 0.004760130982367758, violated group size: 0.217\n",
+ "iteration: 398, error: 0.42196984924623115, fairness violation: 0.004764527638190955, violated group size: 0.283\n",
+ "iteration: 399, error: 0.42191478696741846, fairness violation: 0.004768902255639098, violated group size: 0.217\n",
+ "iteration: 400, error: 0.42186, fairness violation: 0.004773255, violated group size: 0.217\n",
+ "iteration: 401, error: 0.4218054862842893, fairness violation: 0.004777586034912718, violated group size: 0.283\n",
+ "iteration: 402, error: 0.4217512437810945, fairness violation: 0.00478189552238806, violated group size: 0.217\n",
+ "iteration: 403, error: 0.4218436724565757, fairness violation: 0.0047743374689826305, violated group size: 0.283\n",
+ "iteration: 404, error: 0.42178960396039605, fairness violation: 0.004778633663366336, violated group size: 0.217\n",
+ "iteration: 405, error: 0.4217358024691357, fairness violation: 0.004782908641975308, violated group size: 0.217\n",
+ "iteration: 406, error: 0.4216822660098523, fairness violation: 0.004787162561576355, violated group size: 0.217\n",
+ "iteration: 407, error: 0.4216289926289926, fairness violation: 0.004791395577395577, violated group size: 0.283\n",
+ "iteration: 408, error: 0.421575980392157, fairness violation: 0.004795607843137254, violated group size: 0.217\n",
+ "iteration: 409, error: 0.4215232273838631, fairness violation: 0.004799799511002444, violated group size: 0.217\n",
+ "iteration: 410, error: 0.42147073170731714, fairness violation: 0.004803970731707317, violated group size: 0.283\n",
+ "iteration: 411, error: 0.4214184914841849, fairness violation: 0.0048081216545012165, violated group size: 0.217\n",
+ "iteration: 412, error: 0.4213665048543689, fairness violation: 0.004812252427184466, violated group size: 0.283\n",
+ "iteration: 413, error: 0.42131476997578693, fairness violation: 0.004816363196125908, violated group size: 0.217\n",
+ "iteration: 414, error: 0.42126328502415455, fairness violation: 0.004820454106280194, violated group size: 0.217\n",
+ "iteration: 415, error: 0.4212120481927711, fairness violation: 0.004824525301204821, violated group size: 0.283\n",
+ "iteration: 416, error: 0.42116105769230766, fairness violation: 0.004828576923076923, violated group size: 0.217\n",
+ "iteration: 417, error: 0.4211103117505996, fairness violation: 0.004832609112709832, violated group size: 0.283\n",
+ "iteration: 418, error: 0.42105980861244025, fairness violation: 0.004836622009569378, violated group size: 0.283\n",
+ "iteration: 419, error: 0.42100954653937933, fairness violation: 0.004840615751789977, violated group size: 0.217\n",
+ "iteration: 420, error: 0.42110000000000003, fairness violation: 0.0048332238095238084, violated group size: 0.217\n",
+ "iteration: 421, error: 0.42104988123515436, fairness violation: 0.004837206650831354, violated group size: 0.217\n",
+ "iteration: 422, error: 0.42100000000000004, fairness violation: 0.004841170616113744, violated group size: 0.217\n",
+ "iteration: 423, error: 0.420950354609929, fairness violation: 0.004845115839243499, violated group size: 0.217\n",
+ "iteration: 424, error: 0.42104009433962253, fairness violation: 0.004837783018867924, violated group size: 0.217\n",
+ "iteration: 425, error: 0.4209905882352941, fairness violation: 0.004841717647058822, violated group size: 0.217\n",
+ "iteration: 426, error: 0.42094131455399053, fairness violation: 0.004845633802816901, violated group size: 0.217\n",
+ "iteration: 427, error: 0.42089227166276344, fairness violation: 0.004849531615925057, violated group size: 0.217\n",
+ "iteration: 428, error: 0.4208434579439252, fairness violation: 0.004853411214953271, violated group size: 0.217\n",
+ "iteration: 429, error: 0.4207948717948717, fairness violation: 0.0048572727272727274, violated group size: 0.283\n",
+ "iteration: 430, error: 0.42074651162790694, fairness violation: 0.004861116279069767, violated group size: 0.217\n",
+ "iteration: 431, error: 0.4206983758700697, fairness violation: 0.004864941995359629, violated group size: 0.283\n",
+ "iteration: 432, error: 0.420650462962963, fairness violation: 0.00486875, violated group size: 0.217\n",
+ "iteration: 433, error: 0.42060277136258656, fairness violation: 0.0048725404157043874, violated group size: 0.217\n",
+ "iteration: 434, error: 0.42055529953917054, fairness violation: 0.0048763133640553, violated group size: 0.283\n",
+ "iteration: 435, error: 0.42050804597701147, fairness violation: 0.004880068965517241, violated group size: 0.217\n",
+ "iteration: 436, error: 0.4204610091743119, fairness violation: 0.004883807339449542, violated group size: 0.217\n",
+ "iteration: 437, error: 0.4204141876430207, fairness violation: 0.004887528604118992, violated group size: 0.217\n",
+ "iteration: 438, error: 0.42036757990867574, fairness violation: 0.004891232876712329, violated group size: 0.217\n",
+ "iteration: 439, error: 0.4203211845102506, fairness violation: 0.0048949202733485206, violated group size: 0.283\n",
+ "iteration: 440, error: 0.42027499999999995, fairness violation: 0.00489859090909091, violated group size: 0.283\n",
+ "iteration: 441, error: 0.42022902494331066, fairness violation: 0.004902244897959184, violated group size: 0.283\n",
+ "iteration: 442, error: 0.42018325791855204, fairness violation: 0.004905882352941177, violated group size: 0.217\n",
+ "iteration: 443, error: 0.42013769751693003, fairness violation: 0.004909503386004516, violated group size: 0.283\n",
+ "iteration: 444, error: 0.42009234234234244, fairness violation: 0.004913108108108108, violated group size: 0.217\n",
+ "iteration: 445, error: 0.420047191011236, fairness violation: 0.004916696629213483, violated group size: 0.217\n",
+ "iteration: 446, error: 0.42000224215246645, fairness violation: 0.004920269058295964, violated group size: 0.217\n",
+ "iteration: 447, error: 0.4199574944071588, fairness violation: 0.004923825503355704, violated group size: 0.217\n",
+ "iteration: 448, error: 0.41991294642857147, fairness violation: 0.004927366071428571, violated group size: 0.217\n",
+ "iteration: 449, error: 0.41986859688195993, fairness violation: 0.004930890868596881, violated group size: 0.217\n",
+ "iteration: 450, error: 0.41982444444444433, fairness violation: 0.004934399999999999, violated group size: 0.217\n",
+ "iteration: 451, error: 0.41978048780487814, fairness violation: 0.004937893569844789, violated group size: 0.217\n",
+ "iteration: 452, error: 0.41973672566371684, fairness violation: 0.004941371681415928, violated group size: 0.217\n",
+ "iteration: 453, error: 0.41969315673289187, fairness violation: 0.0049448344370860925, violated group size: 0.217\n",
+ "iteration: 454, error: 0.41964977973568285, fairness violation: 0.004948281938325991, violated group size: 0.283\n",
+ "iteration: 455, error: 0.41960659340659345, fairness violation: 0.004951714285714286, violated group size: 0.283\n",
+ "iteration: 456, error: 0.41956359649122804, fairness violation: 0.004955131578947367, violated group size: 0.217\n",
+ "iteration: 457, error: 0.41952078774617063, fairness violation: 0.0049585339168490145, violated group size: 0.217\n",
+ "iteration: 458, error: 0.41947816593886456, fairness violation: 0.004961921397379911, violated group size: 0.217\n",
+ "iteration: 459, error: 0.4194357298474945, fairness violation: 0.00496529411764706, violated group size: 0.283\n",
+ "iteration: 460, error: 0.4193934782608696, fairness violation: 0.004968652173913044, violated group size: 0.283\n",
+ "iteration: 461, error: 0.41935140997830805, fairness violation: 0.004971995661605205, violated group size: 0.283\n",
+ "iteration: 462, error: 0.41930952380952374, fairness violation: 0.0049753246753246735, violated group size: 0.217\n",
+ "iteration: 463, error: 0.41926781857451406, fairness violation: 0.004978639308855291, violated group size: 0.217\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "iteration: 464, error: 0.41922629310344817, fairness violation: 0.004981939655172413, violated group size: 0.217\n",
+ "iteration: 465, error: 0.41918494623655916, fairness violation: 0.004985225806451612, violated group size: 0.217\n",
+ "iteration: 466, error: 0.4191437768240344, fairness violation: 0.004988497854077254, violated group size: 0.283\n",
+ "iteration: 467, error: 0.41910278372591, fairness violation: 0.004991755888650964, violated group size: 0.283\n",
+ "iteration: 468, error: 0.4190619658119658, fairness violation: 0.004994999999999998, violated group size: 0.217\n",
+ "iteration: 469, error: 0.4190213219616205, fairness violation: 0.0049982302771855, violated group size: 0.217\n",
+ "iteration: 470, error: 0.41898085106382993, fairness violation: 0.005001446808510636, violated group size: 0.283\n",
+ "iteration: 471, error: 0.41894055201698516, fairness violation: 0.005004649681528661, violated group size: 0.217\n",
+ "iteration: 472, error: 0.41890042372881364, fairness violation: 0.005007838983050847, violated group size: 0.283\n",
+ "iteration: 473, error: 0.4188604651162791, fairness violation: 0.005011014799154333, violated group size: 0.217\n",
+ "iteration: 474, error: 0.41882067510548526, fairness violation: 0.005014177215189871, violated group size: 0.217\n",
+ "iteration: 475, error: 0.41878105263157905, fairness violation: 0.0050173263157894735, violated group size: 0.283\n",
+ "iteration: 476, error: 0.41874159663865557, fairness violation: 0.0050204621848739485, violated group size: 0.217\n",
+ "iteration: 477, error: 0.4188259958071279, fairness violation: 0.005013576519916141, violated group size: 0.217\n",
+ "iteration: 478, error: 0.41878661087866115, fairness violation: 0.005016707112970709, violated group size: 0.217\n",
+ "iteration: 479, error: 0.4188705636743216, fairness violation: 0.005009858037578285, violated group size: 0.217\n",
+ "iteration: 480, error: 0.41883125000000004, fairness violation: 0.005012983333333334, violated group size: 0.283\n",
+ "iteration: 481, error: 0.4187920997920998, fairness violation: 0.005016095634095634, violated group size: 0.283\n",
+ "iteration: 482, error: 0.4187531120331951, fairness violation: 0.0050191950207468874, violated group size: 0.283\n",
+ "iteration: 483, error: 0.4188364389233955, fairness violation: 0.00501239751552795, violated group size: 0.217\n",
+ "iteration: 484, error: 0.41891942148760325, fairness violation: 0.005005628099173555, violated group size: 0.283\n",
+ "iteration: 485, error: 0.4190020618556701, fairness violation: 0.004998886597938144, violated group size: 0.283\n",
+ "iteration: 486, error: 0.41896296296296304, fairness violation: 0.005001995884773661, violated group size: 0.217\n",
+ "iteration: 487, error: 0.4190451745379877, fairness violation: 0.004995289527720739, violated group size: 0.283\n",
+ "iteration: 488, error: 0.4191270491803279, fairness violation: 0.004988610655737704, violated group size: 0.283\n",
+ "iteration: 489, error: 0.4192085889570552, fairness violation: 0.004981959100204497, violated group size: 0.217\n",
+ "iteration: 490, error: 0.41916938775510204, fairness violation: 0.004985077551020407, violated group size: 0.283\n",
+ "iteration: 491, error: 0.4192505091649695, fairness violation: 0.004978460285132381, violated group size: 0.217\n",
+ "iteration: 492, error: 0.4192113821138212, fairness violation: 0.004981573170731706, violated group size: 0.217\n",
+ "iteration: 493, error: 0.41917241379310355, fairness violation: 0.004984673427991887, violated group size: 0.283\n",
+ "iteration: 494, error: 0.41913360323886634, fairness violation: 0.004987761133603237, violated group size: 0.217\n",
+ "iteration: 495, error: 0.4192141414141415, fairness violation: 0.004981191919191918, violated group size: 0.217\n",
+ "iteration: 496, error: 0.4192943548387097, fairness violation: 0.004974649193548386, violated group size: 0.217\n",
+ "iteration: 497, error: 0.419374245472837, fairness violation: 0.004968132796780683, violated group size: 0.217\n",
+ "iteration: 498, error: 0.4194538152610441, fairness violation: 0.004961642570281124, violated group size: 0.217\n",
+ "iteration: 499, error: 0.41953306613226454, fairness violation: 0.0049551783567134255, violated group size: 0.217\n"
+ ]
+ }
+ ],
+ "source": [
+ "C = 100\n",
+ "print_flag = True\n",
+ "gamma = .005\n",
+ "\n",
+ "\n",
+ "fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n",
+ " max_iters=max_iterations, heatmapflag=False)\n",
+ "\n",
+ "# fit method\n",
+ "fair_model.fit(data_set, early_termination=True)\n",
+ "\n",
+ "# predict method. If threshold in (0, 1) produces binary predictions\n",
+ "\n",
+ "dataset_yhat = fair_model.predict(data_set, threshold=False)\n",
+ "\n",
+ "\n"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 310, error: 0.4207838709677419, fairness violation: 0.00486963870967742, violated group size: 0.283\n",
- "iteration: 311, error: 0.420717041800643, fairness violation: 0.004874913183279744, violated group size: 0.217\n",
- "iteration: 312, error: 0.42065064102564104, fairness violation: 0.0048801538461538466, violated group size: 0.217\n",
- "iteration: 313, error: 0.42058466453674115, fairness violation: 0.0048853610223642185, violated group size: 0.217\n",
- "iteration: 314, error: 0.42051910828025474, fairness violation: 0.004890535031847135, violated group size: 0.217\n",
- "iteration: 315, error: 0.4204539682539683, fairness violation: 0.004895676190476191, violated group size: 0.217\n",
- "iteration: 316, error: 0.4203892405063292, fairness violation: 0.004900784810126583, violated group size: 0.217\n",
- "iteration: 317, error: 0.4203249211356468, fairness violation: 0.004905861198738172, violated group size: 0.283\n",
- "iteration: 318, error: 0.4202610062893082, fairness violation: 0.00491090566037736, violated group size: 0.217\n",
- "iteration: 319, error: 0.4201974921630094, fairness violation: 0.004915918495297806, violated group size: 0.217\n",
- "iteration: 320, error: 0.4201343749999999, fairness violation: 0.004920900000000002, violated group size: 0.217\n",
- "iteration: 321, error: 0.4200716510903427, fairness violation: 0.004925850467289721, violated group size: 0.217\n",
- "iteration: 322, error: 0.4200093167701862, fairness violation: 0.0049307701863354056, violated group size: 0.283\n",
- "iteration: 323, error: 0.4199473684210526, fairness violation: 0.00493565944272446, violated group size: 0.283\n",
- "iteration: 324, error: 0.41988580246913576, fairness violation: 0.004940518518518519, violated group size: 0.217\n",
- "iteration: 325, error: 0.41982461538461535, fairness violation: 0.004945347692307694, violated group size: 0.217\n",
- "iteration: 326, error: 0.4197638036809816, fairness violation: 0.004950147239263805, violated group size: 0.283\n",
- "iteration: 327, error: 0.4197033639143731, fairness violation: 0.004954917431192661, violated group size: 0.283\n",
- "iteration: 328, error: 0.4196432926829268, fairness violation: 0.004959658536585366, violated group size: 0.217\n",
- "iteration: 329, error: 0.41958358662613987, fairness violation: 0.004964370820668694, violated group size: 0.283\n",
- "iteration: 330, error: 0.41952424242424236, fairness violation: 0.004969054545454545, violated group size: 0.217\n",
- "iteration: 331, error: 0.41946525679758306, fairness violation: 0.00497370996978852, violated group size: 0.217\n",
- "iteration: 332, error: 0.41940662650602417, fairness violation: 0.004978337349397591, violated group size: 0.217\n",
- "iteration: 333, error: 0.4193483483483482, fairness violation: 0.004982936936936937, violated group size: 0.217\n",
- "iteration: 334, error: 0.4192904191616766, fairness violation: 0.004987508982035928, violated group size: 0.217\n",
- "iteration: 335, error: 0.4192328358208956, fairness violation: 0.004992053731343284, violated group size: 0.283\n",
- "iteration: 336, error: 0.4191755952380953, fairness violation: 0.00499657142857143, violated group size: 0.283\n",
- "iteration: 337, error: 0.4191186943620178, fairness violation: 0.0050010623145400595, violated group size: 0.217\n",
- "iteration: 338, error: 0.41906213017751476, fairness violation: 0.005005526627218935, violated group size: 0.217\n",
- "iteration: 339, error: 0.4190058997050148, fairness violation: 0.005009964601769911, violated group size: 0.217\n",
- "iteration: 340, error: 0.41894999999999993, fairness violation: 0.005014376470588236, violated group size: 0.283\n",
- "iteration: 341, error: 0.41889442815249267, fairness violation: 0.005018762463343108, violated group size: 0.217\n",
- "iteration: 342, error: 0.41883918128654973, fairness violation: 0.005023122807017544, violated group size: 0.217\n",
- "iteration: 343, error: 0.41878425655976675, fairness violation: 0.0050274577259475225, violated group size: 0.283\n",
- "iteration: 344, error: 0.4187296511627907, fairness violation: 0.005031767441860465, violated group size: 0.217\n",
- "iteration: 345, error: 0.4186753623188406, fairness violation: 0.005036052173913045, violated group size: 0.283\n",
- "iteration: 346, error: 0.4186213872832369, fairness violation: 0.005040312138728323, violated group size: 0.217\n",
- "iteration: 347, error: 0.41856772334293946, fairness violation: 0.005044547550432276, violated group size: 0.283\n",
- "iteration: 348, error: 0.41851436781609197, fairness violation: 0.005048758620689655, violated group size: 0.217\n",
- "iteration: 349, error: 0.418461318051576, fairness violation: 0.005052945558739255, violated group size: 0.283\n",
- "iteration: 350, error: 0.4185771428571428, fairness violation: 0.005043468571428572, violated group size: 0.283\n",
- "iteration: 351, error: 0.4186923076923077, fairness violation: 0.005034045584045584, violated group size: 0.217\n",
- "iteration: 352, error: 0.4188068181818182, fairness violation: 0.005024676136363637, violated group size: 0.283\n",
- "iteration: 353, error: 0.4189206798866855, fairness violation: 0.005015359773371105, violated group size: 0.217\n",
- "iteration: 354, error: 0.41903389830508475, fairness violation: 0.005006096045197741, violated group size: 0.283\n",
- "iteration: 355, error: 0.41914647887323936, fairness violation: 0.004996884507042254, violated group size: 0.283\n",
- "iteration: 356, error: 0.4192584269662922, fairness violation: 0.004987724719101122, violated group size: 0.217\n",
- "iteration: 357, error: 0.41936974789915965, fairness violation: 0.0049786162464986, violated group size: 0.217\n",
- "iteration: 358, error: 0.41948044692737424, fairness violation: 0.004969558659217878, violated group size: 0.217\n",
- "iteration: 359, error: 0.41959052924791085, fairness violation: 0.004960551532033426, violated group size: 0.283\n",
- "iteration: 360, error: 0.4195361111111111, fairness violation: 0.004964855555555557, violated group size: 0.283\n",
- "iteration: 361, error: 0.4196454293628808, fairness violation: 0.004955911357340723, violated group size: 0.283\n",
- "iteration: 362, error: 0.4197541436464089, fairness violation: 0.004947016574585636, violated group size: 0.217\n",
- "iteration: 363, error: 0.4198622589531681, fairness violation: 0.004938170798898072, violated group size: 0.283\n",
- "iteration: 364, error: 0.41996978021978026, fairness violation: 0.004929373626373626, violated group size: 0.217\n",
- "iteration: 365, error: 0.42007671232876714, fairness violation: 0.004920624657534246, violated group size: 0.217\n",
- "iteration: 366, error: 0.42018306010928963, fairness violation: 0.004911923497267759, violated group size: 0.217\n",
- "iteration: 367, error: 0.4202888283378746, fairness violation: 0.004903269754768393, violated group size: 0.217\n",
- "iteration: 368, error: 0.42039402173913043, fairness violation: 0.00489466304347826, violated group size: 0.217\n",
- "iteration: 369, error: 0.4204986449864499, fairness violation: 0.00488610298102981, violated group size: 0.283\n",
- "iteration: 370, error: 0.4206027027027027, fairness violation: 0.0048775891891891885, violated group size: 0.217\n",
- "iteration: 371, error: 0.4207061994609164, fairness violation: 0.004869121293800538, violated group size: 0.217\n",
- "iteration: 372, error: 0.4208091397849463, fairness violation: 0.004860698924731182, violated group size: 0.217\n",
- "iteration: 373, error: 0.420911528150134, fairness violation: 0.004852321715817694, violated group size: 0.217\n",
- "iteration: 374, error: 0.420855614973262, fairness violation: 0.004856754010695187, violated group size: 0.217\n",
- "iteration: 375, error: 0.4209573333333334, fairness violation: 0.004848432, violated group size: 0.217\n",
- "iteration: 376, error: 0.42105851063829786, fairness violation: 0.004840154255319148, violated group size: 0.217\n",
- "iteration: 377, error: 0.4211591511936339, fairness violation: 0.004831920424403182, violated group size: 0.217\n",
- "iteration: 378, error: 0.4211031746031746, fairness violation: 0.004836359788359788, violated group size: 0.217\n",
- "iteration: 379, error: 0.42120316622691284, fairness violation: 0.004828179419525066, violated group size: 0.217\n",
- "iteration: 380, error: 0.42130263157894726, fairness violation: 0.004820042105263157, violated group size: 0.217\n",
- "iteration: 381, error: 0.42124671916010503, fairness violation: 0.004824477690288715, violated group size: 0.283\n",
- "iteration: 382, error: 0.42134554973821986, fairness violation: 0.004816392670157068, violated group size: 0.283\n",
- "iteration: 383, error: 0.42144386422976504, fairness violation: 0.004808349869451696, violated group size: 0.217\n",
- "iteration: 384, error: 0.42154166666666676, fairness violation: 0.004800348958333333, violated group size: 0.283\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ak4FlXHmhI5L"
+ },
+ "source": [
+ "**3-d heatmaps**\n",
+ "\n",
+ "We now show to generate a 3d-heatmap of unfairness using the `generate_heatmap` method. The $X-Y$ axes in the plot represent the coefficients of the linear threshold function that defines a protected subgroup with respect to the first two sensitive attributes. Which $2$ attributes are considered sensitive can be overwritten with the `col_index` argument. The $Z$-axes is the $\\gamma$-disparity (FP) of the corresponding subgroup defined by the linear threshold function. This is important because it allows us to (1) visualize convergence as the heatmap flattens and (2) brute force check the fairness in low-dimensions without relying on a heuristic auditor. See the [the rich subgroup fairness empirical paper](https://arxiv.org/abs/1808.08166) for a discussion of these plots. Note that in the below plot no group has a $\\gamma$-disparity of greater than $.005$, which we would expect since the set of linear threshold functions on two attributes is a subset of the set of linear threshold functions on all protected attributes, and the final model is $\\gamma$-fair."
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 385, error: 0.4214857142857143, fairness violation: 0.00480478961038961, violated group size: 0.217\n",
- "iteration: 386, error: 0.4215829015544041, fairness violation: 0.0047968393782383415, violated group size: 0.217\n",
- "iteration: 387, error: 0.42152713178294576, fairness violation: 0.004801266149870801, violated group size: 0.283\n",
- "iteration: 388, error: 0.4216237113402061, fairness violation: 0.004793365979381444, violated group size: 0.283\n",
- "iteration: 389, error: 0.42171979434447304, fairness violation: 0.004785506426735219, violated group size: 0.217\n",
- "iteration: 390, error: 0.4218153846153847, fairness violation: 0.00477768717948718, violated group size: 0.217\n",
- "iteration: 391, error: 0.4219104859335038, fairness violation: 0.004769907928388747, violated group size: 0.217\n",
- "iteration: 392, error: 0.4218545918367348, fairness violation: 0.00477434693877551, violated group size: 0.217\n",
- "iteration: 393, error: 0.42194910941475816, fairness violation: 0.004766615776081424, violated group size: 0.217\n",
- "iteration: 394, error: 0.42204314720812186, fairness violation: 0.004758923857868021, violated group size: 0.283\n",
- "iteration: 395, error: 0.4221367088607595, fairness violation: 0.00475127088607595, violated group size: 0.217\n",
- "iteration: 396, error: 0.42208080808080806, fairness violation: 0.004755712121212122, violated group size: 0.283\n",
- "iteration: 397, error: 0.42202518891687657, fairness violation: 0.004760130982367758, violated group size: 0.217\n",
- "iteration: 398, error: 0.42196984924623115, fairness violation: 0.004764527638190955, violated group size: 0.283\n",
- "iteration: 399, error: 0.42191478696741846, fairness violation: 0.004768902255639098, violated group size: 0.217\n",
- "iteration: 400, error: 0.42186, fairness violation: 0.004773255, violated group size: 0.217\n",
- "iteration: 401, error: 0.4218054862842893, fairness violation: 0.004777586034912718, violated group size: 0.283\n",
- "iteration: 402, error: 0.4217512437810945, fairness violation: 0.00478189552238806, violated group size: 0.217\n",
- "iteration: 403, error: 0.4218436724565757, fairness violation: 0.0047743374689826305, violated group size: 0.283\n",
- "iteration: 404, error: 0.42178960396039605, fairness violation: 0.004778633663366336, violated group size: 0.217\n",
- "iteration: 405, error: 0.4217358024691357, fairness violation: 0.004782908641975308, violated group size: 0.217\n",
- "iteration: 406, error: 0.4216822660098523, fairness violation: 0.004787162561576355, violated group size: 0.217\n",
- "iteration: 407, error: 0.4216289926289926, fairness violation: 0.004791395577395577, violated group size: 0.283\n",
- "iteration: 408, error: 0.421575980392157, fairness violation: 0.004795607843137254, violated group size: 0.217\n",
- "iteration: 409, error: 0.4215232273838631, fairness violation: 0.004799799511002444, violated group size: 0.217\n",
- "iteration: 410, error: 0.42147073170731714, fairness violation: 0.004803970731707317, violated group size: 0.283\n",
- "iteration: 411, error: 0.4214184914841849, fairness violation: 0.0048081216545012165, violated group size: 0.217\n",
- "iteration: 412, error: 0.4213665048543689, fairness violation: 0.004812252427184466, violated group size: 0.283\n",
- "iteration: 413, error: 0.42131476997578693, fairness violation: 0.004816363196125908, violated group size: 0.217\n",
- "iteration: 414, error: 0.42126328502415455, fairness violation: 0.004820454106280194, violated group size: 0.217\n",
- "iteration: 415, error: 0.4212120481927711, fairness violation: 0.004824525301204821, violated group size: 0.283\n",
- "iteration: 416, error: 0.42116105769230766, fairness violation: 0.004828576923076923, violated group size: 0.217\n",
- "iteration: 417, error: 0.4211103117505996, fairness violation: 0.004832609112709832, violated group size: 0.283\n",
- "iteration: 418, error: 0.42105980861244025, fairness violation: 0.004836622009569378, violated group size: 0.283\n",
- "iteration: 419, error: 0.42100954653937933, fairness violation: 0.004840615751789977, violated group size: 0.217\n",
- "iteration: 420, error: 0.42110000000000003, fairness violation: 0.0048332238095238084, violated group size: 0.217\n",
- "iteration: 421, error: 0.42104988123515436, fairness violation: 0.004837206650831354, violated group size: 0.217\n",
- "iteration: 422, error: 0.42100000000000004, fairness violation: 0.004841170616113744, violated group size: 0.217\n",
- "iteration: 423, error: 0.420950354609929, fairness violation: 0.004845115839243499, violated group size: 0.217\n",
- "iteration: 424, error: 0.42104009433962253, fairness violation: 0.004837783018867924, violated group size: 0.217\n",
- "iteration: 425, error: 0.4209905882352941, fairness violation: 0.004841717647058822, violated group size: 0.217\n",
- "iteration: 426, error: 0.42094131455399053, fairness violation: 0.004845633802816901, violated group size: 0.217\n",
- "iteration: 427, error: 0.42089227166276344, fairness violation: 0.004849531615925057, violated group size: 0.217\n",
- "iteration: 428, error: 0.4208434579439252, fairness violation: 0.004853411214953271, violated group size: 0.217\n",
- "iteration: 429, error: 0.4207948717948717, fairness violation: 0.0048572727272727274, violated group size: 0.283\n",
- "iteration: 430, error: 0.42074651162790694, fairness violation: 0.004861116279069767, violated group size: 0.217\n",
- "iteration: 431, error: 0.4206983758700697, fairness violation: 0.004864941995359629, violated group size: 0.283\n",
- "iteration: 432, error: 0.420650462962963, fairness violation: 0.00486875, violated group size: 0.217\n",
- "iteration: 433, error: 0.42060277136258656, fairness violation: 0.0048725404157043874, violated group size: 0.217\n",
- "iteration: 434, error: 0.42055529953917054, fairness violation: 0.0048763133640553, violated group size: 0.283\n",
- "iteration: 435, error: 0.42050804597701147, fairness violation: 0.004880068965517241, violated group size: 0.217\n",
- "iteration: 436, error: 0.4204610091743119, fairness violation: 0.004883807339449542, violated group size: 0.217\n",
- "iteration: 437, error: 0.4204141876430207, fairness violation: 0.004887528604118992, violated group size: 0.217\n",
- "iteration: 438, error: 0.42036757990867574, fairness violation: 0.004891232876712329, violated group size: 0.217\n",
- "iteration: 439, error: 0.4203211845102506, fairness violation: 0.0048949202733485206, violated group size: 0.283\n",
- "iteration: 440, error: 0.42027499999999995, fairness violation: 0.00489859090909091, violated group size: 0.283\n",
- "iteration: 441, error: 0.42022902494331066, fairness violation: 0.004902244897959184, violated group size: 0.283\n",
- "iteration: 442, error: 0.42018325791855204, fairness violation: 0.004905882352941177, violated group size: 0.217\n",
- "iteration: 443, error: 0.42013769751693003, fairness violation: 0.004909503386004516, violated group size: 0.283\n",
- "iteration: 444, error: 0.42009234234234244, fairness violation: 0.004913108108108108, violated group size: 0.217\n",
- "iteration: 445, error: 0.420047191011236, fairness violation: 0.004916696629213483, violated group size: 0.217\n",
- "iteration: 446, error: 0.42000224215246645, fairness violation: 0.004920269058295964, violated group size: 0.217\n",
- "iteration: 447, error: 0.4199574944071588, fairness violation: 0.004923825503355704, violated group size: 0.217\n",
- "iteration: 448, error: 0.41991294642857147, fairness violation: 0.004927366071428571, violated group size: 0.217\n",
- "iteration: 449, error: 0.41986859688195993, fairness violation: 0.004930890868596881, violated group size: 0.217\n",
- "iteration: 450, error: 0.41982444444444433, fairness violation: 0.004934399999999999, violated group size: 0.217\n",
- "iteration: 451, error: 0.41978048780487814, fairness violation: 0.004937893569844789, violated group size: 0.217\n",
- "iteration: 452, error: 0.41973672566371684, fairness violation: 0.004941371681415928, violated group size: 0.217\n",
- "iteration: 453, error: 0.41969315673289187, fairness violation: 0.0049448344370860925, violated group size: 0.217\n",
- "iteration: 454, error: 0.41964977973568285, fairness violation: 0.004948281938325991, violated group size: 0.283\n",
- "iteration: 455, error: 0.41960659340659345, fairness violation: 0.004951714285714286, violated group size: 0.283\n",
- "iteration: 456, error: 0.41956359649122804, fairness violation: 0.004955131578947367, violated group size: 0.217\n",
- "iteration: 457, error: 0.41952078774617063, fairness violation: 0.0049585339168490145, violated group size: 0.217\n",
- "iteration: 458, error: 0.41947816593886456, fairness violation: 0.004961921397379911, violated group size: 0.217\n",
- "iteration: 459, error: 0.4194357298474945, fairness violation: 0.00496529411764706, violated group size: 0.283\n",
- "iteration: 460, error: 0.4193934782608696, fairness violation: 0.004968652173913044, violated group size: 0.283\n",
- "iteration: 461, error: 0.41935140997830805, fairness violation: 0.004971995661605205, violated group size: 0.283\n",
- "iteration: 462, error: 0.41930952380952374, fairness violation: 0.0049753246753246735, violated group size: 0.217\n",
- "iteration: 463, error: 0.41926781857451406, fairness violation: 0.004978639308855291, violated group size: 0.217\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "u364ULCthI5L",
+ "outputId": "24fa9897-8b4a-4d42-afb3-2832bd76bb99"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "# output heatmap (brute force)\n",
+ "# replace None with the relative path if you want to save the plot\n",
+ "fair_model.heatmapflag = True\n",
+ "fair_model.heatmap_path = 'heatmap'\n",
+ "fair_model.generate_heatmap(data_set, dataset_yhat.labels)\n",
+ "Image(filename='{}.png'.format(fair_model.heatmap_path))\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "iteration: 464, error: 0.41922629310344817, fairness violation: 0.004981939655172413, violated group size: 0.217\n",
- "iteration: 465, error: 0.41918494623655916, fairness violation: 0.004985225806451612, violated group size: 0.217\n",
- "iteration: 466, error: 0.4191437768240344, fairness violation: 0.004988497854077254, violated group size: 0.283\n",
- "iteration: 467, error: 0.41910278372591, fairness violation: 0.004991755888650964, violated group size: 0.283\n",
- "iteration: 468, error: 0.4190619658119658, fairness violation: 0.004994999999999998, violated group size: 0.217\n",
- "iteration: 469, error: 0.4190213219616205, fairness violation: 0.0049982302771855, violated group size: 0.217\n",
- "iteration: 470, error: 0.41898085106382993, fairness violation: 0.005001446808510636, violated group size: 0.283\n",
- "iteration: 471, error: 0.41894055201698516, fairness violation: 0.005004649681528661, violated group size: 0.217\n",
- "iteration: 472, error: 0.41890042372881364, fairness violation: 0.005007838983050847, violated group size: 0.283\n",
- "iteration: 473, error: 0.4188604651162791, fairness violation: 0.005011014799154333, violated group size: 0.217\n",
- "iteration: 474, error: 0.41882067510548526, fairness violation: 0.005014177215189871, violated group size: 0.217\n",
- "iteration: 475, error: 0.41878105263157905, fairness violation: 0.0050173263157894735, violated group size: 0.283\n",
- "iteration: 476, error: 0.41874159663865557, fairness violation: 0.0050204621848739485, violated group size: 0.217\n",
- "iteration: 477, error: 0.4188259958071279, fairness violation: 0.005013576519916141, violated group size: 0.217\n",
- "iteration: 478, error: 0.41878661087866115, fairness violation: 0.005016707112970709, violated group size: 0.217\n",
- "iteration: 479, error: 0.4188705636743216, fairness violation: 0.005009858037578285, violated group size: 0.217\n",
- "iteration: 480, error: 0.41883125000000004, fairness violation: 0.005012983333333334, violated group size: 0.283\n",
- "iteration: 481, error: 0.4187920997920998, fairness violation: 0.005016095634095634, violated group size: 0.283\n",
- "iteration: 482, error: 0.4187531120331951, fairness violation: 0.0050191950207468874, violated group size: 0.283\n",
- "iteration: 483, error: 0.4188364389233955, fairness violation: 0.00501239751552795, violated group size: 0.217\n",
- "iteration: 484, error: 0.41891942148760325, fairness violation: 0.005005628099173555, violated group size: 0.283\n",
- "iteration: 485, error: 0.4190020618556701, fairness violation: 0.004998886597938144, violated group size: 0.283\n",
- "iteration: 486, error: 0.41896296296296304, fairness violation: 0.005001995884773661, violated group size: 0.217\n",
- "iteration: 487, error: 0.4190451745379877, fairness violation: 0.004995289527720739, violated group size: 0.283\n",
- "iteration: 488, error: 0.4191270491803279, fairness violation: 0.004988610655737704, violated group size: 0.283\n",
- "iteration: 489, error: 0.4192085889570552, fairness violation: 0.004981959100204497, violated group size: 0.217\n",
- "iteration: 490, error: 0.41916938775510204, fairness violation: 0.004985077551020407, violated group size: 0.283\n",
- "iteration: 491, error: 0.4192505091649695, fairness violation: 0.004978460285132381, violated group size: 0.217\n",
- "iteration: 492, error: 0.4192113821138212, fairness violation: 0.004981573170731706, violated group size: 0.217\n",
- "iteration: 493, error: 0.41917241379310355, fairness violation: 0.004984673427991887, violated group size: 0.283\n",
- "iteration: 494, error: 0.41913360323886634, fairness violation: 0.004987761133603237, violated group size: 0.217\n",
- "iteration: 495, error: 0.4192141414141415, fairness violation: 0.004981191919191918, violated group size: 0.217\n",
- "iteration: 496, error: 0.4192943548387097, fairness violation: 0.004974649193548386, violated group size: 0.217\n",
- "iteration: 497, error: 0.419374245472837, fairness violation: 0.004968132796780683, violated group size: 0.217\n",
- "iteration: 498, error: 0.4194538152610441, fairness violation: 0.004961642570281124, violated group size: 0.217\n",
- "iteration: 499, error: 0.41953306613226454, fairness violation: 0.0049551783567134255, violated group size: 0.217\n"
- ]
- }
- ],
- "source": [
- "C = 100\n",
- "print_flag = True\n",
- "gamma = .005\n",
- "\n",
- "\n",
- "fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n",
- " max_iters=max_iterations, heatmapflag=False)\n",
- "\n",
- "# fit method\n",
- "fair_model.fit(data_set, early_termination=True)\n",
- "\n",
- "# predict method. If threshold in (0, 1) produces binary predictions\n",
- "\n",
- "dataset_yhat = fair_model.predict(data_set, threshold=False)\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**3-d heatmaps**\n",
- "\n",
- "We now show to generate a 3d-heatmap of unfairness using the `generate_heatmap` method. The $X-Y$ axes in the plot represent the coefficients of the linear threshold function that defines a protected subgroup with respect to the first two sensitive attributes. Which $2$ attributes are considered sensitive can be overwritten with the `col_index` argument. The $Z$-axes is the $\\gamma$-disparity (FP) of the corresponding subgroup defined by the linear threshold function. This is important because it allows us to (1) visualize convergence as the heatmap flattens and (2) brute force check the fairness in low-dimensions without relying on a heuristic auditor. See the [the rich subgroup fairness empirical paper](https://arxiv.org/abs/1808.08166) for a discussion of these plots. Note that in the below plot no group has a $\\gamma$-disparity of greater than $.005$, which we would expect since the set of linear threshold functions on two attributes is a subset of the set of linear threshold functions on all protected attributes, and the final model is $\\gamma$-fair. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "pycharm": {
- "is_executing": true
- }
- },
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- ""
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SwMw4smGhI5M"
+ },
+ "source": [
+ "**black-box auditing**\n",
+ "\n",
+ "We now show to audit any black box classifier with respect to rich subgroup fairness under either FP or FN rate. Note the below auditing procedure would work for any set of (soft) predictions $\\hat{y}$, and need make no assumptions about the structure of the predictor. We note that as expected the disparity of the group found is the same as the disparity printed out in the last iteration of the `fit` method.\n",
+ " "
]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "\n",
- "# output heatmap (brute force)\n",
- "# replace None with the relative path if you want to save the plot\n",
- "fair_model.heatmapflag = True\n",
- "fair_model.heatmap_path = 'heatmap'\n",
- "fair_model.generate_heatmap(data_set, dataset_yhat.labels)\n",
- "Image(filename='{}.png'.format(fair_model.heatmap_path)) \n",
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**black-box auditing**\n",
- "\n",
- "We now show to audit any black box classifier with respect to rich subgroup fairness under either FP or FN rate. Note the below auditing procedure would work for any set of (soft) predictions $\\hat{y}$, and need make no assumptions about the structure of the predictor. We note that as expected the disparity of the group found is the same as the disparity printed out in the last iteration of the `fit` method.\n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "pycharm": {
- "is_executing": true
- }
- },
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.004955178356713431\n"
- ]
- }
- ],
- "source": [
- "\n",
- "\n",
- "gerry_metric = BinaryLabelDatasetMetric(data_set)\n",
- "gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n",
- "print(gamma_disparity)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**pareto curves**\n",
- "\n",
- "The `FairFictPlay` algorithm implemented in the `fit` method converges given access to perfect oracles for solving cost-sensitive classification (CSC) problems. A cost-sensitive classification problem over a hypothesis class $\\mathcal{H}$ is $$\\min_{h}\\sum_{i = 1}^{n}(1-h(x_i))c_0 + h(x_i)c_1$$\n",
- "By default in this package, and in the companion [empirical](https://arxiv.org/abs/1808.08166) and [theory](https://arxiv.org/pdf/1711.05144.pdf) papers, the hypothesis class of the learner and the of the subgroups are hyperplanes. The corresponding heuristic oracle for solving the CSC problem first forms two regression problems $(x_i, c_0)$ and $(x_i, c_1)$. Then in the case of hyperplanes, trains two regressions $r_i: \\mathcal{X} \\to R$ which predict the costs of classifying a given point $x$ $0,1$ respectively. Finally the binary classifier output by the oracle is defined as $\\hat{r}(x) = \\arg\\min_{j \\in \\{0,1\\}}r_j(x)$. But of course if we are interesting in different hypothesis classes for the learner, we simply need different regressors. In this package in addition to linear regression, we've added support for regression trees, kernelized ridge regression, and support vector regression. Below we trace out Pareto curves of $\\gamma$-disparity vs. error for each of these different heuristic oracles. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "pycharm": {
- "is_executing": true
- }
- },
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "YsozvR3hhI5M",
+ "outputId": "6ef397ab-120c-49a0-a767-5201a047ee3d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.004955178356713431\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "gerry_metric = BinaryLabelDatasetMetric(data_set)\n",
+ "gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n",
+ "print(gamma_disparity)\n"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Curr Predictor: Linear\n",
- "Curr Predictor: SVR\n",
- "Curr Predictor: Tree\n",
- "Curr Predictor: Kernel\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1qK4jg_PhI5M"
+ },
+ "source": [
+ "**pareto curves**\n",
+ "\n",
+ "The `FairFictPlay` algorithm implemented in the `fit` method converges given access to perfect oracles for solving cost-sensitive classification (CSC) problems. A cost-sensitive classification problem over a hypothesis class $\\mathcal{H}$ is $$\\min_{h}\\sum_{i = 1}^{n}(1-h(x_i))c_0 + h(x_i)c_1$$\n",
+ "By default in this package, and in the companion [empirical](https://arxiv.org/abs/1808.08166) and [theory](https://arxiv.org/pdf/1711.05144.pdf) papers, the hypothesis class of the learner and the of the subgroups are hyperplanes. The corresponding heuristic oracle for solving the CSC problem first forms two regression problems $(x_i, c_0)$ and $(x_i, c_1)$. Then in the case of hyperplanes, trains two regressions $r_i: \\mathcal{X} \\to R$ which predict the costs of classifying a given point $x$ $0,1$ respectively. Finally the binary classifier output by the oracle is defined as $\\hat{r}(x) = \\arg\\min_{j \\in \\{0,1\\}}r_j(x)$. But of course if we are interesting in different hypothesis classes for the learner, we simply need different regressors. In this package in addition to linear regression, we've added support for regression trees, kernelized ridge regression, and support vector regression. Below we trace out Pareto curves of $\\gamma$-disparity vs. error for each of these different heuristic oracles."
+ ]
},
{
- "data": {
- "image/png": "\n",
- "text/plain": [
- ""
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "WJspqrmEhI5M",
+ "outputId": "d1478011-d530-45e8-a32f-3862db281721"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Curr Predictor: Linear\n",
+ "Curr Predictor: SVR\n",
+ "Curr Predictor: Tree\n",
+ "Curr Predictor: Kernel\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# set to 50 iterations for fast running of notebook - set >= 1000 when running real experiments\n",
+ "pareto_iters = 50\n",
+ "def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01, 0.02, 0.05, 0.1], save_results=False, iters=pareto_iters):\n",
+ "\n",
+ " ln_predictor = linear_model.LinearRegression()\n",
+ " svm_predictor = svm.LinearSVR()\n",
+ " tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n",
+ " kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n",
+ " predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n",
+ " 'SVR': {'predictor': svm_predictor, 'iters': iters},\n",
+ " 'Tree': {'predictor': tree_predictor, 'iters': iters},\n",
+ " 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n",
+ "\n",
+ " results_dict = {}\n",
+ "\n",
+ " for pred in predictor_dict:\n",
+ " print('Curr Predictor: {}'.format(pred))\n",
+ " predictor = predictor_dict[pred]['predictor']\n",
+ " max_iters = predictor_dict[pred]['iters']\n",
+ " fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n",
+ " fair_clf.printflag = False\n",
+ " fair_clf.max_iters=max_iters\n",
+ " errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n",
+ " results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n",
+ " plt.plot(errors, fp_violations, label=pred)\n",
+ "\n",
+ " if save_results:\n",
+ " pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n",
+ "\n",
+ " plt.xlabel('Error')\n",
+ " plt.ylabel('Unfairness')\n",
+ " plt.legend()\n",
+ " plt.title('Error vs. Unfairness\\n(Adult Dataset)')\n",
+ " plt.savefig('gerryfair_pareto.png')\n",
+ " plt.close()\n",
+ "multiple_classifiers_pareto(data_set)\n",
+ "Image(filename='gerryfair_pareto.png')"
]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# set to 50 iterations for fast running of notebook - set >= 1000 when running real experiments\n",
- "pareto_iters = 50\n",
- "def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01, 0.02, 0.05, 0.1], save_results=False, iters=pareto_iters):\n",
- "\n",
- " ln_predictor = linear_model.LinearRegression()\n",
- " svm_predictor = svm.LinearSVR()\n",
- " tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n",
- " kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n",
- " predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n",
- " 'SVR': {'predictor': svm_predictor, 'iters': iters},\n",
- " 'Tree': {'predictor': tree_predictor, 'iters': iters},\n",
- " 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n",
- "\n",
- " results_dict = {}\n",
- "\n",
- " for pred in predictor_dict:\n",
- " print('Curr Predictor: {}'.format(pred))\n",
- " predictor = predictor_dict[pred]['predictor']\n",
- " max_iters = predictor_dict[pred]['iters']\n",
- " fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n",
- " fair_clf.printflag = False\n",
- " fair_clf.max_iters=max_iters\n",
- " errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n",
- " results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n",
- " plt.plot(errors, fp_violations, label=pred)\n",
- "\n",
- " if save_results:\n",
- " pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n",
- "\n",
- " plt.xlabel('Error')\n",
- " plt.ylabel('Unfairness')\n",
- " plt.legend()\n",
- " plt.title('Error vs. Unfairness\\n(Adult Dataset)')\n",
- " plt.savefig('gerryfair_pareto.png')\n",
- " plt.close()\n",
- "multiple_classifiers_pareto(data_set)\n",
- "Image(filename='gerryfair_pareto.png') "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A natural question one might ask is, suppose we fix a statistical fairness definition for rich subgroup fairness like equality of false positive rates, `FP`. Does learning a classifier that is fair with respect to `FP` increase or decrease fairness with respect to false negative rates `FN`? One could see this relationship going in either direction - and indeed we submit that it is dataset dependent. In some cases, if enforcing `FP` fairness pushes the classifier towards the constant classifier, then it will also satisify `FN` rate fairness, since the constant classifier is perfectly fair. However, if the hypothesis class is sufficiently rich, then one would expect that ceteris paribus since we are optimizing for error in addition to `FP` rate fairness, the algorithm would increase `FN` rate unfairness in order to decrease error. Below we trace the FN vs. FP rate tradeoff across a range of input $\\gamma$, where the classifier is optimized only for `FP` rate fairness. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "pycharm": {
- "is_executing": true
- }
- },
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "gamma: 0.001 gamma: 0.002 gamma: 0.003 gamma: 0.004 gamma: 0.005 gamma: 0.0075 gamma: 0.01 gamma: 0.02 gamma: 0.03 gamma: 0.05 "
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N_ghZDh6hI5N"
+ },
+ "source": [
+ "A natural question one might ask is, suppose we fix a statistical fairness definition for rich subgroup fairness like equality of false positive rates, `FP`. Does learning a classifier that is fair with respect to `FP` increase or decrease fairness with respect to false negative rates `FN`? One could see this relationship going in either direction - and indeed we submit that it is dataset dependent. In some cases, if enforcing `FP` fairness pushes the classifier towards the constant classifier, then it will also satisify `FN` rate fairness, since the constant classifier is perfectly fair. However, if the hypothesis class is sufficiently rich, then one would expect that ceteris paribus since we are optimizing for error in addition to `FP` rate fairness, the algorithm would increase `FN` rate unfairness in order to decrease error. Below we trace the FN vs. FP rate tradeoff across a range of input $\\gamma$, where the classifier is optimized only for `FP` rate fairness."
+ ]
},
{
- "data": {
- "image/png": "\n",
- "text/plain": [
- ""
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "0xLIr5QdhI5N",
+ "outputId": "33fb403b-09c0-44a7-9c55-260319380020"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "gamma: 0.001 gamma: 0.002 gamma: 0.003 gamma: 0.004 gamma: 0.005 gamma: 0.0075 gamma: 0.01 gamma: 0.02 gamma: 0.03 gamma: 0.05 "
+ ]
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def fp_vs_fn(dataset, gamma_list, iters):\n",
+ " fp_auditor = Auditor(dataset, 'FP')\n",
+ " fn_auditor = Auditor(dataset, 'FN')\n",
+ " fp_violations = []\n",
+ " fn_violations = []\n",
+ " for g in gamma_list:\n",
+ " print('gamma: {} '.format(g), end =\" \")\n",
+ " fair_model = GerryFairClassifier(C=100, printflag=False, gamma=g, max_iters=iters)\n",
+ " fair_model.gamma=g\n",
+ " fair_model.fit(dataset)\n",
+ " predictions = array_to_tuple((fair_model.predict(dataset)).labels)\n",
+ " _, fp_diff = fp_auditor.audit(predictions)\n",
+ " _, fn_diff = fn_auditor.audit(predictions)\n",
+ " fp_violations.append(fp_diff)\n",
+ " fn_violations.append(fn_diff)\n",
+ "\n",
+ " plt.plot(fp_violations, fn_violations, label='adult')\n",
+ " plt.xlabel('False Positive Disparity')\n",
+ " plt.ylabel('False Negative Disparity')\n",
+ " plt.legend()\n",
+ " plt.title('FP vs FN Unfairness')\n",
+ " plt.savefig('gerryfair_fp_fn.png')\n",
+ " plt.close()\n",
+ "\n",
+ "gamma_list = [0.001, 0.002, 0.003, 0.004, 0.005, 0.0075, 0.01, 0.02, 0.03, 0.05]\n",
+ "fp_vs_fn(data_set, gamma_list, pareto_iters)\n",
+ "Image(filename='gerryfair_fp_fn.png')"
]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ },
+ "id": "wn6BQLiFhI5N"
+ },
+ "outputs": [],
+ "source": []
}
- ],
- "source": [
- "def fp_vs_fn(dataset, gamma_list, iters):\n",
- " fp_auditor = Auditor(dataset, 'FP')\n",
- " fn_auditor = Auditor(dataset, 'FN')\n",
- " fp_violations = []\n",
- " fn_violations = []\n",
- " for g in gamma_list:\n",
- " print('gamma: {} '.format(g), end =\" \")\n",
- " fair_model = GerryFairClassifier(C=100, printflag=False, gamma=g, max_iters=iters)\n",
- " fair_model.gamma=g\n",
- " fair_model.fit(dataset)\n",
- " predictions = array_to_tuple((fair_model.predict(dataset)).labels)\n",
- " _, fp_diff = fp_auditor.audit(predictions)\n",
- " _, fn_diff = fn_auditor.audit(predictions)\n",
- " fp_violations.append(fp_diff)\n",
- " fn_violations.append(fn_diff)\n",
- "\n",
- " plt.plot(fp_violations, fn_violations, label='adult')\n",
- " plt.xlabel('False Positive Disparity')\n",
- " plt.ylabel('False Negative Disparity')\n",
- " plt.legend()\n",
- " plt.title('FP vs FN Unfairness')\n",
- " plt.savefig('gerryfair_fp_fn.png')\n",
- " plt.close()\n",
- "\n",
- "gamma_list = [0.001, 0.002, 0.003, 0.004, 0.005, 0.0075, 0.01, 0.02, 0.03, 0.05]\n",
- "fp_vs_fn(data_set, gamma_list, pareto_iters)\n",
- "Image(filename='gerryfair_fp_fn.png')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.9"
+ },
"pycharm": {
- "is_executing": true
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ },
+ "colab": {
+ "provenance": []
}
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- },
- "pycharm": {
- "stem_cell": {
- "cell_type": "raw",
- "metadata": {
- "collapsed": false
- },
- "source": []
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 1
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/demo_json_explainers.ipynb b/examples/demo_json_explainers.ipynb
index da23274c..fe8e8ddb 100644
--- a/examples/demo_json_explainers.ipynb
+++ b/examples/demo_json_explainers.ipynb
@@ -1,263 +1,309 @@
{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# Load all necessary packages\n",
- "import sys\n",
- "sys.path.append(\"../\")\n",
- "from collections import OrderedDict\n",
- "import json\n",
- "from pprint import pprint\n",
- "from aif360.datasets import GermanDataset\n",
- "from aif360.metrics import BinaryLabelDatasetMetric\n",
- "from aif360.explainers import MetricTextExplainer, MetricJSONExplainer\n",
- "from IPython.display import JSON, display_json"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##### Load dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "gd = GermanDataset()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##### Create metrics"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "priv = [{'sex': 1}]\n",
- "unpriv = [{'sex': 0}]\n",
- "bldm = BinaryLabelDatasetMetric(gd, unprivileged_groups=unpriv, privileged_groups=priv)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##### Create explainers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "text_expl = MetricTextExplainer(bldm)\n",
- "json_expl = MetricJSONExplainer(bldm)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##### Text explanations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
+ "cells": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Number of positive-outcome instances: 700.0\n"
- ]
- }
- ],
- "source": [
- "print(text_expl.num_positives())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_json_explainers.ipynb)\n"
+ ],
+ "metadata": {
+ "id": "PXISC-e0l2JJ"
+ }
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Mean difference (mean label value on privileged instances - mean label value on unprivileged instances): -0.0748013090229\n"
- ]
- }
- ],
- "source": [
- "print(text_expl.mean_difference())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "s-b2u2Rvl013"
+ },
+ "outputs": [],
+ "source": [
+ "# Load all necessary packages\n",
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "from collections import OrderedDict\n",
+ "import json\n",
+ "from pprint import pprint\n",
+ "from aif360.datasets import GermanDataset\n",
+ "from aif360.metrics import BinaryLabelDatasetMetric\n",
+ "from aif360.explainers import MetricTextExplainer, MetricJSONExplainer\n",
+ "from IPython.display import JSON, display_json"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Disparate impact (probability of favorable outcome for unprivileged instances / probability of favorable outcome for privileged instances): 0.896567328205\n"
- ]
- }
- ],
- "source": [
- "print(text_expl.disparate_impact())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##### JSON Explanations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def format_json(json_str):\n",
- " return json.dumps(json.loads(json_str, object_pairs_hook=OrderedDict), indent=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3HohcQsBl016"
+ },
+ "source": [
+ "##### Load dataset"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{\n",
- " \"metric\": \"num_positives\", \n",
- " \"message\": \"Number of positive-outcome instances: 700.0\", \n",
- " \"numPositives\": 700.0, \n",
- " \"description\": \"Computed as the number of positive instances for the given (privileged or unprivileged) group.\", \n",
- " \"ideal\": \"The ideal value of this metric lies in the total number of positive instances made available\"\n",
- "}\n"
- ]
- }
- ],
- "source": [
- "print(format_json(json_expl.num_positives()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "vHsL60ZSl017"
+ },
+ "outputs": [],
+ "source": [
+ "gd = GermanDataset()"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{\n",
- " \"metric\": \"mean_difference\", \n",
- " \"message\": \"Mean difference (mean label value on privileged instances - mean label value on unprivileged instances): -0.0748013090229\", \n",
- " \"numPositivesUnprivileged\": 201.0, \n",
- " \"numInstancesUnprivileged\": 310.0, \n",
- " \"numPositivesPrivileged\": 499.0, \n",
- " \"numInstancesPrivileged\": 690.0, \n",
- " \"description\": \"Computed as the difference of the rate of favorable outcomes received by the unprivileged group to the privileged group.\", \n",
- " \"ideal\": \"The ideal value of this metric is 0.0\"\n",
- "}\n"
- ]
- }
- ],
- "source": [
- "print(format_json(json_expl.mean_difference()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Fy7C2Lttl017"
+ },
+ "source": [
+ "##### Create metrics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "BWIQjCUdl018"
+ },
+ "outputs": [],
+ "source": [
+ "priv = [{'sex': 1}]\n",
+ "unpriv = [{'sex': 0}]\n",
+ "bldm = BinaryLabelDatasetMetric(gd, unprivileged_groups=unpriv, privileged_groups=priv)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eUsxBOxVl018"
+ },
+ "source": [
+ "##### Create explainers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "7agLzgCCl018"
+ },
+ "outputs": [],
+ "source": [
+ "text_expl = MetricTextExplainer(bldm)\n",
+ "json_expl = MetricJSONExplainer(bldm)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hEKcHC3Fl019"
+ },
+ "source": [
+ "##### Text explanations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4yfddrtAl019",
+ "outputId": "680af6c9-53c0-4b51-d4b6-d09207f00ccd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of positive-outcome instances: 700.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(text_expl.num_positives())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Xg91poCjl01-",
+ "outputId": "e310a8ce-3730-4e08-dafd-6f9ecd26945a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mean difference (mean label value on privileged instances - mean label value on unprivileged instances): -0.0748013090229\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(text_expl.mean_difference())"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{\n",
- " \"metric\": \"disparate_impact\", \n",
- " \"message\": \"Disparate impact (probability of favorable outcome for unprivileged instances / probability of favorable outcome for privileged instances): 0.896567328205\", \n",
- " \"numPositivePredictionsUnprivileged\": 201.0, \n",
- " \"numUnprivileged\": 310.0, \n",
- " \"numPositivePredictionsPrivileged\": 499.0, \n",
- " \"numPrivileged\": 690.0, \n",
- " \"description\": \"Computed as the ratio of likelihood of favorable outcome for the unprivileged group to that of the privileged group.\", \n",
- " \"ideal\": \"The ideal value of this metric is 1.0\"\n",
- "}\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "W9HdX1wUl01-",
+ "outputId": "aff81d6b-f850-4a9c-e25a-60e77c69b093"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Disparate impact (probability of favorable outcome for unprivileged instances / probability of favorable outcome for privileged instances): 0.896567328205\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(text_expl.disparate_impact())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LA2uVnjPl01_"
+ },
+ "source": [
+ "##### JSON Explanations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "JCosLnQbl01_"
+ },
+ "outputs": [],
+ "source": [
+ "def format_json(json_str):\n",
+ " return json.dumps(json.loads(json_str, object_pairs_hook=OrderedDict), indent=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "F1nemIIel01_",
+ "outputId": "7c03e9aa-4aee-4862-ddd5-e7abd20d5941"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"metric\": \"num_positives\", \n",
+ " \"message\": \"Number of positive-outcome instances: 700.0\", \n",
+ " \"numPositives\": 700.0, \n",
+ " \"description\": \"Computed as the number of positive instances for the given (privileged or unprivileged) group.\", \n",
+ " \"ideal\": \"The ideal value of this metric lies in the total number of positive instances made available\"\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(format_json(json_expl.num_positives()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vaBapB4xl01_",
+ "outputId": "381346be-c9e9-4f41-dbaf-2c18cdd06ff6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"metric\": \"mean_difference\", \n",
+ " \"message\": \"Mean difference (mean label value on privileged instances - mean label value on unprivileged instances): -0.0748013090229\", \n",
+ " \"numPositivesUnprivileged\": 201.0, \n",
+ " \"numInstancesUnprivileged\": 310.0, \n",
+ " \"numPositivesPrivileged\": 499.0, \n",
+ " \"numInstancesPrivileged\": 690.0, \n",
+ " \"description\": \"Computed as the difference of the rate of favorable outcomes received by the unprivileged group to the privileged group.\", \n",
+ " \"ideal\": \"The ideal value of this metric is 0.0\"\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(format_json(json_expl.mean_difference()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qExRlmw8l02A",
+ "outputId": "11b63f8e-4cc6-4192-e600-de0231ae33cc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"metric\": \"disparate_impact\", \n",
+ " \"message\": \"Disparate impact (probability of favorable outcome for unprivileged instances / probability of favorable outcome for privileged instances): 0.896567328205\", \n",
+ " \"numPositivePredictionsUnprivileged\": 201.0, \n",
+ " \"numUnprivileged\": 310.0, \n",
+ " \"numPositivePredictionsPrivileged\": 499.0, \n",
+ " \"numPrivileged\": 690.0, \n",
+ " \"description\": \"Computed as the ratio of likelihood of favorable outcome for the unprivileged group to that of the privileged group.\", \n",
+ " \"ideal\": \"The ideal value of this metric is 1.0\"\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(format_json(json_expl.disparate_impact()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "mu6tkv_xl02A"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.11"
+ },
+ "colab": {
+ "provenance": []
}
- ],
- "source": [
- "print(format_json(json_expl.disparate_impact()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/demo_lfr.ipynb b/examples/demo_lfr.ipynb
index 134a2729..9ed91c43 100644
--- a/examples/demo_lfr.ipynb
+++ b/examples/demo_lfr.ipynb
@@ -1,473 +1,528 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### This notebook demonstrates the use of the learning fair representations algorithm for bias mitigation\n",
- "Learning fair representations [1] is a pre-processing technique that finds a latent representation which encodes the data well but obfuscates information about protected attributes. We will see how to use this algorithm for learning representations that encourage individual fairness and apply them on the Adult dataset.\n",
- "\n",
- "References:\n",
- "\n",
- "[1] R. Zemel, Y. Wu, K. Swersky, T. Pitassi, and C. Dwork, \"Learning Fair Representations.\" \n",
- "International Conference on Machine Learning, 2013."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib inline\n",
- "# Load all necessary packages\n",
- "import sys\n",
- "sys.path.append(\"../\")\n",
- "from aif360.datasets import BinaryLabelDataset\n",
- "from aif360.datasets import AdultDataset\n",
- "from aif360.metrics import BinaryLabelDatasetMetric\n",
- "from aif360.metrics import ClassificationMetric\n",
- "from aif360.metrics.utils import compute_boolean_conditioning_vector\n",
- "\n",
- "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
- "from aif360.algorithms.preprocessing.lfr import LFR\n",
- "\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn.preprocessing import StandardScaler\n",
- "from sklearn.metrics import accuracy_score\n",
- "from sklearn.metrics import classification_report\n",
- "\n",
- "from IPython.display import Markdown, display\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "\n",
- "from common_utils import compute_metrics"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Load dataset and set options"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Get the dataset and split into train and test\n",
- "dataset_orig = load_preproc_data_adult()\n",
- "dataset_orig_train, dataset_orig_test = dataset_orig.split([0.7], shuffle=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Clean up training data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
+ "cells": [
{
- "data": {
- "text/markdown": [
- "#### Training Dataset shape"
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_lfr.ipynb)\n"
],
- "text/plain": [
- ""
+ "metadata": {
+ "id": "tVXbNnjsmVIz"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BKhA67dOmMcG"
+ },
+ "source": [
+ "#### This notebook demonstrates the use of the learning fair representations algorithm for bias mitigation\n",
+ "Learning fair representations [1] is a pre-processing technique that finds a latent representation which encodes the data well but obfuscates information about protected attributes. We will see how to use this algorithm for learning representations that encourage individual fairness and apply them on the Adult dataset.\n",
+ "\n",
+ "References:\n",
+ "\n",
+ "[1] R. Zemel, Y. Wu, K. Swersky, T. Pitassi, and C. Dwork, \"Learning Fair Representations.\"\n",
+ "International Conference on Machine Learning, 2013."
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(34189, 18)\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wG_aksDfmMcJ"
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "# Load all necessary packages\n",
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "from aif360.datasets import BinaryLabelDataset\n",
+ "from aif360.datasets import AdultDataset\n",
+ "from aif360.metrics import BinaryLabelDatasetMetric\n",
+ "from aif360.metrics import ClassificationMetric\n",
+ "from aif360.metrics.utils import compute_boolean_conditioning_vector\n",
+ "\n",
+ "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
+ "from aif360.algorithms.preprocessing.lfr import LFR\n",
+ "\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.metrics import accuracy_score\n",
+ "from sklearn.metrics import classification_report\n",
+ "\n",
+ "from IPython.display import Markdown, display\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "from common_utils import compute_metrics"
+ ]
},
{
- "data": {
- "text/markdown": [
- "#### Favorable and unfavorable labels"
- ],
- "text/plain": [
- ""
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AtUDRcFzmMcK"
+ },
+ "source": [
+ "#### Load dataset and set options"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1.0 0.0\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UeELy5k1mMcL"
+ },
+ "outputs": [],
+ "source": [
+ "# Get the dataset and split into train and test\n",
+ "dataset_orig = load_preproc_data_adult()\n",
+ "dataset_orig_train, dataset_orig_test = dataset_orig.split([0.7], shuffle=True)"
+ ]
},
{
- "data": {
- "text/markdown": [
- "#### Protected attribute names"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rKE1eMYemMcL"
+ },
+ "source": [
+ "#### Clean up training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1rBcLsI-mMcL",
+ "outputId": "027ce352-06e6-478c-ceb6-1b810401643d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "#### Training Dataset shape"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(34189, 18)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "#### Favorable and unfavorable labels"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1.0 0.0\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "#### Protected attribute names"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['sex', 'race']\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "#### Privileged and unprivileged protected attribute values"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[array([1.]), array([1.])] [array([0.]), array([0.])]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "#### Dataset feature names"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['race', 'sex', 'Age (decade)=10', 'Age (decade)=20', 'Age (decade)=30', 'Age (decade)=40', 'Age (decade)=50', 'Age (decade)=60', 'Age (decade)=>=70', 'Education Years=6', 'Education Years=7', 'Education Years=8', 'Education Years=9', 'Education Years=10', 'Education Years=11', 'Education Years=12', 'Education Years=<6', 'Education Years=>12']\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "# print out some labels, names, etc.\n",
+ "display(Markdown(\"#### Training Dataset shape\"))\n",
+ "print(dataset_orig_train.features.shape)\n",
+ "display(Markdown(\"#### Favorable and unfavorable labels\"))\n",
+ "print(dataset_orig_train.favorable_label, dataset_orig_train.unfavorable_label)\n",
+ "display(Markdown(\"#### Protected attribute names\"))\n",
+ "print(dataset_orig_train.protected_attribute_names)\n",
+ "display(Markdown(\"#### Privileged and unprivileged protected attribute values\"))\n",
+ "print(dataset_orig_train.privileged_protected_attributes,\n",
+ " dataset_orig_train.unprivileged_protected_attributes)\n",
+ "display(Markdown(\"#### Dataset feature names\"))\n",
+ "print(dataset_orig_train.feature_names)"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['sex', 'race']\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4pY8KI0QmMcM"
+ },
+ "source": [
+ "#### Metric for original training data"
+ ]
},
{
- "data": {
- "text/markdown": [
- "#### Privileged and unprivileged protected attribute values"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "n5qIXMbemMcM",
+ "outputId": "ad97d8eb-3381-4388-982f-5c1f070d9504"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "#### Original training dataset"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Difference in mean outcomes between unprivileged and privileged groups = -0.193139\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "#### Original test dataset"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Difference in mean outcomes between unprivileged and privileged groups = -0.197697\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "# Metric for the original dataset\n",
+ "privileged_groups = [{'sex': 1.0}]\n",
+ "unprivileged_groups = [{'sex': 0.0}]\n",
+ "\n",
+ "metric_orig_train = BinaryLabelDatasetMetric(dataset_orig_train,\n",
+ " unprivileged_groups=unprivileged_groups,\n",
+ " privileged_groups=privileged_groups)\n",
+ "display(Markdown(\"#### Original training dataset\"))\n",
+ "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_train.mean_difference())\n",
+ "metric_orig_test = BinaryLabelDatasetMetric(dataset_orig_test,\n",
+ " unprivileged_groups=unprivileged_groups,\n",
+ " privileged_groups=privileged_groups)\n",
+ "display(Markdown(\"#### Original test dataset\"))\n",
+ "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_test.mean_difference())\n"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[array([1.]), array([1.])] [array([0.]), array([0.])]\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xvap7BOvmMcN"
+ },
+ "source": [
+ "#### Train with and transform the original training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HelNehJ4mMcN"
+ },
+ "outputs": [],
+ "source": [
+ "scale_orig = StandardScaler()\n",
+ "dataset_orig_train.features = scale_orig.fit_transform(dataset_orig_train.features)\n",
+ "dataset_orig_test.features = scale_orig.transform(dataset_orig_test.features)"
+ ]
},
{
- "data": {
- "text/markdown": [
- "#### Dataset feature names"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "V3qCK0q-mMcN",
+ "outputId": "76a156cf-7de7-40ef-87a1-c89add05e108"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "step: 0, loss: 1.0939550595829053, L_x: 2.531834521858599, L_y: 0.8200826015334493, L_z: 0.010344502931797964\n",
+ "step: 250, loss: 0.9162820270109503, L_x: 2.529109218043187, L_y: 0.6432961063010657, L_z: 0.010037499452782905\n",
+ "step: 500, loss: 0.8207071510514392, L_x: 2.5204911168067197, L_y: 0.5500397646035967, L_z: 0.00930913738358528\n",
+ "step: 750, loss: 0.8102771268166408, L_x: 2.511873834704061, L_y: 0.5427956868742799, L_z: 0.008147028235977415\n",
+ "step: 1000, loss: 0.7996570283329768, L_x: 2.480828451323288, L_y: 0.5399446552800813, L_z: 0.00581476396028337\n",
+ "step: 1250, loss: 0.7844631169970814, L_x: 2.4242508289183613, L_y: 0.5304307199052671, L_z: 0.005803657099989009\n",
+ "step: 1500, loss: 0.7653305722023572, L_x: 2.3297047767431986, L_y: 0.5176248867874912, L_z: 0.007367603870273078\n",
+ "step: 1750, loss: 0.7154304631442515, L_x: 2.085955877234543, L_y: 0.48081670080967953, L_z: 0.013009087305558827\n",
+ "step: 2000, loss: 0.6906420918886886, L_x: 1.896344106091722, L_y: 0.4646651544564373, L_z: 0.018171263411539594\n",
+ "step: 2250, loss: 0.6783680937630076, L_x: 1.7895665853948028, L_y: 0.4587714378849466, L_z: 0.020319998669290275\n",
+ "step: 2500, loss: 0.6725576747654705, L_x: 1.742061633693402, L_y: 0.4577729094336143, L_z: 0.020289300981257967\n",
+ "step: 2750, loss: 0.6694103860159343, L_x: 1.7548885984309939, L_y: 0.4545867175857845, L_z: 0.019667404293525217\n",
+ "step: 3000, loss: 0.6658207636894926, L_x: 1.7515234617350093, L_y: 0.4539151313299769, L_z: 0.018376643093007367\n",
+ "step: 3250, loss: 0.6481415219979564, L_x: 1.7252276686316934, L_y: 0.4491717858033674, L_z: 0.013223484665709846\n",
+ "step: 3500, loss: 0.645366243737316, L_x: 1.7196207136719521, L_y: 0.4482843307446003, L_z: 0.012559920812760247\n",
+ "step: 3750, loss: 0.6425278186287126, L_x: 1.7117758355776211, L_y: 0.4473063883366716, L_z: 0.012021923367139413\n",
+ "step: 4000, loss: 0.6419409673076768, L_x: 1.7092609385556714, L_y: 0.44744616781598634, L_z: 0.011784352818061686\n",
+ "step: 4250, loss: 0.6377801462539607, L_x: 1.6917081956472533, L_y: 0.4496335370425122, L_z: 0.009487894823361622\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "# Input recontruction quality - Ax\n",
+ "# Fairness constraint - Az\n",
+ "# Output prediction error - Ay\n",
+ "\n",
+ "privileged_groups = [{'sex': 1}]\n",
+ "unprivileged_groups = [{'sex': 0}]\n",
+ "\n",
+ "TR = LFR(unprivileged_groups=unprivileged_groups,\n",
+ " privileged_groups=privileged_groups,\n",
+ " k=10, Ax=0.1, Ay=1.0, Az=2.0,\n",
+ " verbose=1\n",
+ " )\n",
+ "TR = TR.fit(dataset_orig_train, maxiter=5000, maxfun=5000)"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['race', 'sex', 'Age (decade)=10', 'Age (decade)=20', 'Age (decade)=30', 'Age (decade)=40', 'Age (decade)=50', 'Age (decade)=60', 'Age (decade)=>=70', 'Education Years=6', 'Education Years=7', 'Education Years=8', 'Education Years=9', 'Education Years=10', 'Education Years=11', 'Education Years=12', 'Education Years=<6', 'Education Years=>12']\n"
- ]
- }
- ],
- "source": [
- "# print out some labels, names, etc.\n",
- "display(Markdown(\"#### Training Dataset shape\"))\n",
- "print(dataset_orig_train.features.shape)\n",
- "display(Markdown(\"#### Favorable and unfavorable labels\"))\n",
- "print(dataset_orig_train.favorable_label, dataset_orig_train.unfavorable_label)\n",
- "display(Markdown(\"#### Protected attribute names\"))\n",
- "print(dataset_orig_train.protected_attribute_names)\n",
- "display(Markdown(\"#### Privileged and unprivileged protected attribute values\"))\n",
- "print(dataset_orig_train.privileged_protected_attributes, \n",
- " dataset_orig_train.unprivileged_protected_attributes)\n",
- "display(Markdown(\"#### Dataset feature names\"))\n",
- "print(dataset_orig_train.feature_names)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Metric for original training data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sO4LI_hbmMcO"
+ },
+ "outputs": [],
+ "source": [
+ "# Transform training data and align features\n",
+ "dataset_transf_train = TR.transform(dataset_orig_train)\n",
+ "dataset_transf_test = TR.transform(dataset_orig_test)"
+ ]
+ },
{
- "data": {
- "text/markdown": [
- "#### Original training dataset"
- ],
- "text/plain": [
- ""
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZNXZ4AlUmMcO"
+ },
+ "outputs": [],
+ "source": [
+ "print(classification_report(dataset_orig_test.labels, dataset_transf_test.labels))"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Difference in mean outcomes between unprivileged and privileged groups = -0.193139\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UA_fxUSemMcO"
+ },
+ "outputs": [],
+ "source": [
+ "metric_transf_train = BinaryLabelDatasetMetric(dataset_transf_train,\n",
+ " unprivileged_groups=unprivileged_groups,\n",
+ " privileged_groups=privileged_groups)\n",
+ "display(Markdown(\"#### Transformed training dataset\"))\n",
+ "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_transf_train.mean_difference())\n",
+ "metric_transf_test = BinaryLabelDatasetMetric(dataset_transf_test,\n",
+ " unprivileged_groups=unprivileged_groups,\n",
+ " privileged_groups=privileged_groups)\n",
+ "display(Markdown(\"#### Transformed test dataset\"))\n",
+ "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_transf_test.mean_difference())\n"
+ ]
},
{
- "data": {
- "text/markdown": [
- "#### Original test dataset"
- ],
- "text/plain": [
- ""
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PRATChL3mMcO"
+ },
+ "outputs": [],
+ "source": [
+ "from common_utils import compute_metrics\n",
+ "\n",
+ "display(Markdown(\"#### Predictions from transformed testing data\"))\n",
+ "bal_acc_arr_transf = []\n",
+ "disp_imp_arr_transf = []\n",
+ "\n",
+ "class_thresh_arr = np.linspace(0.01, 0.99, 100)\n",
+ "\n",
+ "dataset_transf_test_new = dataset_orig_test.copy(deepcopy=True)\n",
+ "dataset_transf_test_new.scores = dataset_transf_test.scores\n",
+ "\n",
+ "\n",
+ "for thresh in class_thresh_arr:\n",
+ "\n",
+ " fav_inds = dataset_transf_test_new.scores > thresh\n",
+ " dataset_transf_test_new.labels[fav_inds] = 1.0\n",
+ " dataset_transf_test_new.labels[~fav_inds] = 0.0\n",
+ "\n",
+ " metric_test_aft = compute_metrics(dataset_orig_test, dataset_transf_test_new,\n",
+ " unprivileged_groups, privileged_groups,\n",
+ " disp = False)\n",
+ "\n",
+ " bal_acc_arr_transf.append(metric_test_aft[\"Balanced accuracy\"])\n",
+ " disp_imp_arr_transf.append(metric_test_aft[\"Disparate impact\"])"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Difference in mean outcomes between unprivileged and privileged groups = -0.197697\n"
- ]
- }
- ],
- "source": [
- "# Metric for the original dataset\n",
- "privileged_groups = [{'sex': 1.0}]\n",
- "unprivileged_groups = [{'sex': 0.0}]\n",
- "\n",
- "metric_orig_train = BinaryLabelDatasetMetric(dataset_orig_train, \n",
- " unprivileged_groups=unprivileged_groups,\n",
- " privileged_groups=privileged_groups)\n",
- "display(Markdown(\"#### Original training dataset\"))\n",
- "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_train.mean_difference())\n",
- "metric_orig_test = BinaryLabelDatasetMetric(dataset_orig_test, \n",
- " unprivileged_groups=unprivileged_groups,\n",
- " privileged_groups=privileged_groups)\n",
- "display(Markdown(\"#### Original test dataset\"))\n",
- "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_test.mean_difference())\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Train with and transform the original training data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "scale_orig = StandardScaler()\n",
- "dataset_orig_train.features = scale_orig.fit_transform(dataset_orig_train.features)\n",
- "dataset_orig_test.features = scale_orig.transform(dataset_orig_test.features)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HGsipNVamMcO"
+ },
+ "outputs": [],
+ "source": [
+ "fig, ax1 = plt.subplots(figsize=(10,7))\n",
+ "ax1.plot(class_thresh_arr, bal_acc_arr_transf)\n",
+ "ax1.set_xlabel('Classification Thresholds', fontsize=16, fontweight='bold')\n",
+ "ax1.set_ylabel('Balanced Accuracy', color='b', fontsize=16, fontweight='bold')\n",
+ "ax1.xaxis.set_tick_params(labelsize=14)\n",
+ "ax1.yaxis.set_tick_params(labelsize=14)\n",
+ "\n",
+ "\n",
+ "ax2 = ax1.twinx()\n",
+ "ax2.plot(class_thresh_arr, np.abs(1.0-np.array(disp_imp_arr_transf)), color='r')\n",
+ "ax2.set_ylabel('abs(1-disparate impact)', color='r', fontsize=16, fontweight='bold')\n",
+ "ax2.yaxis.set_tick_params(labelsize=14)\n",
+ "ax2.grid(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "biB1DXjKmMcP"
+ },
+ "source": [
+ "abs(1-disparate impact) must be small (close to 0) for classifier predictions to be fair."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DxtBqemwmMcP"
+ },
+ "outputs": [],
+ "source": [
+ "display(Markdown(\"#### Individual fairness metrics\"))\n",
+ "print(\"Consistency of labels in transformed training dataset= %f\" %metric_transf_train.consistency())\n",
+ "print(\"Consistency of labels in original training dataset= %f\" %metric_orig_train.consistency())\n",
+ "print(\"Consistency of labels in transformed test dataset= %f\" %metric_transf_test.consistency())\n",
+ "print(\"Consistency of labels in original test dataset= %f\" %metric_orig_test.consistency())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "90gI7jtwmMcP"
+ },
+ "outputs": [],
+ "source": [
+ "def check_algorithm_success():\n",
+ " \"\"\"Transformed dataset consistency should be greater than original dataset.\"\"\"\n",
+ " assert metric_transf_test.consistency() > metric_orig_test.consistency(), \"Transformed dataset consistency should be greater than original dataset.\"\n",
+ "\n",
+ "check_algorithm_success()"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "step: 0, loss: 1.0939550595829053, L_x: 2.531834521858599, L_y: 0.8200826015334493, L_z: 0.010344502931797964\n",
- "step: 250, loss: 0.9162820270109503, L_x: 2.529109218043187, L_y: 0.6432961063010657, L_z: 0.010037499452782905\n",
- "step: 500, loss: 0.8207071510514392, L_x: 2.5204911168067197, L_y: 0.5500397646035967, L_z: 0.00930913738358528\n",
- "step: 750, loss: 0.8102771268166408, L_x: 2.511873834704061, L_y: 0.5427956868742799, L_z: 0.008147028235977415\n",
- "step: 1000, loss: 0.7996570283329768, L_x: 2.480828451323288, L_y: 0.5399446552800813, L_z: 0.00581476396028337\n",
- "step: 1250, loss: 0.7844631169970814, L_x: 2.4242508289183613, L_y: 0.5304307199052671, L_z: 0.005803657099989009\n",
- "step: 1500, loss: 0.7653305722023572, L_x: 2.3297047767431986, L_y: 0.5176248867874912, L_z: 0.007367603870273078\n",
- "step: 1750, loss: 0.7154304631442515, L_x: 2.085955877234543, L_y: 0.48081670080967953, L_z: 0.013009087305558827\n",
- "step: 2000, loss: 0.6906420918886886, L_x: 1.896344106091722, L_y: 0.4646651544564373, L_z: 0.018171263411539594\n",
- "step: 2250, loss: 0.6783680937630076, L_x: 1.7895665853948028, L_y: 0.4587714378849466, L_z: 0.020319998669290275\n",
- "step: 2500, loss: 0.6725576747654705, L_x: 1.742061633693402, L_y: 0.4577729094336143, L_z: 0.020289300981257967\n",
- "step: 2750, loss: 0.6694103860159343, L_x: 1.7548885984309939, L_y: 0.4545867175857845, L_z: 0.019667404293525217\n",
- "step: 3000, loss: 0.6658207636894926, L_x: 1.7515234617350093, L_y: 0.4539151313299769, L_z: 0.018376643093007367\n",
- "step: 3250, loss: 0.6481415219979564, L_x: 1.7252276686316934, L_y: 0.4491717858033674, L_z: 0.013223484665709846\n",
- "step: 3500, loss: 0.645366243737316, L_x: 1.7196207136719521, L_y: 0.4482843307446003, L_z: 0.012559920812760247\n",
- "step: 3750, loss: 0.6425278186287126, L_x: 1.7117758355776211, L_y: 0.4473063883366716, L_z: 0.012021923367139413\n",
- "step: 4000, loss: 0.6419409673076768, L_x: 1.7092609385556714, L_y: 0.44744616781598634, L_z: 0.011784352818061686\n",
- "step: 4250, loss: 0.6377801462539607, L_x: 1.6917081956472533, L_y: 0.4496335370425122, L_z: 0.009487894823361622\n"
- ]
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GkkBIqXgmMcP"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.10"
+ },
+ "colab": {
+ "provenance": []
}
- ],
- "source": [
- "# Input recontruction quality - Ax\n",
- "# Fairness constraint - Az\n",
- "# Output prediction error - Ay\n",
- "\n",
- "privileged_groups = [{'sex': 1}]\n",
- "unprivileged_groups = [{'sex': 0}]\n",
- " \n",
- "TR = LFR(unprivileged_groups=unprivileged_groups,\n",
- " privileged_groups=privileged_groups,\n",
- " k=10, Ax=0.1, Ay=1.0, Az=2.0,\n",
- " verbose=1\n",
- " )\n",
- "TR = TR.fit(dataset_orig_train, maxiter=5000, maxfun=5000)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Transform training data and align features\n",
- "dataset_transf_train = TR.transform(dataset_orig_train)\n",
- "dataset_transf_test = TR.transform(dataset_orig_test)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(classification_report(dataset_orig_test.labels, dataset_transf_test.labels))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "metric_transf_train = BinaryLabelDatasetMetric(dataset_transf_train, \n",
- " unprivileged_groups=unprivileged_groups,\n",
- " privileged_groups=privileged_groups)\n",
- "display(Markdown(\"#### Transformed training dataset\"))\n",
- "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_transf_train.mean_difference())\n",
- "metric_transf_test = BinaryLabelDatasetMetric(dataset_transf_test, \n",
- " unprivileged_groups=unprivileged_groups,\n",
- " privileged_groups=privileged_groups)\n",
- "display(Markdown(\"#### Transformed test dataset\"))\n",
- "print(\"Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_transf_test.mean_difference())\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from common_utils import compute_metrics\n",
- "\n",
- "display(Markdown(\"#### Predictions from transformed testing data\"))\n",
- "bal_acc_arr_transf = []\n",
- "disp_imp_arr_transf = []\n",
- "\n",
- "class_thresh_arr = np.linspace(0.01, 0.99, 100)\n",
- "\n",
- "dataset_transf_test_new = dataset_orig_test.copy(deepcopy=True)\n",
- "dataset_transf_test_new.scores = dataset_transf_test.scores\n",
- "\n",
- "\n",
- "for thresh in class_thresh_arr:\n",
- " \n",
- " fav_inds = dataset_transf_test_new.scores > thresh\n",
- " dataset_transf_test_new.labels[fav_inds] = 1.0\n",
- " dataset_transf_test_new.labels[~fav_inds] = 0.0\n",
- " \n",
- " metric_test_aft = compute_metrics(dataset_orig_test, dataset_transf_test_new, \n",
- " unprivileged_groups, privileged_groups,\n",
- " disp = False)\n",
- "\n",
- " bal_acc_arr_transf.append(metric_test_aft[\"Balanced accuracy\"])\n",
- " disp_imp_arr_transf.append(metric_test_aft[\"Disparate impact\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax1 = plt.subplots(figsize=(10,7))\n",
- "ax1.plot(class_thresh_arr, bal_acc_arr_transf)\n",
- "ax1.set_xlabel('Classification Thresholds', fontsize=16, fontweight='bold')\n",
- "ax1.set_ylabel('Balanced Accuracy', color='b', fontsize=16, fontweight='bold')\n",
- "ax1.xaxis.set_tick_params(labelsize=14)\n",
- "ax1.yaxis.set_tick_params(labelsize=14)\n",
- "\n",
- "\n",
- "ax2 = ax1.twinx()\n",
- "ax2.plot(class_thresh_arr, np.abs(1.0-np.array(disp_imp_arr_transf)), color='r')\n",
- "ax2.set_ylabel('abs(1-disparate impact)', color='r', fontsize=16, fontweight='bold')\n",
- "ax2.yaxis.set_tick_params(labelsize=14)\n",
- "ax2.grid(True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "abs(1-disparate impact) must be small (close to 0) for classifier predictions to be fair."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "display(Markdown(\"#### Individual fairness metrics\"))\n",
- "print(\"Consistency of labels in transformed training dataset= %f\" %metric_transf_train.consistency())\n",
- "print(\"Consistency of labels in original training dataset= %f\" %metric_orig_train.consistency())\n",
- "print(\"Consistency of labels in transformed test dataset= %f\" %metric_transf_test.consistency())\n",
- "print(\"Consistency of labels in original test dataset= %f\" %metric_orig_test.consistency())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def check_algorithm_success():\n",
- " \"\"\"Transformed dataset consistency should be greater than original dataset.\"\"\"\n",
- " assert metric_transf_test.consistency() > metric_orig_test.consistency(), \"Transformed dataset consistency should be greater than original dataset.\"\n",
- "\n",
- "check_algorithm_success() "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/demo_mdss_classifier_metric.ipynb b/examples/demo_mdss_classifier_metric.ipynb
index 004d2956..13f94437 100644
--- a/examples/demo_mdss_classifier_metric.ipynb
+++ b/examples/demo_mdss_classifier_metric.ipynb
@@ -1,1230 +1,1381 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Bias scan using Multi-Dimensional Subset Scan (MDSS)\n",
- "\n",
- "\"Identifying Significant Predictive Bias in Classifiers\" https://arxiv.org/abs/1611.08292\n",
- "\n",
- "The goal of bias scan is to identify a subgroup(s) that has significantly more predictive bias than would be expected from an unbiased classifier. There are $\\prod_{m=1}^{M}\\left(2^{|X_{m}|}-1\\right)$ unique subgroups from a dataset with $M$ features, with each feature having $|X_{m}|$ discretized values, where a subgroup is any $M$-dimension\n",
- "Cartesian set product, between subsets of feature-values from each feature --- excluding the empty set. Bias scan mitigates this computational hurdle by approximately identifing the most statistically biased subgroup in linear time (rather than exponential).\n",
- "\n",
- "\n",
- "We define the statistical measure of predictive bias function, $score_{bias}(S)$ as a likelihood ratio score and a function of a given subgroup $S$. The null hypothesis is that the given prediction's odds are correct for all subgroups in $\\mathcal{D}$:\n",
- "\n",
- "$$H_{0}:odds(y_{i})=\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}}\\ \\forall i\\in\\mathcal{D}.$$\n",
- "\n",
- "The alternative hypothesis assumes some constant multiplicative bias in the odds for some given subgroup $S$:\n",
- "\n",
- "$$H_{1}:\\ odds(y_{i})=q\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}},\\ \\text{where}\\ q>1\\ \\forall i\\in S\\ \\mathrm{and}\\ q=1\\ \\forall i\\notin S.$$\n",
- "\n",
- "In the classification setting, each observation's likelihood is Bernoulli distributed and assumed independent. This results in the following scoring function for a subgroup $S$:\n",
- "\n",
- "\\begin{align*}\n",
- "score_{bias}(S)= & \\max_{q}\\log\\prod_{i\\in S}\\frac{Bernoulli(\\frac{q\\hat{p}_{i}}{1-\\hat{p}_{i}+q\\hat{p}_{i}})}{Bernoulli(\\hat{p}_{i})}\\\\\n",
- "= & \\max_{q}\\log(q)\\sum_{i\\in S}y_{i}-\\sum_{i\\in S}\\log(1-\\hat{p}_{i}+q\\hat{p}_{i}).\n",
- "\\end{align*}\n",
- "Our bias scan is thus represented as: $S^{*}=FSS(\\mathcal{D},\\mathcal{E},F_{score})=MDSS(\\mathcal{D},\\hat{p},score_{bias})$.\n",
- "\n",
- "where $S^{*}$ is the detected most anomalous subgroup, $FSS$ is one of several subset scan algorithms for different problem settings, $\\mathcal{D}$ is a dataset with outcomes $Y$ and discretized features $\\mathcal{X}$, $\\mathcal{E}$ are a set of expectations or 'normal' values for $Y$, and $F_{score}$ is an expectation-based scoring statistic that measures the amount of anomalousness between subgroup observations and their expectations.\n",
- "\n",
- "Predictive bias emphasizes comparable predictions for a subgroup and its observations and Bias scan provides a more general method that can detect and characterize such bias, or poor classifier fit, in the larger space of all possible subgroups, without a priori specification."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import itertools\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "\n",
- "from aif360.metrics import BinaryLabelDatasetMetric, MDSSClassificationMetric\n",
- "from aif360.detectors import bias_scan\n",
- "\n",
- "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll demonstrate scoring a subset and finding the most anomalous subset with bias scan using the compas dataset.\n",
- "\n",
- "We can specify subgroups to be scored or scan for the most anomalous subgroup. Bias scan allows us to decide if we aim to identify bias as `higher` than expected probabilities or `lower` than expected probabilities. Depending on the favourable label, the corresponding subgroup may be categorized as priviledged or unprivileged."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_orig = load_preproc_data_compas()\n",
- "\n",
- "female_group = [{'sex': 1}]\n",
- "male_group = [{'sex': 0}]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The dataset has the categorical features one-hot encoded so we'll modify the dataset to convert them back \n",
- "to the categorical featues because scanning one-hot encoded features may find subgroups that are not meaningful e.g., a subgroup with 2 race values. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
- "\n",
- "age_cat = np.argmax(dataset_orig_df[['age_cat=Less than 25', 'age_cat=25 to 45',\n",
- " 'age_cat=Greater than 45']].values, axis=1).reshape(-1, 1)\n",
- "priors_count = np.argmax(dataset_orig_df[['priors_count=0', 'priors_count=1 to 3',\n",
- " 'priors_count=More than 3']].values, axis=1).reshape(-1, 1)\n",
- "c_charge_degree = np.argmax(dataset_orig_df[['c_charge_degree=M', 'c_charge_degree=F']].values, axis=1).reshape(-1, 1)\n",
- "\n",
- "features = np.concatenate((dataset_orig_df[['sex', 'race']].values, age_cat, priors_count,\n",
- " c_charge_degree, dataset_orig.labels), axis=1)\n",
- "feature_names = ['sex', 'race', 'age_cat', 'priors_count', 'c_charge_degree']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
+ "cells": [
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
sex
\n",
- "
race
\n",
- "
age_cat
\n",
- "
priors_count
\n",
- "
c_charge_degree
\n",
- "
two_year_recid
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_mdss_classifier_metric.ipynb)"
],
- "text/plain": [
- " sex race age_cat priors_count c_charge_degree two_year_recid\n",
- "0 0.0 0.0 1.0 0.0 1.0 1.0\n",
- "1 0.0 0.0 0.0 2.0 1.0 1.0\n",
- "2 0.0 1.0 1.0 2.0 1.0 1.0\n",
- "3 1.0 1.0 1.0 0.0 0.0 0.0\n",
- "4 0.0 1.0 1.0 0.0 1.0 0.0"
+ "metadata": {
+ "id": "zu6QM0J8CN0q"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kci10l16B90B"
+ },
+ "source": [
+ "## Bias scan using Multi-Dimensional Subset Scan (MDSS)\n",
+ "\n",
+ "\"Identifying Significant Predictive Bias in Classifiers\" https://arxiv.org/abs/1611.08292\n",
+ "\n",
+ "The goal of bias scan is to identify a subgroup(s) that has significantly more predictive bias than would be expected from an unbiased classifier. There are $\\prod_{m=1}^{M}\\left(2^{|X_{m}|}-1\\right)$ unique subgroups from a dataset with $M$ features, with each feature having $|X_{m}|$ discretized values, where a subgroup is any $M$-dimension\n",
+ "Cartesian set product, between subsets of feature-values from each feature --- excluding the empty set. Bias scan mitigates this computational hurdle by approximately identifing the most statistically biased subgroup in linear time (rather than exponential).\n",
+ "\n",
+ "\n",
+ "We define the statistical measure of predictive bias function, $score_{bias}(S)$ as a likelihood ratio score and a function of a given subgroup $S$. The null hypothesis is that the given prediction's odds are correct for all subgroups in $\\mathcal{D}$:\n",
+ "\n",
+ "$$H_{0}:odds(y_{i})=\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}}\\ \\forall i\\in\\mathcal{D}.$$\n",
+ "\n",
+ "The alternative hypothesis assumes some constant multiplicative bias in the odds for some given subgroup $S$:\n",
+ "\n",
+ "$$H_{1}:\\ odds(y_{i})=q\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}},\\ \\text{where}\\ q>1\\ \\forall i\\in S\\ \\mathrm{and}\\ q=1\\ \\forall i\\notin S.$$\n",
+ "\n",
+ "In the classification setting, each observation's likelihood is Bernoulli distributed and assumed independent. This results in the following scoring function for a subgroup $S$:\n",
+ "\n",
+ "\\begin{align*}\n",
+ "score_{bias}(S)= & \\max_{q}\\log\\prod_{i\\in S}\\frac{Bernoulli(\\frac{q\\hat{p}_{i}}{1-\\hat{p}_{i}+q\\hat{p}_{i}})}{Bernoulli(\\hat{p}_{i})}\\\\\n",
+ "= & \\max_{q}\\log(q)\\sum_{i\\in S}y_{i}-\\sum_{i\\in S}\\log(1-\\hat{p}_{i}+q\\hat{p}_{i}).\n",
+ "\\end{align*}\n",
+ "Our bias scan is thus represented as: $S^{*}=FSS(\\mathcal{D},\\mathcal{E},F_{score})=MDSS(\\mathcal{D},\\hat{p},score_{bias})$.\n",
+ "\n",
+ "where $S^{*}$ is the detected most anomalous subgroup, $FSS$ is one of several subset scan algorithms for different problem settings, $\\mathcal{D}$ is a dataset with outcomes $Y$ and discretized features $\\mathcal{X}$, $\\mathcal{E}$ are a set of expectations or 'normal' values for $Y$, and $F_{score}$ is an expectation-based scoring statistic that measures the amount of anomalousness between subgroup observations and their expectations.\n",
+ "\n",
+ "Predictive bias emphasizes comparable predictions for a subgroup and its observations and Bias scan provides a more general method that can detect and characterize such bias, or poor classifier fit, in the larger space of all possible subgroups, without a priori specification."
]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df = pd.DataFrame(features, columns=feature_names + ['two_year_recid'])\n",
- "df.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Training\n",
- "We'll create a structured dataset and then train a simple classifier to predict the probability of the outcome"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "from aif360.datasets import StandardDataset\n",
- "dataset = StandardDataset(df, label_name='two_year_recid', favorable_classes=[0],\n",
- " protected_attribute_names=['sex', 'race'],\n",
- " privileged_classes=[[1], [1]],\n",
- " instance_weights_name=None)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_orig_train, dataset_orig_test = dataset.split([0.7], shuffle=True, seed=0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.124496\n",
- "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.159410\n"
- ]
- }
- ],
- "source": [
- "metric_train = BinaryLabelDatasetMetric(dataset_orig_train,\n",
- " unprivileged_groups=male_group,\n",
- " privileged_groups=female_group)\n",
- "\n",
- "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_train.mean_difference())\n",
- "metric_test = BinaryLabelDatasetMetric(dataset_orig_test,\n",
- " unprivileged_groups=male_group,\n",
- " privileged_groups=female_group)\n",
- "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_test.mean_difference())\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "It shows that overall Females in the dataset have a lower observed recidivism them Males."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "If we train a classifier, the model is likely to pick up this bias in the dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2DYOd2UsB90E"
+ },
+ "outputs": [],
+ "source": [
+ "import itertools\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "from aif360.metrics import BinaryLabelDatasetMetric, MDSSClassificationMetric\n",
+ "from aif360.detectors import bias_scan\n",
+ "\n",
+ "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas"
+ ]
+ },
{
- "data": {
- "text/plain": [
- "LogisticRegression(random_state=0)"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SqG1-54_B90F"
+ },
+ "source": [
+ "We'll demonstrate scoring a subset and finding the most anomalous subset with bias scan using the compas dataset.\n",
+ "\n",
+ "We can specify subgroups to be scored or scan for the most anomalous subgroup. Bias scan allows us to decide if we aim to identify bias as `higher` than expected probabilities or `lower` than expected probabilities. Depending on the favourable label, the corresponding subgroup may be categorized as priviledged or unprivileged."
]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from sklearn.linear_model import LogisticRegression\n",
- "clf = LogisticRegression(solver='lbfgs', C=1.0, penalty='l2', random_state=0)\n",
- "clf.fit(dataset_orig_train.features, dataset_orig_train.labels.flatten())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that the probability scores we use are the probabilities of the favorable label, which is 0 in this case."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "array([0., 1.])"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CeaDWPaEB90G"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_orig = load_preproc_data_compas()\n",
+ "\n",
+ "female_group = [{'sex': 1}]\n",
+ "male_group = [{'sex': 0}]"
]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "clf.classes_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "predictions should reflect the probability of a favorable outcome (i.e. no recidivism)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_bias_test_prob = clf.predict_proba(dataset_orig_test.features)[:, 0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
sex
\n",
- "
race
\n",
- "
age_cat
\n",
- "
priors_count
\n",
- "
c_charge_degree
\n",
- "
observed
\n",
- "
probabilities
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
2.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.552951
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.740959
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.374728
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
2.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.444487
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.584908
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UbcjVVhAB90G"
+ },
+ "source": [
+ "The dataset has the categorical features one-hot encoded so we'll modify the dataset to convert them back\n",
+ "to the categorical featues because scanning one-hot encoded features may find subgroups that are not meaningful e.g., a subgroup with 2 race values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NzbzKqmfB90G"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
+ "\n",
+ "age_cat = np.argmax(dataset_orig_df[['age_cat=Less than 25', 'age_cat=25 to 45',\n",
+ " 'age_cat=Greater than 45']].values, axis=1).reshape(-1, 1)\n",
+ "priors_count = np.argmax(dataset_orig_df[['priors_count=0', 'priors_count=1 to 3',\n",
+ " 'priors_count=More than 3']].values, axis=1).reshape(-1, 1)\n",
+ "c_charge_degree = np.argmax(dataset_orig_df[['c_charge_degree=M', 'c_charge_degree=F']].values, axis=1).reshape(-1, 1)\n",
+ "\n",
+ "features = np.concatenate((dataset_orig_df[['sex', 'race']].values, age_cat, priors_count,\n",
+ " c_charge_degree, dataset_orig.labels), axis=1)\n",
+ "feature_names = ['sex', 'race', 'age_cat', 'priors_count', 'c_charge_degree']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true,
+ "id": "QxeSBzMVB90H",
+ "outputId": "a64c16ec-f5e1-40fc-81f9-1c0a4e96d521"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
race
\n",
+ "
age_cat
\n",
+ "
priors_count
\n",
+ "
c_charge_degree
\n",
+ "
two_year_recid
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex race age_cat priors_count c_charge_degree two_year_recid\n",
+ "0 0.0 0.0 1.0 0.0 1.0 1.0\n",
+ "1 0.0 0.0 0.0 2.0 1.0 1.0\n",
+ "2 0.0 1.0 1.0 2.0 1.0 1.0\n",
+ "3 1.0 1.0 1.0 0.0 0.0 0.0\n",
+ "4 0.0 1.0 1.0 0.0 1.0 0.0"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
],
- "text/plain": [
- " sex race age_cat priors_count c_charge_degree observed probabilities\n",
- "0 1.0 1.0 2.0 2.0 1.0 1.0 0.552951\n",
- "1 1.0 0.0 1.0 0.0 1.0 0.0 0.740959\n",
- "2 0.0 1.0 0.0 1.0 1.0 0.0 0.374728\n",
- "3 0.0 0.0 2.0 2.0 1.0 1.0 0.444487\n",
- "4 0.0 1.0 1.0 1.0 0.0 1.0 0.584908"
+ "source": [
+ "df = pd.DataFrame(features, columns=feature_names + ['two_year_recid'])\n",
+ "df.head()"
]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df = pd.DataFrame(dataset_orig_test.features, columns=dataset_orig_test.feature_names)\n",
- "df['observed'] = pd.Series(dataset_orig_test.labels.flatten(), index=df.index)\n",
- "df['probabilities'] = pd.Series(dataset_bias_test_prob, index=df.index)\n",
- "df.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll the create another structured dataset as the classified dataset by assigning the predicted probabilities to the scores attribute"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_bias_test = dataset_orig_test.copy()\n",
- "dataset_bias_test.scores = dataset_bias_test_prob\n",
- "dataset_bias_test.labels = dataset_orig_test.labels"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Bias scoring"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "First, we try to observe the difference between the model prediction and the actual observations of the favorable label, which in this case is 0. We create a new test_df for this computation. \n",
- "\n",
- "If the model's average prediction of the favorable label is higher than the actual observations average, then the group is said to be privileged. In the converse case, the group is said to be unprivileged.\n",
- "\n",
- "We would check for whether the male and female groups are privileged or not using mdss score"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9MBbYHoyB90I"
+ },
+ "source": [
+ "### Training\n",
+ "We'll create a structured dataset and then train a simple classifier to predict the probability of the outcome"
+ ]
+ },
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
sex
\n",
- "
race
\n",
- "
age_cat
\n",
- "
priors_count
\n",
- "
c_charge_degree
\n",
- "
two_year_recid
\n",
- "
model_not_recid
\n",
- "
observed_not_recid
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
2479
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
2.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.552951
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
3574
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.740959
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
513
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.374728
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
1725
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
2.0
\n",
- "
2.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.444487
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
96
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.584908
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
...
\n",
- "
\n",
- "
\n",
- "
4931
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.374728
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
3264
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.535753
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
1653
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.490037
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
2607
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.769140
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
2732
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
2.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.251726
\n",
- "
0.0
\n",
- "
\n",
- " \n",
- "
\n",
- "
1584 rows × 8 columns
\n",
- "
"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Mr3tEubCB90I"
+ },
+ "outputs": [],
+ "source": [
+ "from aif360.datasets import StandardDataset\n",
+ "dataset = StandardDataset(df, label_name='two_year_recid', favorable_classes=[0],\n",
+ " protected_attribute_names=['sex', 'race'],\n",
+ " privileged_classes=[[1], [1]],\n",
+ " instance_weights_name=None)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ub0RsE6_B90I"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_orig_train, dataset_orig_test = dataset.split([0.7], shuffle=True, seed=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ozr0_L3AB90I",
+ "outputId": "b4a747cc-8121-452b-dc23-effcd43b89d7"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.124496\n",
+ "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.159410\n"
+ ]
+ }
],
- "text/plain": [
- " sex race age_cat priors_count c_charge_degree two_year_recid \\\n",
- "2479 1.0 1.0 2.0 2.0 1.0 1.0 \n",
- "3574 1.0 0.0 1.0 0.0 1.0 0.0 \n",
- "513 0.0 1.0 0.0 1.0 1.0 0.0 \n",
- "1725 0.0 0.0 2.0 2.0 1.0 1.0 \n",
- "96 0.0 1.0 1.0 1.0 0.0 1.0 \n",
- "... ... ... ... ... ... ... \n",
- "4931 0.0 1.0 0.0 1.0 1.0 0.0 \n",
- "3264 0.0 0.0 0.0 0.0 1.0 1.0 \n",
- "1653 0.0 0.0 1.0 1.0 1.0 0.0 \n",
- "2607 1.0 1.0 1.0 0.0 1.0 1.0 \n",
- "2732 0.0 1.0 0.0 2.0 0.0 1.0 \n",
- "\n",
- " model_not_recid observed_not_recid \n",
- "2479 0.552951 0.0 \n",
- "3574 0.740959 1.0 \n",
- "513 0.374728 1.0 \n",
- "1725 0.444487 0.0 \n",
- "96 0.584908 0.0 \n",
- "... ... ... \n",
- "4931 0.374728 1.0 \n",
- "3264 0.535753 0.0 \n",
- "1653 0.490037 1.0 \n",
- "2607 0.769140 0.0 \n",
- "2732 0.251726 0.0 \n",
- "\n",
- "[1584 rows x 8 columns]"
+ "source": [
+ "metric_train = BinaryLabelDatasetMetric(dataset_orig_train,\n",
+ " unprivileged_groups=male_group,\n",
+ " privileged_groups=female_group)\n",
+ "\n",
+ "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_train.mean_difference())\n",
+ "metric_test = BinaryLabelDatasetMetric(dataset_orig_test,\n",
+ " unprivileged_groups=male_group,\n",
+ " privileged_groups=female_group)\n",
+ "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_test.mean_difference())\n"
]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "test_df = dataset_bias_test.convert_to_dataframe()[0]\n",
- "test_df['model_not_recid'] = dataset_bias_test.scores.flatten()\n",
- "test_df['observed_not_recid'] = 1 - test_df['two_year_recid']\n",
- "test_df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "model_not_recid 0.617561\n",
- "observed_not_recid 0.657051\n",
- "dtype: float64"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uhDQDm_5B90J"
+ },
+ "source": [
+ "It shows that overall Females in the dataset have a lower observed recidivism them Males."
]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Females actual vs predicted rates of positive label\n",
- "test_df[test_df.sex == 1][['model_not_recid','observed_not_recid']].mean()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Since model average predictions for the positive label is lower than the observed average by a substantial amount (about 4%), the female group is most likely unprivileged."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "model_not_recid 0.512443\n",
- "observed_not_recid 0.497642\n",
- "dtype: float64"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VNo-Ix97B90J"
+ },
+ "source": [
+ "If we train a classifier, the model is likely to pick up this bias in the dataset"
]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Males actual vs predicted rates of positive label\n",
- "test_df[test_df.sex == 0][['model_not_recid','observed_not_recid']].mean()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Since model average predictions for the positive label is greater than the observed average by a small amount (about 1.5%), the male group could be privileged."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we'll create an instance of the MDSS Classification Metric and assess the apriori defined privileged and unprivileged groups; females and males respectively. \n",
- "\n",
- "By apriori defining the male group as unprivileged, we are saying we expect that the model's predictions is systematically lower than the actual observation.\n",
- "\n",
- "By apriori defining the female group as privileged, we are saying we expect that the model's predictions is systematically higher than the actual observation.\n",
- "\n",
- "From our mini-analysis above, we know that these hypothesis are unlikely to be true "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "mdss_classified = MDSSClassificationMetric(dataset_orig_test, dataset_bias_test,\n",
- " unprivileged_groups=male_group,\n",
- " privileged_groups=female_group)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "-0.0"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Hj1KWwaOB90J",
+ "outputId": "abf6f0c4-ec8c-4033-a88a-bbb11fa47900"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LogisticRegression(random_state=0)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.linear_model import LogisticRegression\n",
+ "clf = LogisticRegression(solver='lbfgs', C=1.0, penalty='l2', random_state=0)\n",
+ "clf.fit(dataset_orig_train.features, dataset_orig_train.labels.flatten())"
]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# We are asking the question:\n",
- "# Is there evidence that the hypothesized privileged group is actually privileged?\n",
- "\n",
- "female_privileged_score = mdss_classified.score_groups(privileged=True)\n",
- "female_privileged_score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By having a score very close to zero, mdss bias score is informing us that there is no evidence from the data that our hypothesis of the female group being privileged is true."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "-0.0"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8bC5aDp4B90K"
+ },
+ "source": [
+ "Note that the probability scores we use are the probabilities of the favorable label, which is 0 in this case."
]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# We are asking the question:\n",
- "# Is there evidence that the hypothesized unprivileged group is actually unprivileged?\n",
- "\n",
- "male_unprivileged_score = mdss_classified.score_groups(privileged=False)\n",
- "male_unprivileged_score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By having a score very close zero, mdss bias score is informing us that there is no evidence from the data to support our hypothesis of the male group being unprivileged is true."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can flip our initial hypothesis and check if the male group is privileged or the female group is unprivileged."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [],
- "source": [
- "mdss_classified = MDSSClassificationMetric(dataset_orig_test, dataset_bias_test,\n",
- " unprivileged_groups=female_group,\n",
- " privileged_groups=male_group)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "0.63"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tQU_CQtOB90K",
+ "outputId": "f9beec78-31e0-48e7-bfcf-7f276b09af16"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0., 1.])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "clf.classes_"
]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "male_privileged_score = mdss_classified.score_groups(privileged=True)\n",
- "male_privileged_score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By having a positive score, mdss bias score is informing us that there is evidence from the data that our hypothesis of the male group being privileged is true."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "1.1769"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UyhqLqhNB90K"
+ },
+ "source": [
+ "predictions should reflect the probability of a favorable outcome (i.e. no recidivism)."
]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "female_unprivileged_score = mdss_classified.score_groups(privileged=False)\n",
- "female_unprivileged_score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By having a positive score, mdss bias score is informing us that there is evidence from the data to support our hypothesis of the female group being unprivileged is true."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By taking into account the size of the group and the magnitude of the deviation, mdss bias core has been able to tell us the following about the male and female groups:\n",
- "- There is no evidence that the female group is privileged.\n",
- "- There is no evidence that the male group is unprivileged.\n",
- "- There is evidence that the male group is privileged.\n",
- "- There is evidence that the female is unprivileged."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Bias scan\n",
- "We get the bias score for the apriori defined subgroup but assuming we had no prior knowledge \n",
- "about the predictive bias and wanted to find the subgroups with the most bias, we can apply bias scan to identify the priviledged and unpriviledged groups. The privileged argument is not a reference to a group but the direction for which to scan for bias."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [],
- "source": [
- "privileged_subset = bias_scan(df.iloc[:, :-2], df.observed, df.probabilities,\n",
- " favorable_value=dataset_orig_test.favorable_label,\n",
- " penalty=0.5, overpredicted=True)\n",
- "unprivileged_subset = bias_scan(df.iloc[:, :-2], df.observed, df.probabilities,\n",
- " favorable_value=dataset_orig_test.favorable_label,\n",
- " penalty=0.5, overpredicted=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "({'race': [0.0], 'age_cat': [0.0], 'sex': [0.0]}, 3.1526)\n",
- "({'sex': [1.0], 'race': [0.0]}, 3.3036)\n"
- ]
- }
- ],
- "source": [
- "print(privileged_subset)\n",
- "print(unprivileged_subset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [],
- "source": [
- "assert privileged_subset[0]\n",
- "assert unprivileged_subset[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can observe that the bias score is higher than the score of the prior groups. These subgroups are guaranteed to be the highest scoring subgroup among the exponentially many subgroups.\n",
- "\n",
- "For the purposes of this example, the logistic regression model systematically underestimates the recidivism risk of individuals in the `Non-caucasian`, `less than 25`, `Male` subgroup whereas individuals belonging to the `Non-caucasian`, `Female` are assigned a higher risk than is actually observed. We refer to these subgroups as the `detected privileged group` and `detected unprivileged group` respectively."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can create another srtuctured dataset using the new groups to compute other dataset metrics. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "protected_attr_names = set(privileged_subset[0].keys()).union(set(unprivileged_subset[0].keys()))\n",
- "dataset_orig_test.protected_attribute_names = list(protected_attr_names)\n",
- "dataset_bias_test.protected_attribute_names = list(protected_attr_names)\n",
- "\n",
- "protected_attr = np.where(np.isin(dataset_orig_test.feature_names, list(protected_attr_names)))[0]\n",
- "\n",
- "dataset_orig_test.protected_attributes = dataset_orig_test.features[:, protected_attr]\n",
- "dataset_bias_test.protected_attributes = dataset_bias_test.features[:, protected_attr]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [],
- "source": [
- "# converts from dictionary of lists to list of dictionaries\n",
- "a = list(privileged_subset[0].values())\n",
- "subset_values = list(itertools.product(*a))\n",
- "\n",
- "detected_privileged_groups = []\n",
- "for vals in subset_values:\n",
- " detected_privileged_groups.append((dict(zip(privileged_subset[0].keys(), vals))))\n",
- "\n",
- "a = list(unprivileged_subset[0].values())\n",
- "subset_values = list(itertools.product(*a))\n",
- "\n",
- "detected_unprivileged_groups = []\n",
- "for vals in subset_values:\n",
- " detected_unprivileged_groups.append((dict(zip(unprivileged_subset[0].keys(), vals))))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "20HUpT98B90K"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_bias_test_prob = clf.predict_proba(dataset_orig_test.features)[:, 0]"
+ ]
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Test set: Difference in mean outcomes between unprivileged and privileged groups = 0.275836\n"
- ]
- }
- ],
- "source": [
- "metric_bias_test = BinaryLabelDatasetMetric(dataset_bias_test,\n",
- " unprivileged_groups=detected_unprivileged_groups,\n",
- " privileged_groups=detected_privileged_groups)\n",
- "\n",
- "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\"\n",
- " % metric_bias_test.mean_difference())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "It appears the detected privileged group have a higher risk of recidivism than the unprivileged group."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "As noted in the paper, predictive bias is different from predictive fairness so there's no the emphasis in the subgroups having comparable predictions between them. \n",
- "We can investigate the difference in what the model predicts vs what we actually observed as well as the multiplicative difference in the odds of the subgroups."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_choose = df[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
- "temp_df = df.loc[to_choose]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ryABk3w5B90K",
+ "outputId": "234f8d7b-f69a-42ad-9544-bc151e367245"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
race
\n",
+ "
age_cat
\n",
+ "
priors_count
\n",
+ "
c_charge_degree
\n",
+ "
observed
\n",
+ "
probabilities
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
2.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.552951
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.740959
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.374728
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
2.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.444487
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.584908
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex race age_cat priors_count c_charge_degree observed probabilities\n",
+ "0 1.0 1.0 2.0 2.0 1.0 1.0 0.552951\n",
+ "1 1.0 0.0 1.0 0.0 1.0 0.0 0.740959\n",
+ "2 0.0 1.0 0.0 1.0 1.0 0.0 0.374728\n",
+ "3 0.0 0.0 2.0 2.0 1.0 1.0 0.444487\n",
+ "4 0.0 1.0 1.0 1.0 0.0 1.0 0.584908"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = pd.DataFrame(dataset_orig_test.features, columns=dataset_orig_test.feature_names)\n",
+ "df['observed'] = pd.Series(dataset_orig_test.labels.flatten(), index=df.index)\n",
+ "df['probabilities'] = pd.Series(dataset_bias_test_prob, index=df.index)\n",
+ "df.head()"
+ ]
+ },
{
- "data": {
- "text/plain": [
- "'Our detected priviledged group has a size of 192, we observe 67.71% as the average risk of recidivism, but our model predicts 57.30%'"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8cH_NmUuB90L"
+ },
+ "source": [
+ "We'll the create another structured dataset as the classified dataset by assigning the predicted probabilities to the scores attribute"
]
- },
- "execution_count": 29,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "group_obs = temp_df['observed'].mean()\n",
- "group_prob = 1-temp_df['probabilities'].mean()\n",
- "\n",
- "\"Our detected priviledged group has a size of {}, we observe {:.2%} as the average risk of recidivism, but our model predicts {:.2%}\"\\\n",
- ".format(len(temp_df), group_obs, group_prob)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "'This is a multiplicative increase in the odds by 1.562'"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kctU9DX1B90L"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_bias_test = dataset_orig_test.copy()\n",
+ "dataset_bias_test.scores = dataset_bias_test_prob\n",
+ "dataset_bias_test.labels = dataset_orig_test.labels"
]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "odds_mul = (group_obs / (1 - group_obs)) / (group_prob /(1 - group_prob))\n",
- "\"This is a multiplicative increase in the odds by {:.3f}\".format(odds_mul)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [],
- "source": [
- "assert odds_mul > 1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_choose = df[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
- "temp_df = df.loc[to_choose]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "'Our detected unpriviledged group has a size of 169, we observe 33.14% as the average risk of recidivism, but our model predicts 43.65%'"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OB_6l47PB90L"
+ },
+ "source": [
+ "### Bias scoring"
]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "group_obs = temp_df['observed'].mean()\n",
- "group_prob = 1-temp_df['probabilities'].mean()\n",
- "\n",
- "\"Our detected unpriviledged group has a size of {}, we observe {:.2%} as the average risk of recidivism, but our model predicts {:.2%}\"\\\n",
- ".format(len(temp_df), group_obs, group_prob)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/plain": [
- "'This is a multiplicative decrease in the odds by 0.640'"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ykJRqEA3B90L"
+ },
+ "source": [
+ "First, we try to observe the difference between the model prediction and the actual observations of the favorable label, which in this case is 0. We create a new test_df for this computation.\n",
+ "\n",
+ "If the model's average prediction of the favorable label is higher than the actual observations average, then the group is said to be privileged. In the converse case, the group is said to be unprivileged.\n",
+ "\n",
+ "We would check for whether the male and female groups are privileged or not using mdss score"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PV5F2zMwB90L",
+ "outputId": "32bd9359-b7a1-4ee5-89fe-7f838dd50e15"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
race
\n",
+ "
age_cat
\n",
+ "
priors_count
\n",
+ "
c_charge_degree
\n",
+ "
two_year_recid
\n",
+ "
model_not_recid
\n",
+ "
observed_not_recid
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
2479
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
2.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.552951
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
3574
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.740959
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
513
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.374728
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
1725
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
2.0
\n",
+ "
2.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.444487
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
96
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.584908
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
...
\n",
+ "
\n",
+ "
\n",
+ "
4931
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.374728
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
3264
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.535753
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
1653
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.490037
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
2607
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.769140
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
2732
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
2.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.251726
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1584 rows × 8 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex race age_cat priors_count c_charge_degree two_year_recid \\\n",
+ "2479 1.0 1.0 2.0 2.0 1.0 1.0 \n",
+ "3574 1.0 0.0 1.0 0.0 1.0 0.0 \n",
+ "513 0.0 1.0 0.0 1.0 1.0 0.0 \n",
+ "1725 0.0 0.0 2.0 2.0 1.0 1.0 \n",
+ "96 0.0 1.0 1.0 1.0 0.0 1.0 \n",
+ "... ... ... ... ... ... ... \n",
+ "4931 0.0 1.0 0.0 1.0 1.0 0.0 \n",
+ "3264 0.0 0.0 0.0 0.0 1.0 1.0 \n",
+ "1653 0.0 0.0 1.0 1.0 1.0 0.0 \n",
+ "2607 1.0 1.0 1.0 0.0 1.0 1.0 \n",
+ "2732 0.0 1.0 0.0 2.0 0.0 1.0 \n",
+ "\n",
+ " model_not_recid observed_not_recid \n",
+ "2479 0.552951 0.0 \n",
+ "3574 0.740959 1.0 \n",
+ "513 0.374728 1.0 \n",
+ "1725 0.444487 0.0 \n",
+ "96 0.584908 0.0 \n",
+ "... ... ... \n",
+ "4931 0.374728 1.0 \n",
+ "3264 0.535753 0.0 \n",
+ "1653 0.490037 1.0 \n",
+ "2607 0.769140 0.0 \n",
+ "2732 0.251726 0.0 \n",
+ "\n",
+ "[1584 rows x 8 columns]"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_df = dataset_bias_test.convert_to_dataframe()[0]\n",
+ "test_df['model_not_recid'] = dataset_bias_test.scores.flatten()\n",
+ "test_df['observed_not_recid'] = 1 - test_df['two_year_recid']\n",
+ "test_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qMjfpXO4B90M",
+ "outputId": "78697715-5600-4c49-ee4d-a747ded2606d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "model_not_recid 0.617561\n",
+ "observed_not_recid 0.657051\n",
+ "dtype: float64"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Females actual vs predicted rates of positive label\n",
+ "test_df[test_df.sex == 1][['model_not_recid','observed_not_recid']].mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HKaZKOgNB90M"
+ },
+ "source": [
+ "Since model average predictions for the positive label is lower than the observed average by a substantial amount (about 4%), the female group is most likely unprivileged."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_tKQJ0FUB90M",
+ "outputId": "1b6bbce7-e1b7-44c9-8803-59241e5cdc27"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "model_not_recid 0.512443\n",
+ "observed_not_recid 0.497642\n",
+ "dtype: float64"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Males actual vs predicted rates of positive label\n",
+ "test_df[test_df.sex == 0][['model_not_recid','observed_not_recid']].mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "y-wvPZH3B90M"
+ },
+ "source": [
+ "Since model average predictions for the positive label is greater than the observed average by a small amount (about 1.5%), the male group could be privileged."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6fQGJP-2B90M"
+ },
+ "source": [
+ "Now, we'll create an instance of the MDSS Classification Metric and assess the apriori defined privileged and unprivileged groups; females and males respectively.\n",
+ "\n",
+ "By apriori defining the male group as unprivileged, we are saying we expect that the model's predictions is systematically lower than the actual observation.\n",
+ "\n",
+ "By apriori defining the female group as privileged, we are saying we expect that the model's predictions is systematically higher than the actual observation.\n",
+ "\n",
+ "From our mini-analysis above, we know that these hypothesis are unlikely to be true"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yIqKBRPbB90M"
+ },
+ "outputs": [],
+ "source": [
+ "mdss_classified = MDSSClassificationMetric(dataset_orig_test, dataset_bias_test,\n",
+ " unprivileged_groups=male_group,\n",
+ " privileged_groups=female_group)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SI_TOyu0B90N",
+ "outputId": "9c118292-8368-4e0b-fdee-5ef094b6e9a1"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "-0.0"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# We are asking the question:\n",
+ "# Is there evidence that the hypothesized privileged group is actually privileged?\n",
+ "\n",
+ "female_privileged_score = mdss_classified.score_groups(privileged=True)\n",
+ "female_privileged_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OOKIOvxUB90N"
+ },
+ "source": [
+ "By having a score very close to zero, mdss bias score is informing us that there is no evidence from the data that our hypothesis of the female group being privileged is true."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "B8dsK90PB90N",
+ "outputId": "4cf06af8-6ca3-49d6-a9d7-895e7ce69478"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "-0.0"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# We are asking the question:\n",
+ "# Is there evidence that the hypothesized unprivileged group is actually unprivileged?\n",
+ "\n",
+ "male_unprivileged_score = mdss_classified.score_groups(privileged=False)\n",
+ "male_unprivileged_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Anh6JvDAB90N"
+ },
+ "source": [
+ "By having a score very close zero, mdss bias score is informing us that there is no evidence from the data to support our hypothesis of the male group being unprivileged is true."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2QuN7nxrB90N"
+ },
+ "source": [
+ "We can flip our initial hypothesis and check if the male group is privileged or the female group is unprivileged."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Om4DTlYtB90O"
+ },
+ "outputs": [],
+ "source": [
+ "mdss_classified = MDSSClassificationMetric(dataset_orig_test, dataset_bias_test,\n",
+ " unprivileged_groups=female_group,\n",
+ " privileged_groups=male_group)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "baP4caqSB90O",
+ "outputId": "a892284c-1b54-411b-c839-45925cf17bc6"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.63"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "male_privileged_score = mdss_classified.score_groups(privileged=True)\n",
+ "male_privileged_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vZWkojHvB90O"
+ },
+ "source": [
+ "By having a positive score, mdss bias score is informing us that there is evidence from the data that our hypothesis of the male group being privileged is true."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bp8kBr9nB90O",
+ "outputId": "d689519f-775d-4a5f-e372-43de18e122ae"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1.1769"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "female_unprivileged_score = mdss_classified.score_groups(privileged=False)\n",
+ "female_unprivileged_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "v7LHafDqB90P"
+ },
+ "source": [
+ "By having a positive score, mdss bias score is informing us that there is evidence from the data to support our hypothesis of the female group being unprivileged is true."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fIO7ig11B90P"
+ },
+ "source": [
+ "By taking into account the size of the group and the magnitude of the deviation, mdss bias core has been able to tell us the following about the male and female groups:\n",
+ "- There is no evidence that the female group is privileged.\n",
+ "- There is no evidence that the male group is unprivileged.\n",
+ "- There is evidence that the male group is privileged.\n",
+ "- There is evidence that the female is unprivileged."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XlYgT7krB90Q"
+ },
+ "source": [
+ "### Bias scan\n",
+ "We get the bias score for the apriori defined subgroup but assuming we had no prior knowledge\n",
+ "about the predictive bias and wanted to find the subgroups with the most bias, we can apply bias scan to identify the priviledged and unpriviledged groups. The privileged argument is not a reference to a group but the direction for which to scan for bias."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JDyYYrMsB90Q"
+ },
+ "outputs": [],
+ "source": [
+ "privileged_subset = bias_scan(df.iloc[:, :-2], df.observed, df.probabilities,\n",
+ " favorable_value=dataset_orig_test.favorable_label,\n",
+ " penalty=0.5, overpredicted=True)\n",
+ "unprivileged_subset = bias_scan(df.iloc[:, :-2], df.observed, df.probabilities,\n",
+ " favorable_value=dataset_orig_test.favorable_label,\n",
+ " penalty=0.5, overpredicted=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SmDnrJb_B90Q",
+ "outputId": "325ae712-7d15-4a19-9732-7afe72ab3e41"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "({'race': [0.0], 'age_cat': [0.0], 'sex': [0.0]}, 3.1526)\n",
+ "({'sex': [1.0], 'race': [0.0]}, 3.3036)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(privileged_subset)\n",
+ "print(unprivileged_subset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k3i-WdwLB90Q"
+ },
+ "outputs": [],
+ "source": [
+ "assert privileged_subset[0]\n",
+ "assert unprivileged_subset[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LEkybEddB90Q"
+ },
+ "source": [
+ "We can observe that the bias score is higher than the score of the prior groups. These subgroups are guaranteed to be the highest scoring subgroup among the exponentially many subgroups.\n",
+ "\n",
+ "For the purposes of this example, the logistic regression model systematically underestimates the recidivism risk of individuals in the `Non-caucasian`, `less than 25`, `Male` subgroup whereas individuals belonging to the `Non-caucasian`, `Female` are assigned a higher risk than is actually observed. We refer to these subgroups as the `detected privileged group` and `detected unprivileged group` respectively."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I_YxruutB90Q"
+ },
+ "source": [
+ "We can create another srtuctured dataset using the new groups to compute other dataset metrics. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xdVrmsmtB90R"
+ },
+ "outputs": [],
+ "source": [
+ "protected_attr_names = set(privileged_subset[0].keys()).union(set(unprivileged_subset[0].keys()))\n",
+ "dataset_orig_test.protected_attribute_names = list(protected_attr_names)\n",
+ "dataset_bias_test.protected_attribute_names = list(protected_attr_names)\n",
+ "\n",
+ "protected_attr = np.where(np.isin(dataset_orig_test.feature_names, list(protected_attr_names)))[0]\n",
+ "\n",
+ "dataset_orig_test.protected_attributes = dataset_orig_test.features[:, protected_attr]\n",
+ "dataset_bias_test.protected_attributes = dataset_bias_test.features[:, protected_attr]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Q2W84utLB90R"
+ },
+ "outputs": [],
+ "source": [
+ "# converts from dictionary of lists to list of dictionaries\n",
+ "a = list(privileged_subset[0].values())\n",
+ "subset_values = list(itertools.product(*a))\n",
+ "\n",
+ "detected_privileged_groups = []\n",
+ "for vals in subset_values:\n",
+ " detected_privileged_groups.append((dict(zip(privileged_subset[0].keys(), vals))))\n",
+ "\n",
+ "a = list(unprivileged_subset[0].values())\n",
+ "subset_values = list(itertools.product(*a))\n",
+ "\n",
+ "detected_unprivileged_groups = []\n",
+ "for vals in subset_values:\n",
+ " detected_unprivileged_groups.append((dict(zip(unprivileged_subset[0].keys(), vals))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5CQLNeLMB90R",
+ "outputId": "fdf5c829-9610-43a1-b46c-6d6e7409f52c"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Test set: Difference in mean outcomes between unprivileged and privileged groups = 0.275836\n"
+ ]
+ }
+ ],
+ "source": [
+ "metric_bias_test = BinaryLabelDatasetMetric(dataset_bias_test,\n",
+ " unprivileged_groups=detected_unprivileged_groups,\n",
+ " privileged_groups=detected_privileged_groups)\n",
+ "\n",
+ "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\"\n",
+ " % metric_bias_test.mean_difference())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FpsKHVQOB90R"
+ },
+ "source": [
+ "It appears the detected privileged group have a higher risk of recidivism than the unprivileged group."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aM5Xi7wYB90R"
+ },
+ "source": [
+ "As noted in the paper, predictive bias is different from predictive fairness so there's no the emphasis in the subgroups having comparable predictions between them.\n",
+ "We can investigate the difference in what the model predicts vs what we actually observed as well as the multiplicative difference in the odds of the subgroups."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-S80czQ6B90R"
+ },
+ "outputs": [],
+ "source": [
+ "to_choose = df[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
+ "temp_df = df.loc[to_choose]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Tv-pOahHB90S",
+ "outputId": "b4b11f9c-8ec7-4cff-c4fb-fb1dbc17f383"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Our detected priviledged group has a size of 192, we observe 67.71% as the average risk of recidivism, but our model predicts 57.30%'"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "group_obs = temp_df['observed'].mean()\n",
+ "group_prob = 1-temp_df['probabilities'].mean()\n",
+ "\n",
+ "\"Our detected priviledged group has a size of {}, we observe {:.2%} as the average risk of recidivism, but our model predicts {:.2%}\"\\\n",
+ ".format(len(temp_df), group_obs, group_prob)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mhzLh_ZWB90S",
+ "outputId": "6507496a-8086-4605-a135-e031dde9de69"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'This is a multiplicative increase in the odds by 1.562'"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "odds_mul = (group_obs / (1 - group_obs)) / (group_prob /(1 - group_prob))\n",
+ "\"This is a multiplicative increase in the odds by {:.3f}\".format(odds_mul)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4CBlc07aB90S"
+ },
+ "outputs": [],
+ "source": [
+ "assert odds_mul > 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "u-e4oDAwB90S"
+ },
+ "outputs": [],
+ "source": [
+ "to_choose = df[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
+ "temp_df = df.loc[to_choose]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hOwgUMAsB90S",
+ "outputId": "d45edf71-6256-4ae8-8c21-60c1a5275de4"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Our detected unpriviledged group has a size of 169, we observe 33.14% as the average risk of recidivism, but our model predicts 43.65%'"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "group_obs = temp_df['observed'].mean()\n",
+ "group_prob = 1-temp_df['probabilities'].mean()\n",
+ "\n",
+ "\"Our detected unpriviledged group has a size of {}, we observe {:.2%} as the average risk of recidivism, but our model predicts {:.2%}\"\\\n",
+ ".format(len(temp_df), group_obs, group_prob)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DHMHVUyOB90T",
+ "outputId": "705cf23c-7d1c-41e4-a2a7-cb8b39a5b498"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'This is a multiplicative decrease in the odds by 0.640'"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "odds_mul = (group_obs / (1 - group_obs)) / (group_prob /(1 - group_prob))\n",
+ "\"This is a multiplicative decrease in the odds by {:.3f}\".format(odds_mul)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RComzQQWB90T"
+ },
+ "outputs": [],
+ "source": [
+ "assert odds_mul < 1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pak55cFNB90T"
+ },
+ "source": [
+ "In summary, this notebook demonstrates the use of bias scan to identify subgroups with significant predictive bias, as quantified by a likelihood ratio score, using subset scanning. This allows consideration of not just subgroups of a priori interest or small dimensions, but the space of all possible subgroups of features.\n",
+ "It also presents opportunity for a kind of bias mitigation technique that uses the multiplicative odds in the over-or-under estimated subgroups to adjust for predictive fairness."
]
- },
- "execution_count": 34,
- "metadata": {},
- "output_type": "execute_result"
}
- ],
- "source": [
- "odds_mul = (group_obs / (1 - group_obs)) / (group_prob /(1 - group_prob))\n",
- "\"This is a multiplicative decrease in the odds by {:.3f}\".format(odds_mul)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [],
- "source": [
- "assert odds_mul < 1"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In summary, this notebook demonstrates the use of bias scan to identify subgroups with significant predictive bias, as quantified by a likelihood ratio score, using subset scanning. This allows consideration of not just subgroups of a priori interest or small dimensions, but the space of all possible subgroups of features.\n",
- "It also presents opportunity for a kind of bias mitigation technique that uses the multiplicative odds in the over-or-under estimated subgroups to adjust for predictive fairness."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.7 ('aif360')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.7"
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.9.7 ('aif360')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
+ }
+ },
+ "colab": {
+ "provenance": []
+ }
},
- "vscode": {
- "interpreter": {
- "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/demo_mdss_detector.ipynb b/examples/demo_mdss_detector.ipynb
index 6ce0ed8d..1cf24f3e 100644
--- a/examples/demo_mdss_detector.ipynb
+++ b/examples/demo_mdss_detector.ipynb
@@ -1,1542 +1,1730 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Bias scan using Multi-Dimensional Subset Scan (MDSS)\n",
- "\n",
- "\"Identifying Significant Predictive Bias in Classifiers\" https://arxiv.org/abs/1611.08292\n",
- "\n",
- "The goal of bias scan is to identify a subgroup(s) that has significantly more predictive bias than would be expected from an unbiased classifier. There are $\\prod_{m=1}^{M}\\left(2^{|X_{m}|}-1\\right)$ unique subgroups from a dataset with $M$ features, with each feature having $|X_{m}|$ discretized values, where a subgroup is any $M$-dimension\n",
- "Cartesian set product, between subsets of feature-values from each feature --- excluding the empty set. Bias scan mitigates this computational hurdle by approximately identifing the most statistically biased subgroup in linear time (rather than exponential).\n",
- "\n",
- "\n",
- "We define the statistical measure of predictive bias function, $score_{bias}(S)$ as a likelihood ratio score and a function of a given subgroup $S$. The null hypothesis is that the given prediction's odds are correct for all subgroups in $\\mathcal{D}$:\n",
- "\n",
- "$$H_{0}:odds(y_{i})=\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}}\\ \\forall i\\in\\mathcal{D}.$$\n",
- "\n",
- "The alternative hypothesis assumes some constant multiplicative bias in the odds for some given subgroup $S$:\n",
- "\n",
- "$$H_{1}:\\ odds(y_{i})=q\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}},\\ \\text{where}\\ q>1\\ \\forall i\\in S\\ \\mathrm{and}\\ q=1\\ \\forall i\\notin S.$$\n",
- "\n",
- "In the classification setting, each observation's likelihood is Bernoulli distributed and assumed independent. This results in the following scoring function for a subgroup $S$:\n",
- "\n",
- "\\begin{align*}\n",
- "score_{bias}(S)= & \\max_{q}\\log\\prod_{i\\in S}\\frac{Bernoulli(\\frac{q\\hat{p}_{i}}{1-\\hat{p}_{i}+q\\hat{p}_{i}})}{Bernoulli(\\hat{p}_{i})}\\\\\n",
- "= & \\max_{q}\\log(q)\\sum_{i\\in S}y_{i}-\\sum_{i\\in S}\\log(1-\\hat{p}_{i}+q\\hat{p}_{i}).\n",
- "\\end{align*}\n",
- "Our bias scan is thus represented as: $S^{*}=FSS(\\mathcal{D},\\mathcal{E},F_{score})=MDSS(\\mathcal{D},\\hat{p},score_{bias})$.\n",
- "\n",
- "where $S^{*}$ is the detected most anomalous subgroup, $FSS$ is one of several subset scan algorithms for different problem settings, $\\mathcal{D}$ is a dataset with outcomes $Y$ and discretized features $\\mathcal{X}$, $\\mathcal{E}$ are a set of expectations or 'normal' values for $Y$, and $F_{score}$ is an expectation-based scoring statistic that measures the amount of anomalousness between subgroup observations and their expectations.\n",
- "\n",
- "Predictive bias emphasizes comparable predictions for a subgroup and its observations and Bias scan provides a more general method that can detect and characterize such bias, or poor classifier fit, in the larger space of all possible subgroups, without a priori specification."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Usage\n",
- "\n",
- "MDScan currently supports three scoring functions. These scoring functions usage are described below:\n",
- "- *BerkJones*: Non-parametric scoring function. To be used for all of the four types of outcomes supported - binary, continuous, nominal, ordinal.\n",
- "- *Bernoulli*: Parametric scoring function. To used for two of the four types of outcomes supported - binary and nominal.\n",
- "- *Guassian*: Parametric scoring function. To used for one of the four types of outcomes supported - continuous.\n",
- "- *Poisson*: Parametric scoring function. To be used for three of the four types of outcomes supported - binary, continuous, and ordinal.\n",
- "\n",
- "Note, non-parametric scoring functions can only be used for datasets where the expectations are constant or none.\n",
- "\n",
- "The type of outcomes must be provided using the mode keyword argument. The definition for the four types of outcomes supported are provided below:\n",
- "- Binary: Yes/no outcomes. Outcomes must 0 or 1.\n",
- "- Continuous: Continuous outcomes. Outcomes could be any real number.\n",
- "- Nominal: Multiclass outcomes with no rank or order between them. Outcomes must be a finite set of integers with dimensionality <= 10.\n",
- "- Ordinal: Multiclass outcomes that are ranked in a specific order. Outcomes must be positive integers.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from aif360.detectors.mdss_detector import bias_scan\n",
- "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll demonstrate finding the most anomalous subset with bias scan using the compas dataset. We can specify subgroups to be scored or scan for the most anomalous subgroup. Bias scan allows us to decide if we aim to identify bias as `higher` than expected probabilities or `lower` than expected probabilities."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Compas Dataset\n",
- "This is a binary classification use case where the favorable label is 0 and the scoring function is the default bernoulli."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "np.random.seed(0)\n",
- "\n",
- "dataset_orig = load_preproc_data_compas()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The dataset has the categorical features one-hot encoded so we'll modify the dataset to convert them back \n",
- "to the categorical featues because scanning one-hot encoded features may find subgroups that are not meaningful eg. a subgroup with 2 race values. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
- "\n",
- "age_cat = np.argmax(dataset_orig_df[['age_cat=Less than 25', 'age_cat=25 to 45', \n",
- " 'age_cat=Greater than 45']].values, axis=1).reshape(-1, 1)\n",
- "priors_count = np.argmax(dataset_orig_df[['priors_count=0', 'priors_count=1 to 3', \n",
- " 'priors_count=More than 3']].values, axis=1).reshape(-1, 1)\n",
- "c_charge_degree = np.argmax(dataset_orig_df[['c_charge_degree=F', 'c_charge_degree=M']].values, axis=1).reshape(-1, 1)\n",
- "\n",
- "features = np.concatenate((dataset_orig_df[['sex', 'race']].values, age_cat, priors_count, \\\n",
- " c_charge_degree, dataset_orig.labels), axis=1)\n",
- "feature_names = ['sex', 'race', 'age_cat', 'priors_count', 'c_charge_degree']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = pd.DataFrame(features, columns=feature_names + ['two_year_recid'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
sex
\n",
- "
race
\n",
- "
age_cat
\n",
- "
priors_count
\n",
- "
c_charge_degree
\n",
- "
two_year_recid
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
2.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
2.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
0.0
\n",
- "
1.0
\n",
- "
1.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
0.0
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_mdss_detector.ipynb)"
],
- "text/plain": [
- " sex race age_cat priors_count c_charge_degree two_year_recid\n",
- "0 0.0 0.0 1.0 0.0 0.0 1.0\n",
- "1 0.0 0.0 0.0 2.0 0.0 1.0\n",
- "2 0.0 1.0 1.0 2.0 0.0 1.0\n",
- "3 1.0 1.0 1.0 0.0 1.0 0.0\n",
- "4 0.0 1.0 1.0 0.0 0.0 0.0"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### training\n",
- "We'll train a simple classifier to predict the probability of the outcome"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "LogisticRegression()"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from sklearn.linear_model import LogisticRegression\n",
- "X = df.drop('two_year_recid', axis = 1)\n",
- "y = df['two_year_recid']\n",
- "clf = LogisticRegression(solver='lbfgs', C=1.0, penalty='l2')\n",
- "clf.fit(X, y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that the probability scores we use are the probabilities of the favorable label, which is 0 in this case."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "probs = pd.Series(clf.predict_proba(X)[:,0])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### bias scan\n",
- "We can scan for a privileged and unprivileged subset using bias scan"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "privileged_subset = bias_scan(data=X,observations=y,expectations=probs,favorable_value=0, overpredicted=True)\n",
- "unprivileged_subset = bias_scan(data=X,observations=y,expectations=probs,favorable_value=0,overpredicted=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "({'age_cat': [1.0], 'priors_count': [0.0, 1.0, 2.0], 'sex': [1.0], 'race': [1.0], 'c_charge_degree': [0.0]}, 7.9086)\n",
- "({'race': [0.0], 'age_cat': [1.0, 2.0], 'priors_count': [1.0], 'c_charge_degree': [0.0, 1.0]}, 7.0227)\n"
- ]
- }
- ],
- "source": [
- "print(privileged_subset)\n",
- "print(unprivileged_subset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "dff = X.copy()\n",
- "dff['observed'] = y \n",
- "dff['probabilities'] = 1 - probs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_choose = dff[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected priviledged group has a size of 147, we observe 0.5374149659863946 as the average risk of recidivism, but our model predicts 0.38278159716895366'"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "\"Our detected priviledged group has a size of {}, we observe {} as the average risk of recidivism, but our model predicts {}\"\\\n",
- ".format(len(temp_df), temp_df['observed'].mean(), temp_df['probabilities'].mean())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_choose = dff[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected priviledged group has a size of 732, we observe 0.3770491803278688 as the average risk of recidivism, but our model predicts 0.4447038821779929'"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "\"Our detected priviledged group has a size of {}, we observe {} as the average risk of recidivism, but our model predicts {}\"\\\n",
- ".format(len(temp_df), temp_df['observed'].mean(), temp_df['probabilities'].mean())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Adult Dataset\n",
- "This is a binary classification use case where the favorable label is 1 and the scoring function is the berk jones."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
workclass
\n",
- "
education
\n",
- "
marital_status
\n",
- "
occupation
\n",
- "
relationship
\n",
- "
race
\n",
- "
sex
\n",
- "
native_country
\n",
- "
age_bin
\n",
- "
education_num_bin
\n",
- "
hours_per_week_bin
\n",
- "
capital_gain_bin
\n",
- "
capital_loss_bin
\n",
- "
observed
\n",
- "
expectation
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
Private
\n",
- "
11th
\n",
- "
Never-married
\n",
- "
Machine-op-inspct
\n",
- "
Own-child
\n",
- "
Black
\n",
- "
Male
\n",
- "
United-States
\n",
- "
17-27
\n",
- "
1-8
\n",
- "
40-44
\n",
- "
0
\n",
- "
0
\n",
- "
0
\n",
- "
0.236226
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
Private
\n",
- "
HS-grad
\n",
- "
Married-civ-spouse
\n",
- "
Farming-fishing
\n",
- "
Husband
\n",
- "
White
\n",
- "
Male
\n",
- "
United-States
\n",
- "
37-47
\n",
- "
9
\n",
- "
45-99
\n",
- "
0
\n",
- "
0
\n",
- "
0
\n",
- "
0.236226
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
Local-gov
\n",
- "
Assoc-acdm
\n",
- "
Married-civ-spouse
\n",
- "
Protective-serv
\n",
- "
Husband
\n",
- "
White
\n",
- "
Male
\n",
- "
United-States
\n",
- "
28-36
\n",
- "
12-16
\n",
- "
40-44
\n",
- "
0
\n",
- "
0
\n",
- "
1
\n",
- "
0.236226
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
Private
\n",
- "
Some-college
\n",
- "
Married-civ-spouse
\n",
- "
Machine-op-inspct
\n",
- "
Husband
\n",
- "
Black
\n",
- "
Male
\n",
- "
United-States
\n",
- "
37-47
\n",
- "
10-11
\n",
- "
40-44
\n",
- "
7298-7978
\n",
- "
0
\n",
- "
1
\n",
- "
0.236226
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
?
\n",
- "
Some-college
\n",
- "
Never-married
\n",
- "
?
\n",
- "
Own-child
\n",
- "
White
\n",
- "
Female
\n",
- "
United-States
\n",
- "
17-27
\n",
- "
10-11
\n",
- "
1-39
\n",
- "
0
\n",
- "
0
\n",
- "
0
\n",
- "
0.236226
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "metadata": {
+ "id": "cot8Opn7Ck5r"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xYzhBYSCChiJ"
+ },
+ "source": [
+ "## Bias scan using Multi-Dimensional Subset Scan (MDSS)\n",
+ "\n",
+ "\"Identifying Significant Predictive Bias in Classifiers\" https://arxiv.org/abs/1611.08292\n",
+ "\n",
+ "The goal of bias scan is to identify a subgroup(s) that has significantly more predictive bias than would be expected from an unbiased classifier. There are $\\prod_{m=1}^{M}\\left(2^{|X_{m}|}-1\\right)$ unique subgroups from a dataset with $M$ features, with each feature having $|X_{m}|$ discretized values, where a subgroup is any $M$-dimension\n",
+ "Cartesian set product, between subsets of feature-values from each feature --- excluding the empty set. Bias scan mitigates this computational hurdle by approximately identifing the most statistically biased subgroup in linear time (rather than exponential).\n",
+ "\n",
+ "\n",
+ "We define the statistical measure of predictive bias function, $score_{bias}(S)$ as a likelihood ratio score and a function of a given subgroup $S$. The null hypothesis is that the given prediction's odds are correct for all subgroups in $\\mathcal{D}$:\n",
+ "\n",
+ "$$H_{0}:odds(y_{i})=\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}}\\ \\forall i\\in\\mathcal{D}.$$\n",
+ "\n",
+ "The alternative hypothesis assumes some constant multiplicative bias in the odds for some given subgroup $S$:\n",
+ "\n",
+ "$$H_{1}:\\ odds(y_{i})=q\\frac{\\hat{p}_{i}}{1-\\hat{p}_{i}},\\ \\text{where}\\ q>1\\ \\forall i\\in S\\ \\mathrm{and}\\ q=1\\ \\forall i\\notin S.$$\n",
+ "\n",
+ "In the classification setting, each observation's likelihood is Bernoulli distributed and assumed independent. This results in the following scoring function for a subgroup $S$:\n",
+ "\n",
+ "\\begin{align*}\n",
+ "score_{bias}(S)= & \\max_{q}\\log\\prod_{i\\in S}\\frac{Bernoulli(\\frac{q\\hat{p}_{i}}{1-\\hat{p}_{i}+q\\hat{p}_{i}})}{Bernoulli(\\hat{p}_{i})}\\\\\n",
+ "= & \\max_{q}\\log(q)\\sum_{i\\in S}y_{i}-\\sum_{i\\in S}\\log(1-\\hat{p}_{i}+q\\hat{p}_{i}).\n",
+ "\\end{align*}\n",
+ "Our bias scan is thus represented as: $S^{*}=FSS(\\mathcal{D},\\mathcal{E},F_{score})=MDSS(\\mathcal{D},\\hat{p},score_{bias})$.\n",
+ "\n",
+ "where $S^{*}$ is the detected most anomalous subgroup, $FSS$ is one of several subset scan algorithms for different problem settings, $\\mathcal{D}$ is a dataset with outcomes $Y$ and discretized features $\\mathcal{X}$, $\\mathcal{E}$ are a set of expectations or 'normal' values for $Y$, and $F_{score}$ is an expectation-based scoring statistic that measures the amount of anomalousness between subgroup observations and their expectations.\n",
+ "\n",
+ "Predictive bias emphasizes comparable predictions for a subgroup and its observations and Bias scan provides a more general method that can detect and characterize such bias, or poor classifier fit, in the larger space of all possible subgroups, without a priori specification."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_StPscBKChiM"
+ },
+ "source": [
+ "### Usage\n",
+ "\n",
+ "MDScan currently supports three scoring functions. These scoring functions usage are described below:\n",
+ "- *BerkJones*: Non-parametric scoring function. To be used for all of the four types of outcomes supported - binary, continuous, nominal, ordinal.\n",
+ "- *Bernoulli*: Parametric scoring function. To used for two of the four types of outcomes supported - binary and nominal.\n",
+ "- *Guassian*: Parametric scoring function. To used for one of the four types of outcomes supported - continuous.\n",
+ "- *Poisson*: Parametric scoring function. To be used for three of the four types of outcomes supported - binary, continuous, and ordinal.\n",
+ "\n",
+ "Note, non-parametric scoring functions can only be used for datasets where the expectations are constant or none.\n",
+ "\n",
+ "The type of outcomes must be provided using the mode keyword argument. The definition for the four types of outcomes supported are provided below:\n",
+ "- Binary: Yes/no outcomes. Outcomes must 0 or 1.\n",
+ "- Continuous: Continuous outcomes. Outcomes could be any real number.\n",
+ "- Nominal: Multiclass outcomes with no rank or order between them. Outcomes must be a finite set of integers with dimensionality <= 10.\n",
+ "- Ordinal: Multiclass outcomes that are ranked in a specific order. Outcomes must be positive integers.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pEw-BVOaChiM"
+ },
+ "outputs": [],
+ "source": [
+ "from aif360.detectors.mdss_detector import bias_scan\n",
+ "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WBaUrHfAChiN"
+ },
+ "source": [
+ "We'll demonstrate finding the most anomalous subset with bias scan using the compas dataset. We can specify subgroups to be scored or scan for the most anomalous subgroup. Bias scan allows us to decide if we aim to identify bias as `higher` than expected probabilities or `lower` than expected probabilities."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b-4Qb66XChiN"
+ },
+ "source": [
+ "# Compas Dataset\n",
+ "This is a binary classification use case where the favorable label is 0 and the scoring function is the default bernoulli."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Yzhak7pCChiO"
+ },
+ "outputs": [],
+ "source": [
+ "np.random.seed(0)\n",
+ "\n",
+ "dataset_orig = load_preproc_data_compas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_aaSrX5ZChiO"
+ },
+ "source": [
+ "The dataset has the categorical features one-hot encoded so we'll modify the dataset to convert them back\n",
+ "to the categorical featues because scanning one-hot encoded features may find subgroups that are not meaningful eg. a subgroup with 2 race values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "EsV3GU1TChiO"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
+ "\n",
+ "age_cat = np.argmax(dataset_orig_df[['age_cat=Less than 25', 'age_cat=25 to 45',\n",
+ " 'age_cat=Greater than 45']].values, axis=1).reshape(-1, 1)\n",
+ "priors_count = np.argmax(dataset_orig_df[['priors_count=0', 'priors_count=1 to 3',\n",
+ " 'priors_count=More than 3']].values, axis=1).reshape(-1, 1)\n",
+ "c_charge_degree = np.argmax(dataset_orig_df[['c_charge_degree=F', 'c_charge_degree=M']].values, axis=1).reshape(-1, 1)\n",
+ "\n",
+ "features = np.concatenate((dataset_orig_df[['sex', 'race']].values, age_cat, priors_count, \\\n",
+ " c_charge_degree, dataset_orig.labels), axis=1)\n",
+ "feature_names = ['sex', 'race', 'age_cat', 'priors_count', 'c_charge_degree']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jWLqSK0pChiO"
+ },
+ "outputs": [],
+ "source": [
+ "df = pd.DataFrame(features, columns=feature_names + ['two_year_recid'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true,
+ "id": "EfTmpJKMChiP",
+ "outputId": "b3287ae8-bee7-4d60-bfac-a201c384849f"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
race
\n",
+ "
age_cat
\n",
+ "
priors_count
\n",
+ "
c_charge_degree
\n",
+ "
two_year_recid
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
2.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
2.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex race age_cat priors_count c_charge_degree two_year_recid\n",
+ "0 0.0 0.0 1.0 0.0 0.0 1.0\n",
+ "1 0.0 0.0 0.0 2.0 0.0 1.0\n",
+ "2 0.0 1.0 1.0 2.0 0.0 1.0\n",
+ "3 1.0 1.0 1.0 0.0 1.0 0.0\n",
+ "4 0.0 1.0 1.0 0.0 0.0 0.0"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
],
- "text/plain": [
- " workclass education marital_status occupation \\\n",
- "0 Private 11th Never-married Machine-op-inspct \n",
- "1 Private HS-grad Married-civ-spouse Farming-fishing \n",
- "2 Local-gov Assoc-acdm Married-civ-spouse Protective-serv \n",
- "3 Private Some-college Married-civ-spouse Machine-op-inspct \n",
- "4 ? Some-college Never-married ? \n",
- "\n",
- " relationship race sex native_country age_bin education_num_bin \\\n",
- "0 Own-child Black Male United-States 17-27 1-8 \n",
- "1 Husband White Male United-States 37-47 9 \n",
- "2 Husband White Male United-States 28-36 12-16 \n",
- "3 Husband Black Male United-States 37-47 10-11 \n",
- "4 Own-child White Female United-States 17-27 10-11 \n",
- "\n",
- " hours_per_week_bin capital_gain_bin capital_loss_bin observed expectation \n",
- "0 40-44 0 0 0 0.236226 \n",
- "1 45-99 0 0 0 0.236226 \n",
- "2 40-44 0 0 1 0.236226 \n",
- "3 40-44 7298-7978 0 1 0.236226 \n",
- "4 1-39 0 0 0 0.236226 "
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data = pd.read_csv('https://gist.githubusercontent.com/Viktour19/b690679802c431646d36f7e2dd117b9e/raw/d8f17bf25664bd2d9fa010750b9e451c4155dd61/adult_autostrat.csv')\n",
- "data.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that for the adult dataset, the positive label is 1 and thus the expectations provided is the probability of the earning >50k i.e label 1 and the favorable label is 1 which is the default for binary classification tasks. Since we would be using scoring function BerkJones, we also need to pass in an alpha value. Alpha can be interpreted as what proportion of the data you expect to have the favorable value"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "X = data.drop(['observed','expectation'], axis = 1)\n",
- "probs = data['expectation']\n",
- "y = data['observed']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "privileged_subset = bias_scan(data=X, observations=y, scoring='BerkJones', expectations=probs, overpredicted=True,penalty=50, alpha = .24)\n",
- "unprivileged_subset = bias_scan(data=X,observations=y, scoring='BerkJones', expectations=probs, overpredicted=False,penalty=50, alpha = .24)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "({'relationship': [' Not-in-family', ' Other-relative', ' Own-child', ' Unmarried'], 'capital_gain_bin': ['0']}, 932.4812)\n",
- "({'education_num_bin': ['12-16'], 'marital_status': [' Married-civ-spouse']}, 1041.1901)\n"
- ]
- }
- ],
- "source": [
- "print(privileged_subset)\n",
- "print(unprivileged_subset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [],
- "source": [
- "dff = X.copy()\n",
- "dff['observed'] = y \n",
- "dff['probabilities'] = probs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected privileged group has a size of 8532, we observe 0.0472 as the average probability of earning >50k, but our model predicts 0.2362'"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = dff[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]\n",
- "\n",
- "\"Our detected privileged group has a size of {}, we observe {} as the average probability of earning >50k, but our model predicts {}\"\\\n",
- ".format(len(temp_df), np.round(temp_df['observed'].mean(),4), np.round(temp_df['probabilities'].mean(),4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected unprivileged group has a size of 2430, we observe 0.6996 as the average probability of earning >50k, but our model predicts 0.2362'"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = dff[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]\n",
- "\n",
- "\"Our detected unprivileged group has a size of {}, we observe {} as the average probability of earning >50k, but our model predicts {}\"\\\n",
- ".format(len(temp_df), np.round(temp_df['observed'].mean(),4), np.round(temp_df['probabilities'].mean(),4))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Insurance Costs\n",
- "This is a regression use case where the favorable value is 0 and the scoring function is Gaussian."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(1338, 7)"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data = pd.read_csv('https://raw.githubusercontent.com/Adebayo-Oshingbesan/data/main/insurance.csv')\n",
- "data.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "for col in ['bmi','age']:\n",
- " data[col] = pd.qcut(data[col], 10, duplicates='drop')\n",
- " data[col] = data[col].apply(lambda x: str(round(x.left, 2)) + ' - ' + str(round(x.right,2)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [],
- "source": [
- "features = data.drop('charges', axis = 1)\n",
- "X = features.copy()\n",
- "\n",
- "for feature in X.columns:\n",
- " X[feature] = X[feature].astype('category').cat.codes\n",
- "\n",
- "y = data['charges']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "from sklearn.linear_model import LinearRegression\n",
- "reg = LinearRegression()\n",
- "reg.fit(X, y)\n",
- "y_pred = pd.Series(reg.predict(X))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [],
- "source": [
- "privileged_subset = bias_scan(data=features, observations=y, expectations=y_pred, scoring = 'Gaussian', \n",
- " overpredicted=True, penalty=1e10, mode ='continuous', favorable_value='low')\n",
- "\n",
- "unprivileged_subset = bias_scan(data=features, observations=y, expectations=y_pred, scoring = 'Gaussian', \n",
- " overpredicted=False, penalty=1e10, mode ='continuous', favorable_value='low')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "({'bmi': ['15.96 - 22.99', '22.99 - 25.33', '25.33 - 27.36'], 'smoker': ['no']}, 2384.5786)\n",
- "({'bmi': ['15.96 - 22.99', '22.99 - 25.33', '25.33 - 27.36', '27.36 - 28.8'], 'smoker': ['yes']}, 3927.8765)\n"
- ]
- }
- ],
- "source": [
- "print(privileged_subset)\n",
- "print(unprivileged_subset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected privileged group has a size of 321, we observe 7844.8402958566985 as the mean insurance costs, but our model predicts 5420.493262774548'"
- ]
- },
- "execution_count": 28,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = data[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
- "temp_df = data.loc[to_choose].copy()\n",
- "temp_y = y_pred.loc[to_choose].copy()\n",
- "\n",
- "\"Our detected privileged group has a size of {}, we observe {} as the mean insurance costs, but our model predicts {}\"\\\n",
- ".format(len(temp_df), temp_df['charges'].mean(), temp_y.mean())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected privileged group has a size of 115, we observe 21148.373896173915 as the mean insurance costs, but our model predicts 29694.035319112845'"
- ]
- },
- "execution_count": 29,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = data[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
- "temp_df = data.loc[to_choose].copy()\n",
- "temp_y = y_pred.loc[to_choose].copy()\n",
- "\n",
- "\"Our detected privileged group has a size of {}, we observe {} as the mean insurance costs, but our model predicts {}\"\\\n",
- ".format(len(temp_df), temp_df['charges'].mean(), temp_y.mean())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Hospitalization Time\n",
- "This is an ordinal, multiclass classification use case where the favorable value is 1 and the scoring function is Poisson."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(29980, 22)"
- ]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data = pd.read_csv('https://raw.githubusercontent.com/Adebayo-Oshingbesan/data/main/hospital.csv')\n",
- "data = data[data['Length of Stay'] != '120 +'].fillna('Unknown')\n",
- "data.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [],
- "source": [
- "X = data.drop(['Length of Stay'], axis = 1)\n",
- "y = pd.to_numeric(data['Length of Stay'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [],
- "source": [
- "privileged_subset = bias_scan(data=X, observations=y, scoring = 'Poisson', favorable_value = 'low', overpredicted=True, penalty=50, mode ='ordinal')\n",
- "unprivileged_subset = bias_scan(data=X, observations=y, scoring = 'Poisson', favorable_value = 'low', overpredicted=False, penalty=50, mode ='ordinal')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "({'APR Severity of Illness Description': ['Extreme']}, 11180.5386)\n",
- "({'Patient Disposition': ['Home or Self Care', 'Left Against Medical Advice', 'Short-term Hospital'], 'APR Severity of Illness Description': ['Minor', 'Moderate'], 'APR MDC Code': [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 21]}, 9950.881)\n"
- ]
- }
- ],
- "source": [
- "print(privileged_subset)\n",
- "print(unprivileged_subset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [],
- "source": [
- "dff = X.copy()\n",
- "dff['observed'] = y \n",
- "dff['predicted'] = y.mean()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected privileged group has a size of 1900, we observe 15.2216 as the average number of days spent in the hospital, but our model predicts 5.4231'"
- ]
- },
- "execution_count": 35,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = dff[privileged_subset[0].keys()].isin(privileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]\n",
- "\n",
- "\"Our detected privileged group has a size of {}, we observe {} as the average number of days spent in the hospital, but our model predicts {}\"\\\n",
- ".format(len(temp_df), np.round(temp_df['observed'].mean(),4), np.round(temp_df['predicted'].mean(),4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Our detected unprivileged group has a size of 14620, we observe 2.8301 as the average number of days spent in the hospital, but our model predicts 5.4231'"
- ]
- },
- "execution_count": 36,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "to_choose = dff[unprivileged_subset[0].keys()].isin(unprivileged_subset[0]).all(axis=1)\n",
- "temp_df = dff.loc[to_choose]\n",
- "\n",
- "\"Our detected unprivileged group has a size of {}, we observe {} as the average number of days spent in the hospital, but our model predicts {}\"\\\n",
- ".format(len(temp_df), np.round(temp_df['observed'].mean(),4), np.round(temp_df['predicted'].mean(),4))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Temperature Dataset\n",
- "This is a regression use case where the favorable value is the higher temperatures and the scoring function is Berk Jones."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
",
+ "image/svg+xml": "\n\n\n\n",
+ "image/png": "\n"
+ },
+ "metadata": {
+ "needs_background": "light"
+ }
+ }
+ ],
+ "source": [
+ "fig, ax1 = plt.subplots(figsize=(13,7))\n",
+ "ax1.plot(all_tau, accuracies, color='r')\n",
+ "ax1.set_title('Accuracy and $\\gamma_{sr}$ vs Tau', fontsize=16, fontweight='bold')\n",
+ "ax1.set_xlabel('Input Tau', fontsize=16, fontweight='bold')\n",
+ "ax1.set_ylabel('Accuracy', color='r', fontsize=16, fontweight='bold')\n",
+ "ax1.xaxis.set_tick_params(labelsize=14)\n",
+ "ax1.yaxis.set_tick_params(labelsize=14)\n",
+ "\n",
+ "ax2 = ax1.twinx()\n",
+ "ax2.plot(all_tau, statistical_rates, color='b')\n",
+ "ax2.set_ylabel('$\\gamma_{sr}$', color='b', fontsize=16, fontweight='bold')\n",
+ "ax2.yaxis.set_tick_params(labelsize=14)\n",
+ "ax2.grid(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HRHHvcuQC1uS"
+ },
+ "source": [
+ "References:\n",
+ "\n",
+ " Celis, L. E., Huang, L., Keswani, V., & Vishnoi, N. K. (2018).\n",
+ " \"Classification with Fairness Constraints: A Meta-Algorithm with Provable Guarantees.\"\"\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.6.9 64-bit",
+ "language": "python",
+ "name": "python_defaultSpec_1596663900877"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.15"
+ },
+ "colab": {
+ "provenance": []
}
- ],
- "source": [
- "fig, ax1 = plt.subplots(figsize=(13,7))\n",
- "ax1.plot(all_tau, accuracies, color='r')\n",
- "ax1.set_title('Accuracy and $\\gamma_{sr}$ vs Tau', fontsize=16, fontweight='bold')\n",
- "ax1.set_xlabel('Input Tau', fontsize=16, fontweight='bold')\n",
- "ax1.set_ylabel('Accuracy', color='r', fontsize=16, fontweight='bold')\n",
- "ax1.xaxis.set_tick_params(labelsize=14)\n",
- "ax1.yaxis.set_tick_params(labelsize=14)\n",
- "\n",
- "ax2 = ax1.twinx()\n",
- "ax2.plot(all_tau, statistical_rates, color='b')\n",
- "ax2.set_ylabel('$\\gamma_{sr}$', color='b', fontsize=16, fontweight='bold')\n",
- "ax2.yaxis.set_tick_params(labelsize=14)\n",
- "ax2.grid(True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "References:\n",
- "\n",
- " Celis, L. E., Huang, L., Keswani, V., & Vishnoi, N. K. (2018). \n",
- " \"Classification with Fairness Constraints: A Meta-Algorithm with Provable Guarantees.\"\"\n"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.6.9 64-bit",
- "language": "python",
- "name": "python_defaultSpec_1596663900877"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.15"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
+ "nbformat": 4,
+ "nbformat_minor": 0
}
\ No newline at end of file
diff --git a/examples/demo_ot_metric.ipynb b/examples/demo_ot_metric.ipynb
index 9c8bc904..dd8217d8 100644
--- a/examples/demo_ot_metric.ipynb
+++ b/examples/demo_ot_metric.ipynb
@@ -1,1379 +1,1485 @@
{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Measuring bias with Optimal Transport by calculating the Wasserstein distance"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Table of contents\n",
- "\n",
- "- Introduction\n",
- "- General Optimal Transport examples\n",
- "- Usage\n",
- "- Application to Compas Dataset\n",
- "- Application to Adult Dataset\n",
- "- More details\n",
- " - OT for mapping estimation\n",
- " - Kantorovich optimal transport problem\n",
- " - Solving optimal transport\n",
- " - Necessity and priority of usage"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## References\n",
- "\n",
- "\"FlipTest: fairness testing via optimal transport\" https://dl.acm.org/doi/abs/10.1145/3351095.3372845\n",
- "\n",
- "\"Obtaining Fairness using Optimal Transport Theory\" http://proceedings.mlr.press/v97/gordaliza19a.html\n",
- "\n",
- "\"Computational Optimal Transport\" https://arxiv.org/abs/1803.00567\n",
- "\n",
- "\"POT: Python Optimal Transport\" https://jmlr.org/papers/v22/20-451.html"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Introduction\n",
- "\n",
- "Optimal Transport (OT) is a field of mathematics which studies the geometry of probability spaces. Among its many contributions, OT provides a principled way to compare and align probability distributions by taking into account the underlying geometry of the\n",
- "considered metric space.\n",
- "\n",
- "Optimal Transport (OT) is a mathematical problem that was first introduced by Gaspard Monge in 1781. It addresses the task of determining the most efficient method for transporting mass from one distribution to another. In this problem, the cost associated with moving a unit of mass from one position to another is referred to as the ground cost. The primary objective of OT is to minimize the total cost incurred when moving one mass distribution onto another. The optimization problem can be expressed for two distributions $\\mu_s$ and $\\mu_t$ as\n",
- "\n",
- "$$\n",
- "\\min_{m, m_{\\#} \\mu_s=\\mu_t} \\int c(x, m(x)) d \\mu_s(x)\n",
- "$$\n",
- "in the continuous case, and\n",
- "$$\n",
- "\\min_{\\sigma \\in \\text{Perm}(n)} \\frac{1}{n} \\sum_{i=1}^n \\textbf{C}_{i,\\sigma(i)}\n",
- "$$\n",
- "in the discrete case, where $\\textbf{C}_{\\cdot, \\cdot}$ is the ground cost and the constraint $m_{\\#} \\mu_s=\\mu_t$ ensures that $\\mu_s$ is completely transported to $\\mu_t$. Where $T_{\\#} \\mu_s = \\mu_s(T^{-1}(B)) = u_{t}(B)$ with $T$ as a trasportation matrix between $\\mu_s$ and $\\mu_t$ at point $B$. "
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "OT can be used to detect **model-induced bias** by calculating the above cost (also known as **Earth Mover's distance** or **Wasserstein distance**) between the distribution of ground truth labels and model predictions for each of the **protected groups**. If its value is close to 1, the model is **biased** towards this group."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## General Optimal Transport examples"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let us start with some simple examples of calculating the Earth Mover's distance between two distributions - the basis of Optimal Transport for bias detection. We do this using the `earth_movers_distance` function.\n",
- "\n",
- "For concrete examples of bias detection on real datasets, skip to the next chapter."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1. General Optimal Transport\n",
- "\n",
- "Suppose we have two distributions $a$ and $b$ (as shown in the picture below), and we need to calculate the Wasserstein distance between these two distributions."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "# Initial distribution\n",
- "a = np.array([0., 0.01547988, 0.03095975, 0.04643963, 0.05727554, 0.05417957, 0.04643963, 0.07739938, \n",
- " 0.10835913, 0.12383901, 0.11764706, 0.10526316, 0.09287926, 0.07739938, 0.04643962, 0. ])\n",
- "# Required distribution\n",
- "b = np.array([0., 0.01829787, 0.02702128, 0.04106383, 0.07, 0.10829787, 0.14212766, 0.14468085, \n",
- " 0.13, 0.10808511, 0.08255319, 0.05170213, 0.03361702, 0.02702128, 0.01553191, 0. ])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_ot_metric.ipynb)"
+ ],
+ "metadata": {
+ "id": "vGqYImdfDCyy"
+ }
+ },
{
- "data": {
- "image/png": "",
- "text/plain": [
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Fp89L1E4DCXe"
+ },
+ "source": [
+ "# Measuring bias with Optimal Transport by calculating the Wasserstein distance"
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "from scipy.interpolate import make_interp_spline\n",
- "\n",
- "# Drawing both of them\n",
- "figure, axis = plt.subplots(1, 2)\n",
- "figure.set_figheight(4)\n",
- "figure.set_figwidth(12)\n",
- "figure.tight_layout(w_pad = 5)\n",
- "\n",
- "def draw(y, id):\n",
- " x = np.array(range(0, np.size(y)))\n",
- " XYSpline = make_interp_spline(x, y) \n",
- " X = np.linspace(x.min(), x.max(), 500)\n",
- " Y = XYSpline(X)\n",
- " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
- " axis[id].scatter(x, y, color=\"orange\")\n",
- " axis[id].plot(X, Y, color='blue')\n",
- "\n",
- "axis[0].title.set_text(\"Initial distribution\")\n",
- "axis[1].title.set_text(\"Required distribution\")\n",
- "draw(a, 0)\n",
- "draw(b, 1)\n",
- "\n",
- "plt.show()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In order to better understand how Optimal Transport works, below is presented the code considering the case when the matrix cost distance is presented and defined as the absolute difference between positions of each part of the distribution. That is $\\text{distance}[i][j] = abs(i - j)$."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "\n",
- "_a = pd.Series(a)\n",
- "_b = pd.Series(b)\n",
- "distance = np.zeros((np.size(a), np.size(b)))\n",
- "for i in range(np.size(a)):\n",
- " for j in range(np.size(b)):\n",
- " distance[i][j] = abs(i - j)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Wasserstein distance is equal to 1.3773703499999999.\n"
- ]
- }
- ],
- "source": [
- "from aif360.sklearn.metrics import ot_distance\n",
- "c0 = ot_distance(y_true=_a, y_pred=_b, cost_matrix=distance, mode='continuous')\n",
- "\n",
- "print(\"Wasserstein distance is equal to \", c0, \".\", sep=\"\")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2. Randomly distributed samples\n",
- "\n",
- "Suppose we have two distributions $a$ and $b$ with length $N$, that are generated randomly, and we need to calculate earth_movers_distance for them."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "N = 1000\n",
- "np.random.seed(seed=1)\n",
- "\n",
- "# Initial distribution\n",
- "a = np.random.rand(N)\n",
- "a /= np.sum(a)\n",
- "\n",
- "# Required distribution\n",
- "b = np.random.rand(N)\n",
- "b /= np.sum(b)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1Fk66iQtDCXg"
+ },
+ "source": [
+ "## Table of contents\n",
+ "\n",
+ "- Introduction\n",
+ "- General Optimal Transport examples\n",
+ "- Usage\n",
+ "- Application to Compas Dataset\n",
+ "- Application to Adult Dataset\n",
+ "- More details\n",
+ " - OT for mapping estimation\n",
+ " - Kantorovich optimal transport problem\n",
+ " - Solving optimal transport\n",
+ " - Necessity and priority of usage"
+ ]
+ },
{
- "data": {
- "image/png": "",
- "text/plain": [
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l6BtF35fDCXg"
+ },
+ "source": [
+ "## References\n",
+ "\n",
+ "\"FlipTest: fairness testing via optimal transport\" https://dl.acm.org/doi/abs/10.1145/3351095.3372845\n",
+ "\n",
+ "\"Obtaining Fairness using Optimal Transport Theory\" http://proceedings.mlr.press/v97/gordaliza19a.html\n",
+ "\n",
+ "\"Computational Optimal Transport\" https://arxiv.org/abs/1803.00567\n",
+ "\n",
+ "\"POT: Python Optimal Transport\" https://jmlr.org/papers/v22/20-451.html"
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# Drawing both of them\n",
- "figure, axis = plt.subplots(1, 2)\n",
- "figure.set_figheight(4)\n",
- "figure.set_figwidth(15)\n",
- "figure.tight_layout(w_pad = 5)\n",
- "\n",
- "def draw(y, id):\n",
- " axis[id].hist(y, color='lightgreen', ec='black', bins=10)\n",
- "\n",
- "axis[0].title.set_text(\"Initial distribution\")\n",
- "axis[1].title.set_text(\"Required distribution\")\n",
- "draw(a, 0)\n",
- "draw(b, 1)\n",
- "\n",
- "plt.show()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In this case the Wasserstein distance tends to zero as the size of the samples increase."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Wasserstein distance is: 2.003382269162742e-05.\n"
- ]
- }
- ],
- "source": [
- "import pandas as pd\n",
- "from aif360.sklearn.metrics import ot_distance\n",
- "\n",
- "_a = pd.Series(a)\n",
- "_b = pd.Series(b)\n",
- "c = ot_distance(y_true=_a, y_pred=_b, mode='continuous')\n",
- "\n",
- "print(\"Wasserstein distance is: \", c, \".\", sep=\"\")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3. Permutations\n",
- "\n",
- "Another example that shows clearly what the permutations in the first formula refer to is the one presented below."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "# Initial distribution\n",
- "a = np.array([0., 0.1, 0.1, 0.1, 0.08, 0., 0.1, 0.1, 0.08, 0.08, 0., 0.1, 0.08, 0.08, 0.08, 0.])\n",
- "# Required distribution\n",
- "b = np.array([0., 0.08, 0.08, 0.08, 0.1, 0., 0.08, 0.08, 0.1, 0.1, 0., 0.08, 0.1, 0.1, 0.1, 0.])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iy6Q5v0gDCXg"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "Optimal Transport (OT) is a field of mathematics which studies the geometry of probability spaces. Among its many contributions, OT provides a principled way to compare and align probability distributions by taking into account the underlying geometry of the\n",
+ "considered metric space.\n",
+ "\n",
+ "Optimal Transport (OT) is a mathematical problem that was first introduced by Gaspard Monge in 1781. It addresses the task of determining the most efficient method for transporting mass from one distribution to another. In this problem, the cost associated with moving a unit of mass from one position to another is referred to as the ground cost. The primary objective of OT is to minimize the total cost incurred when moving one mass distribution onto another. The optimization problem can be expressed for two distributions $\\mu_s$ and $\\mu_t$ as\n",
+ "\n",
+ "$$\n",
+ "\\min_{m, m_{\\#} \\mu_s=\\mu_t} \\int c(x, m(x)) d \\mu_s(x)\n",
+ "$$\n",
+ "in the continuous case, and\n",
+ "$$\n",
+ "\\min_{\\sigma \\in \\text{Perm}(n)} \\frac{1}{n} \\sum_{i=1}^n \\textbf{C}_{i,\\sigma(i)}\n",
+ "$$\n",
+ "in the discrete case, where $\\textbf{C}_{\\cdot, \\cdot}$ is the ground cost and the constraint $m_{\\#} \\mu_s=\\mu_t$ ensures that $\\mu_s$ is completely transported to $\\mu_t$. Where $T_{\\#} \\mu_s = \\mu_s(T^{-1}(B)) = u_{t}(B)$ with $T$ as a trasportation matrix between $\\mu_s$ and $\\mu_t$ at point $B$."
+ ]
+ },
{
- "data": {
- "image/png": "",
- "text/plain": [
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "euPQoGR9DCXh"
+ },
+ "source": [
+ "OT can be used to detect **model-induced bias** by calculating the above cost (also known as **Earth Mover's distance** or **Wasserstein distance**) between the distribution of ground truth labels and model predictions for each of the **protected groups**. If its value is close to 1, the model is **biased** towards this group."
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# Drawing both of them\n",
- "figure, axis = plt.subplots(1, 2)\n",
- "figure.set_figheight(4)\n",
- "figure.set_figwidth(12)\n",
- "figure.tight_layout(w_pad = 5)\n",
- "\n",
- "def draw(y, id):\n",
- " x = np.array(range(0, np.size(y)))\n",
- " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
- " axis[id].scatter(x, y, color=\"orange\")\n",
- "\n",
- "axis[0].title.set_text(\"Initial distribution\")\n",
- "axis[1].title.set_text(\"Required distribution\")\n",
- "draw(a, 0)\n",
- "draw(b, 1)\n",
- "\n",
- "plt.show()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "There, since we can go from the initial distribution to the desired one just using permutations, the Wasserstein distance is zero."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.0\n"
- ]
- }
- ],
- "source": [
- "import pandas as pd\n",
- "from aif360.sklearn.metrics import ot_distance\n",
- "\n",
- "_a = pd.Series(a)\n",
- "_b = pd.Series(b)\n",
- "c = ot_distance(_a, _b, mode='continuous')\n",
- "\n",
- "print(c)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4. Extreme case\n",
- "\n",
- "One more example that is closer to our case is \"normalization\". It's an explanation of why the maximum Wasserstein distance we can get in our case is approaching 1 (with increasing the size of the sample), that is, it is normalized. We get this in the case that all our population has a value 0 of the 2-year recidivism (which is presented in the paragraph \"Compas Dataset\") and the classifier fails massively in all the cases labeling all with a 1. That would be the worst-case scenario."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "# Initial distribution\n",
- "a = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.001])\n",
- "# Required distribution\n",
- "b = np.array([0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e1VAdG5wDCXh"
+ },
+ "source": [
+ "## General Optimal Transport examples"
+ ]
+ },
{
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAABLgAAAGUCAYAAAA285u8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVNklEQVR4nO3de1yUdd7/8TdxmNFSSlxBChDdCsxSg1ahEL1NTC076Eq1YaVW3NQqkKVopnlXpLkuax74WZrrr1La0LKNNbGUtZXa5GAnN6slMYM1rEQzOV6/P/wxd+MMyCA6Xszr+Xhcj73nO5/r+ny/w0if+8N18DIMwxAAAAAAAABgUue5ewIAAAAAAADA6aDBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFeIA1a9bIy8tLu3btatP+Xl5emjdvnu31Z599pnnz5unrr792iL3nnnvUq1evNuU5nX0laejQoRo6dKjt9ddffy0vLy+tWbPGpeO88sorysrKcmkfZ7nmzZsnLy8vVVVVuXSslpypzx4AALSPprqrafPx8VHPnj11++2364svvnD39Gy2b98uLy8vbd++/azk69Wrl+6555427dtSneWKY8eOad68eS6v2VmuXr166cYbb3TpOKfSUg16cj0OwBENLgCnVFhYqClTpthef/bZZ3riiSecNlnmzJmjjRs3nsXZNa9nz54qLCzUmDFjXNqvLQ2utuZylVk+ewAAPN2LL76owsJCbd26VQ899JA2bdqk6667Tj/88IO7pyZJuvrqq1VYWKirr77a3VNpkylTpqiwsNClfY4dO6YnnnjC5QZXW3K1RUs16Mn1OABHPu6eAIBz3+DBg1sd26dPnzM4E9dYLBaX5t4WDQ0Nqq+vPyu5TuVc+uwBAPB0/fr1U3R0tKQTZ5k3NDRo7ty5ev3113Xvvfe6eXZS165dW1W7HDt2TJ07dz4LM3LNJZdcoksuueSM5mha+9nIdSrurjMBM+AMLsBD3XPPPbrgggv05ZdfavTo0brgggsUEhKihx9+WDU1NXaxvzwles2aNfrtb38rSRo2bJjt9PumU8adXSa3bNkyDRkyRD169ND555+vK6+8UgsXLlRdXV2b5m4YhhYuXKiwsDBZrVZdffXV+tvf/uYQ5+x09u+++07333+/QkJCZLFY9Ktf/UrXXnuttm7dKulEAfrWW29p3759dpcX/PJ4Cxcu1JNPPqnw8HBZLBZt27atxcsh9+/fr9tuu01du3aVv7+/7rrrLn333Xd2Mc2ddv7L0/nb8tkfP35cGRkZCg8Pl5+fny6++GI9+OCD+vHHHx3y3Hjjjdq8ebOuvvpqderUSREREVq9enUzPwUAAOCKpmbXf/7zH7vxXbt2aezYserWrZusVqsGDhyoV1991WH/999/X9dee62sVquCg4OVkZGh559/Xl5eXnZndremppCcX6LYVB9+/PHHSkhIUJcuXTR8+HBJUm1trZ588klFRETYaqh7773Xoaapq6vTo48+qqCgIHXu3FnXXXed/vnPf7b6c/r22281YcIEdenSRf7+/kpMTFRlZaVDnLPLBt99910NHTpUAQEB6tSpk0JDQzVu3DgdO3ZMX3/9tX71q19Jkp544glbHdX0mTQdr7i4WOPHj9dFF11k++NhS5dDbty4UVdddZWsVqt69+6tJUuW2L3fdMnqyWffn/z5t1SDSs5/rp988oluvvlmXXTRRbJarRowYID+/Oc/O82zbt06zZ49W8HBweratauuv/56ff75507XBJgVZ3ABHqyurk5jx47V5MmT9fDDD+vvf/+7/ud//kf+/v56/PHHne4zZswYPf3005o1a5aWLVtmO629pbOHvvrqK9155522Jsvu3bv11FNP6V//+lebGihPPPGEnnjiCU2ePFnjx4/X/v37dd9996mhoUGXX355i/smJSWpuLhYTz31lC677DL9+OOPKi4u1qFDhyRJy5cv1/3336+vvvqq2cv9lixZossuu0yLFi1S165ddemll7aY89Zbb9WECROUnJysTz/9VHPmzNFnn32mDz74QL6+vq1et6ufvWEYuuWWW/TOO+8oIyNDcXFx+uijjzR37lwVFhaqsLBQFovFFr979249/PDDmjlzpgIDA/XCCy9o8uTJ+vWvf60hQ4a0ep4AAMBRWVmZJOmyyy6zjW3btk033HCDBg0apOzsbPn7+2v9+vVKTEzUsWPHbM2Xzz77TMOHD1evXr20Zs0ade7cWcuXL9crr7zS7vOsra3V2LFj9cADD2jmzJmqr69XY2Ojbr75Zu3YsUOPPvqoYmNjtW/fPs2dO1dDhw7Vrl271KlTJ0nSfffdp7Vr12r69OkaMWKEPvnkE9122206cuTIKXP//PPPuv766/Xtt98qMzNTl112md566y0lJiaect+vv/5aY8aMUVxcnFavXq0LL7xQBw4c0ObNm1VbW6uePXtq8+bNuuGGGzR58mTb5X5NTa8mt912m26//XYlJyfrp59+ajFnaWmpUlNTNW/ePAUFBenll1/WtGnTVFtbq+nTp59yzr/Umhr0lz7//HPFxsaqR48eWrJkiQICAvTSSy/pnnvu0X/+8x89+uijdvGzZs3StddeqxdeeEHV1dWaMWOGbrrpJu3Zs0fe3t4uzRU4ZxkAOrwXX3zRkGR8+OGHtrG7777bkGS8+uqrdrGjR482Lr/8crsxScbcuXNtr//yl78Ykoxt27Y55Lr77ruNsLCwZufS0NBg1NXVGWvXrjW8vb2N77//vtX7GoZh/PDDD4bVajVuvfVWu/F//OMfhiQjPj7eNlZWVmZIMl588UXb2AUXXGCkpqa2mGPMmDFO59F0vD59+hi1tbVO3/tlrrlz5xqSjLS0NLvYl19+2ZBkvPTSS7axkz/jJmFhYcbdd99te+3KZ79582ZDkrFw4UK7uJycHEOSsXLlSrs8VqvV2Ldvn23s559/Nrp162Y88MADDrkAAIBzTXXX+++/b9TV1RlHjhwxNm/ebAQFBRlDhgwx6urqbLERERHGwIED7cYMwzBuvPFGo2fPnkZDQ4NhGIaRmJhodOrUyaisrLTF1NfXGxEREYYko6yszDbe2ppi27ZtDjVFU324evVqu33XrVtnSDJyc3Ptxj/88ENDkrF8+XLDMAxjz549LdY+v8zvzIoVKwxJxhtvvGE3ft999zVbZzV57bXXDElGaWlps8f/7rvvmv18mo73+OOPN/veL4WFhRleXl4O+UaMGGF07drV+OmnnwzD+N/vwy9/Robh/PNvrgY1DMef6+23325YLBajvLzcLm7UqFFG586djR9//NEuz+jRo+3iXn31VUOSUVhY6DQfYEZcogh4MC8vL9100012Y1dddZX27dvXrnlKSko0duxYBQQEyNvbW76+vpo4caIaGhq0d+9el45VWFio48eP63e/+53deGxsrMLCwk65/29+8xutWbNGTz75pN5///02XSY5duxYl868OnmuEyZMkI+Pj7Zt2+Zyble8++67kuTwxKLf/va3Ov/88/XOO+/YjQ8YMEChoaG211arVZdddlm7fx8AAPAEgwcPlq+vr7p06aIbbrhBF110kd544w35+Jy4iObLL7/Uv/71L1udUF9fb9tGjx6tiooK2yVk27Zt0/DhwxUYGGg7vre3d6vObGqLcePG2b3+61//qgsvvFA33XST3TwHDBigoKAg22V2TbVNc7XPqWzbtk1dunTR2LFj7cbvvPPOU+47YMAA+fn56f7779ef//xn/fvf/z7lPs6cvPaWXHHFFerfv7/d2J133qnq6moVFxe3KX9rvfvuuxo+fLhCQkLsxu+55x4dO3bM4ab4J3+mV111lSRR56FDocEFeLDOnTvLarXajVksFh0/frzdcpSXlysuLk4HDhzQn/70J+3YsUMffvihli1bJunEqeiuaLqUMCgoyOE9Z2Mny8nJ0d13360XXnhBMTEx6tatmyZOnOj03g7N6dmzZ+sn7GRePj4+CggIsK3lTDl06JB8fHwcTr338vJSUFCQQ/6AgACHY1gsFpd/RgAAQFq7dq0+/PBDvfvuu3rggQe0Z88e3XHHHbb3m+7FNX36dPn6+tptKSkpkqSqqipJJ/6b3tbax1WdO3dW165d7cb+85//6Mcff5Sfn5/DXCsrK+3m6WxeTbXPqRw6dMiuidekNevs06ePtm7dqh49eujBBx9Unz591KdPH/3pT3865b6/5Eqd19LP5GzUec7mGhwc7DT/yZ9/020qqPPQkXAPLgBn1Ouvv66ffvpJGzZssDvDqrS0tE3Ha/qPs7OGVGVlpcNN1k/WvXt3ZWVlKSsrS+Xl5dq0aZNmzpypgwcPavPmza2aQ3M3GW1OZWWlLr74Ytvr+vp6HTp0yK7QsFgsDjf3l06vOAoICFB9fb2+++47uyaXYRiqrKzUNddc0+ZjAwCAlkVGRtpuLD9s2DA1NDTohRde0Guvvabx48ere/fukqSMjAzddtttTo/RdG/RgICAZmufk51uTeGszunevbsCAgKarZW6dOlim2fTvJzVPqcSEBDg9Ib0rf1DZFxcnOLi4tTQ0KBdu3bpueeeU2pqqgIDA3X77be36hiu1Hkt/UyaPoumPyaf/DNpagq2VUBAgCoqKhzGv/32W0myfb8AT8IZXABc5spffJqKhF/ezNwwDD3//PNtyj148GBZrVa9/PLLduM7d+50+RTr0NBQPfTQQxoxYoTdaeTtfdbSyXN99dVXVV9fr6FDh9rGevXqpY8++sgu7t1339XRo0ftxlz57JueevTSSy/Zjefm5uqnn36yvQ8AAM68hQsX6qKLLtLjjz+uxsZGXX755br00ku1e/duRUdHO92aGkfDhg3TO++8Y/cExoaGBuXk5DjkaW1N4Yobb7xRhw4dUkNDg9N5NjXimmqb5mqfUxk2bJiOHDmiTZs22Y27ejN9b29vDRo0yHbFQFOd195nLX366afavXu33dgrr7yiLl262B4G1PTH15N/JievsWl+rZ3b8OHD9e6779oaWk3Wrl2rzp07a/Dgwa1dBtBhcAYXAJf169dPkrRy5Up16dJFVqtV4eHhTk89HzFihPz8/HTHHXfo0Ucf1fHjx7VixQr98MMPbcp90UUXafr06XryySc1ZcoU/fa3v9X+/fttT69pyeHDhzVs2DDdeeedioiIUJcuXfThhx9q8+bNdn85vfLKK7VhwwatWLFCUVFROu+882x/gW2LDRs2yMfHRyNGjLA9RbF///6aMGGCLSYpKUlz5szR448/rvj4eH322WdaunSp/P397Y7l6mc/cuRIzZgxQ9XV1br22mttT1EcOHCgkpKS2rwmAADgmosuukgZGRl69NFH9corr+iuu+7S//k//0ejRo3SyJEjdc899+jiiy/W999/rz179qi4uFh/+ctfJEmPPfaYNm3apP/6r//S448/rs6dO2vZsmVOn/LX2prCFbfffrtefvlljR49WtOmTdNvfvMb+fr66ptvvtG2bdt0880369Zbb1VkZKTuuusuZWVlydfXV9dff70++eQT25OnT2XixIn64x//qIkTJ+qpp57SpZdeqry8PL399tun3Dc7O1vvvvuuxowZo9DQUB0/ftz2tO7rr79e0okzzcLCwvTGG29o+PDh6tatm7p3737KKwCaExwcrLFjx2revHnq2bOnXnrpJeXn52vBggXq3LmzJOmaa67R5ZdfrunTp6u+vl4XXXSRNm7cqPfee8/heK7UoHPnztVf//pXDRs2TI8//ri6deuml19+WW+99ZYWLlx4Wj9vwLTcfZd7AGdec09RPP/88x1inT0lRk6eNpOVlWWEh4cb3t7edk+1cfYkxDfffNPo37+/YbVajYsvvth45JFHjL/97W9On9xzqqcoGoZhNDY2GpmZmUZISIjh5+dnXHXVVcabb75pxMfHt/gUxePHjxvJycnGVVddZXTt2tXo1KmTcfnllxtz5861PenGMAzj+++/N8aPH29ceOGFhpeXl+3zaDres88+6zCnlp6iWFRUZNx0003GBRdcYHTp0sW44447jP/85z92+9fU1BiPPvqoERISYnTq1MmIj483SktLHZ545Opn//PPPxszZswwwsLCDF9fX6Nnz57Gf//3fxs//PCDXVxYWJgxZswYh3Wd/JkCAICWOau7mvz8889GaGiocemllxr19fWGYRjG7t27jQkTJhg9evQwfH19jaCgIOO//uu/jOzsbLt9//GPfxiDBw82LBaLERQUZDzyyCPGypUrHZ7Q19qaormnKDqrDw3DMOrq6oxFixbZaroLLrjAiIiIMB544AHjiy++sMv/8MMPGz169DCsVqsxePBgo7Cw0GlN48w333xjjBs3zlY3jRs3zti5c+cpn6JYWFho3HrrrUZYWJhhsViMgIAAIz4+3ti0aZPd8bdu3WoMHDjQsFgsdk92bDred9995zCn5p6iOGbMGOO1114zrrjiCsPPz8/o1auXsXjxYof99+7dayQkJBhdu3Y1fvWrXxm///3vjbfeesvh82+uBjUM5/X4xx9/bNx0002Gv7+/4efnZ/Tv39/uMzKM//05/+Uvf7Ebd1a7AmbnZRiGcfbaaQAAAACA9rBmzRrde++9Kisra/NZSADQUXAPLgAAAAAAAJgaDS4AAAAAAACYGpcoAgAAAAAAwNQ4gwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKbm4+4JnEsaGxv17bffqkuXLvLy8nL3dAAAwBlgGIaOHDmi4OBgnXcef+szG+o1AAA6vrbUazS4fuHbb79VSEiIu6cBAADOgv379+uSSy5x9zTgIuo1AAA8hyv1Gg2uX+jSpYukEx9g165d3TwbAABwJlRXVyskJMT2332YC/UaAAAdX1vqNRpcv9B0mnvXrl0pmAAA6OC4vM2cqNcAAPAcrtRr3HgCAAAAAAAApkaDCwAAAAAAAKZGgwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKZGgwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKbm4+4JAACADq6xQfpuh/RzhdSpp/SrOOk8b3fPCjg97vhee0pOd+UlJznNmpecHSunu/J2gHqtTQ2u5cuX69lnn1VFRYWuuOIKZWVlKS4urtn4goICpaen69NPP1VwcLAeffRRJScn28Xk5uZqzpw5+uqrr9SnTx899dRTuvXWW23v//3vf9ezzz6roqIiVVRUaOPGjbrlllvsjmEYhp544gmtXLlSP/zwgwYNGqRly5bpiiuuaMsyAQDA6dq/QSqaJh375n/HOl8iRf1JCrnNffMCToc7vteektNdeclJTrPmJWfHyumuvB2kXnP5EsWcnBylpqZq9uzZKikpUVxcnEaNGqXy8nKn8WVlZRo9erTi4uJUUlKiWbNmaerUqcrNzbXFFBYWKjExUUlJSdq9e7eSkpI0YcIEffDBB7aYn376Sf3799fSpUubndvChQu1ePFiLV26VB9++KGCgoI0YsQIHTlyxNVlAgCA07V/g7RjvH2xJEnHDpwY37/BPfMCToc7vteektNdeclJTrPmJWfHyumuvB2oXvMyDMNwZYdBgwbp6quv1ooVK2xjkZGRuuWWW5SZmekQP2PGDG3atEl79uyxjSUnJ2v37t0qLCyUJCUmJqq6ulp/+9vfbDE33HCDLrroIq1bt85x0l5eDmdwGYah4OBgpaamasaMGZKkmpoaBQYGasGCBXrggQdOubbq6mr5+/vr8OHD6tq166k/DAAA4Fxjg7Spl2OxZON14i+DY8vO+unv/Pfe3Nz683PH99pTcrorLznJada85OxYOd2Vt4PVay5dolhbW6uioiLNnDnTbjwhIUE7d+50uk9hYaESEhLsxkaOHKlVq1aprq5Ovr6+KiwsVFpamkNMVlZWq+dWVlamyspKu1wWi0Xx8fHauXOn0wZXTU2NampqbK+rq6tbnQ8AAE9XXl6uqqoqp+9d8NMuXdZssSRJhnRsv/b+Y5WOnh/tNKJ79+4KDQ1th5kCreeO77W7/i01l9cdOdsjr6fkbC4v3yO+u67kbC4v3yO+R67kbCmvO7jU4KqqqlJDQ4MCAwPtxgMDA1VZWel0n8rKSqfx9fX1qqqqUs+ePZuNae6YzeVp2u/k4+zbt8/pPpmZmXriiSdanQMAAJxQXl6uiMgI/XzsZ6fv3x4jrXvo1MeZO+MBrS90/l6nzp30rz3/OmeKJnR87vheu+vfUkt53ZGzPfJ6Sk5nefkeuS9ne+Tle3Ru/Uz5HrU+Z3N53aVNN5n38vKye20YhsPYqeJPHnf1mO0xt4yMDKWnp9teV1dXKyQkxOWcAAB4mqqqKv187Gfd9X/uUuBlgQ7v96nZLx169ZTHGTJ3gi62OP639z97/6OXHnhJVVVV50TBBM/gju+1u/4ttZTXHTlPN6+n5GwuL98j9+U83bx8j869nynfo9blbCmvu7jU4Orevbu8vb0dzqw6ePCgw5lTTYKCgpzG+/j4KCAgoMWY5o7ZXB7pxJlcPXv2bNVxLBaLLBZLq3MAAAB7gZcFKqS/Y8FTZ1ysI9u36ILjP8rZn5kMSUetF6rumsEK8XL5mTfAGeWO77W7/i05y+uOnGc6Lzk943vEd9e8OZvL6yk5z3ReT6nXXJqhn5+foqKilJ+fbzeen5+v2NhYp/vExMQ4xG/ZskXR0dHy9fVtMaa5YzoTHh6uoKAgu+PU1taqoKDApeMAAIDTZ3idp4LIW0/83ye/9///tyDyVhkmKJaAJu74XntKTnflJSc5zZqXnB0rp7vydrR6zeVZpqen64UXXtDq1au1Z88epaWlqby8XMnJyZJOXPY3ceJEW3xycrL27dun9PR07dmzR6tXr9aqVas0ffp0W8y0adO0ZcsWLViwQP/617+0YMECbd26VampqbaYo0ePqrS0VKWlpZJO3FS+tLRU5eXlkk5cmpiamqqnn35aGzdu1CeffKJ77rlHnTt31p133tmWzwYAAJyGr4L6662B9+qo9UK78aPWC/XWwHv1VVB/90wMOA3u+F57Sk535SUnOc2al5wdK6e78nakes3le3AlJibq0KFDmj9/vioqKtSvXz/l5eUpLCxMklRRUWFrOkknzqzKy8tTWlqali1bpuDgYC1ZskTjxo2zxcTGxmr9+vV67LHHNGfOHPXp00c5OTkaNGiQLWbXrl0aNmyY7XXTvbPuvvturVmzRpL06KOP6ueff1ZKSop++OEHDRo0SFu2bFGXLl1cXSYAAGgHXwX1178Dr1Tw91/p/Jpq/WTpqm+79THNXwIBZ9zxvfaUnO7KS05ymjUvOTtWTnfl7Sj1WptuMp+SkqKUlBSn7zU1m34pPj5excXFLR5z/PjxGj9+fLPvDx061HZz+uZ4eXlp3rx5mjdvXotxAADg7DG8ztOBgEvdPQ2gXbnje+0pOd2Vl5zkNGtecnasnO7K2xHqNXO14wAAAAAAAICT0OACAAAAAACAqdHgAgAAgFPLly9XeHi4rFaroqKitGPHjhbjCwoKFBUVJavVqt69eys7O9sh5scff9SDDz6onj17ymq1KjIyUnl5eWdqCQAAwEPQ4AIAAICDnJwcpaamavbs2SopKVFcXJxGjRpl9zChXyorK9Po0aMVFxenkpISzZo1S1OnTlVubq4tpra2ViNGjNDXX3+t1157TZ9//rmef/55XXzxxWdrWQAAoINq003mAQAA0LEtXrxYkydP1pQpUyRJWVlZevvtt7VixQplZmY6xGdnZys0NFRZWVmSpMjISO3atUuLFi2yPT179erV+v7777Vz5075+vpKku1J3AAAAKeDM7gAAABgp7a2VkVFRUpISLAbT0hI0M6dO53uU1hY6BA/cuRI7dq1S3V1dZKkTZs2KSYmRg8++KACAwPVr18/Pf3002poaGh2LjU1NaqurrbbAAAATkaDCwAAAHaqqqrU0NCgwMBAu/HAwEBVVlY63aeystJpfH19vaqqqiRJ//73v/Xaa6+poaFBeXl5euyxx/SHP/xBTz31VLNzyczMlL+/v20LCQk5zdUBAICOiAYXAAAAnPLy8rJ7bRiGw9ip4n853tjYqB49emjlypWKiorS7bffrtmzZ2vFihXNHjMjI0OHDx+2bfv372/rcgAAQAfGPbgAAABgp3v37vL29nY4W+vgwYMOZ2k1CQoKchrv4+OjgIAASVLPnj3l6+srb29vW0xkZKQqKytVW1srPz8/h+NaLBZZLJbTXRIAAOjgOIMLAAAAdvz8/BQVFaX8/Hy78fz8fMXGxjrdJyYmxiF+y5Ytio6Ott1Q/tprr9WXX36pxsZGW8zevXvVs2dPp80tAACA1qLBBQAAAAfp6el64YUXtHr1au3Zs0dpaWkqLy9XcnKypBOXDk6cONEWn5ycrH379ik9PV179uzR6tWrtWrVKk2fPt0W89///d86dOiQpk2bpr179+qtt97S008/rQcffPCsrw8AAHQsXKIIAAAAB4mJiTp06JDmz5+viooK9evXT3l5eQoLC5MkVVRUqLy83BYfHh6uvLw8paWladmyZQoODtaSJUs0btw4W0xISIi2bNmitLQ0XXXVVbr44os1bdo0zZgx46yvDwAAdCw0uAAAAOBUSkqKUlJSnL63Zs0ah7H4+HgVFxe3eMyYmBi9//777TE9AAAAGy5RBAAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4AIAAAAAAICptanBtXz5coWHh8tqtSoqKko7duxoMb6goEBRUVGyWq3q3bu3srOzHWJyc3PVt29fWSwW9e3bVxs3bnQ579GjR/XQQw/pkksuUadOnRQZGakVK1a0ZYkAAAAer71rvjVr1sjLy8thO378+JlcBgAA8AAuN7hycnKUmpqq2bNnq6SkRHFxcRo1apTKy8udxpeVlWn06NGKi4tTSUmJZs2apalTpyo3N9cWU1hYqMTERCUlJWn37t1KSkrShAkT9MEHH7iUNy0tTZs3b9ZLL72kPXv2KC0tTb///e/1xhtvuLpMAAAAj3Ymaj5J6tq1qyoqKuw2q9V6NpYEAAA6MJcbXIsXL9bkyZM1ZcoURUZGKisrSyEhIc2eKZWdna3Q0FBlZWUpMjJSU6ZM0aRJk7Ro0SJbTFZWlkaMGKGMjAxFREQoIyNDw4cPV1ZWlkt5CwsLdffdd2vo0KHq1auX7r//fvXv31+7du1ydZkAAAAe7UzUfJLk5eWloKAguw0AAOB0udTgqq2tVVFRkRISEuzGExIStHPnTqf7FBYWOsSPHDlSu3btUl1dXYsxTcdsbd7rrrtOmzZt0oEDB2QYhrZt26a9e/dq5MiRTudWU1Oj6upquw0AAMDTnamaTzpxS4mwsDBdcskluvHGG1VSUtLiXKjXAABAa7jU4KqqqlJDQ4MCAwPtxgMDA1VZWel0n8rKSqfx9fX1qqqqajGm6ZitzbtkyRL17dtXl1xyifz8/HTDDTdo+fLluu6665zOLTMzU/7+/rYtJCSkFZ8CAABAx3amar6IiAitWbNGmzZt0rp162S1WnXttdfqiy++aHYu1GsAAKA12nSTeS8vL7vXhmE4jJ0q/uTx1hzzVDFLlizR+++/r02bNqmoqEh/+MMflJKSoq1btzqdV0ZGhg4fPmzb9u/f3+waAAAAPE1713yDBw/WXXfdpf79+ysuLk6vvvqqLrvsMj333HPNHpN6DQAAtIaPK8Hdu3eXt7e3w1/uDh486PAXuyZBQUFO4318fBQQENBiTNMxW5P3559/1qxZs7Rx40aNGTNGknTVVVeptLRUixYt0vXXX+8wN4vFIovF0trlAwAAeIQzVfOd7LzzztM111zT4hlc1GsAAKA1XDqDy8/PT1FRUcrPz7cbz8/PV2xsrNN9YmJiHOK3bNmi6Oho+fr6thjTdMzW5K2rq1NdXZ3OO89+Sd7e3mpsbHRlmQAAAB7tTNV8JzMMQ6WlperZs2f7TBwAAHgsl87gkqT09HQlJSUpOjpaMTExWrlypcrLy5WcnCzpxGnkBw4c0Nq1ayVJycnJWrp0qdLT03XfffepsLBQq1at0rp162zHnDZtmoYMGaIFCxbo5ptv1htvvKGtW7fqvffea3Xerl27Kj4+Xo888og6deqksLAwFRQUaO3atVq8ePFpfUgAAACe5kzUfE888YQGDx6sSy+9VNXV1VqyZIlKS0u1bNkyt6wRAAB0HC43uBITE3Xo0CHNnz9fFRUV6tevn/Ly8hQWFiZJqqioUHl5uS0+PDxceXl5SktL07JlyxQcHKwlS5Zo3LhxtpjY2FitX79ejz32mObMmaM+ffooJydHgwYNanVeSVq/fr0yMjL0u9/9Tt9//73CwsL01FNP2QoxAAAAtM6ZqPl+/PFH3X///aqsrJS/v78GDhyov//97/rNb35z1tcHAAA6FpcbXJKUkpKilJQUp++tWbPGYSw+Pl7FxcUtHnP8+PEaP358m/NKJ+798OKLL7Z4DAAAALROe9d8f/zjH/XHP/6xvaYHAABg06anKAIAAAAAAADnChpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAHBq+fLlCg8Pl9VqVVRUlHbs2NFifEFBgaKiomS1WtW7d29lZ2c3G7t+/Xp5eXnplltuaedZAwAAT0SDCwAAAA5ycnKUmpqq2bNnq6SkRHFxcRo1apTKy8udxpeVlWn06NGKi4tTSUmJZs2apalTpyo3N9chdt++fZo+fbri4uLO9DIAAICHoMEFAAAAB4sXL9bkyZM1ZcoURUZGKisrSyEhIVqxYoXT+OzsbIWGhiorK0uRkZGaMmWKJk2apEWLFtnFNTQ06He/+52eeOIJ9e7d+2wsBQAAeAAaXAAAALBTW1uroqIiJSQk2I0nJCRo586dTvcpLCx0iB85cqR27dqluro629j8+fP1q1/9SpMnT27VXGpqalRdXW23AQAAnIwGFwAAAOxUVVWpoaFBgYGBduOBgYGqrKx0uk9lZaXT+Pr6elVVVUmS/vGPf2jVqlV6/vnnWz2XzMxM+fv727aQkBAXVwMAADwBDS4AAAA45eXlZffaMAyHsVPFN40fOXJEd911l55//nl179691XPIyMjQ4cOHbdv+/ftdWAEAAPAUPu6eAAAAAM4t3bt3l7e3t8PZWgcPHnQ4S6tJUFCQ03gfHx8FBATo008/1ddff62bbrrJ9n5jY6MkycfHR59//rn69OnjcFyLxSKLxXK6SwIAAB0cZ3ABAADAjp+fn6KiopSfn283np+fr9jYWKf7xMTEOMRv2bJF0dHR8vX1VUREhD7++GOVlpbatrFjx2rYsGEqLS3l0kMAAHBaOIMLAAAADtLT05WUlKTo6GjFxMRo5cqVKi8vV3JysqQTlw4eOHBAa9eulSQlJydr6dKlSk9P13333afCwkKtWrVK69atkyRZrVb169fPLseFF14oSQ7jAAAArqLBBQAAAAeJiYk6dOiQ5s+fr4qKCvXr1095eXkKCwuTJFVUVKi8vNwWHx4erry8PKWlpWnZsmUKDg7WkiVLNG7cOHctAQAAeBAaXAAAAHAqJSVFKSkpTt9bs2aNw1h8fLyKi4tbfXxnxwAAAGgL7sEFAAAAAAAAU6PBBQAAAAAAAFNrU4Nr+fLlCg8Pl9VqVVRUlHbs2NFifEFBgaKiomS1WtW7d29lZ2c7xOTm5qpv376yWCzq27evNm7c2Ka8e/bs0dixY+Xv768uXbpo8ODBdveHAAAAAAAAQMficoMrJydHqampmj17tkpKShQXF6dRo0Y120QqKyvT6NGjFRcXp5KSEs2aNUtTp05Vbm6uLaawsFCJiYlKSkrS7t27lZSUpAkTJuiDDz5wKe9XX32l6667ThEREdq+fbt2796tOXPmyGq1urpMAAAAAAAAmITLDa7Fixdr8uTJmjJliiIjI5WVlaWQkBCtWLHCaXx2drZCQ0OVlZWlyMhITZkyRZMmTdKiRYtsMVlZWRoxYoQyMjIUERGhjIwMDR8+XFlZWS7lnT17tkaPHq2FCxdq4MCB6t27t8aMGaMePXq4ukwAAAAAAACYhEsNrtraWhUVFSkhIcFuPCEhQTt37nS6T2FhoUP8yJEjtWvXLtXV1bUY03TM1uRtbGzUW2+9pcsuu0wjR45Ujx49NGjQIL3++uvNrqempkbV1dV2GwAAAAAAAMzFpQZXVVWVGhoaFBgYaDceGBioyspKp/tUVlY6ja+vr1dVVVWLMU3HbE3egwcP6ujRo3rmmWd0ww03aMuWLbr11lt12223qaCgwOncMjMz5e/vb9tCQkJa+UkAAAAAAADgXNGmm8x7eXnZvTYMw2HsVPEnj7fmmC3FNDY2SpJuvvlmpaWlacCAAZo5c6ZuvPFGpze1l6SMjAwdPnzYtu3fv7/ZNQAAAAAAAODc5ONKcPfu3eXt7e1wttbBgwcdzq5qEhQU5DTex8dHAQEBLcY0HbM1ebt37y4fHx/17dvXLiYyMlLvvfee07lZLBZZLJaWlgwAAAAAAIBznEtncPn5+SkqKkr5+fl24/n5+YqNjXW6T0xMjEP8li1bFB0dLV9f3xZjmo7Zmrx+fn665ppr9Pnnn9vF7N27V2FhYa4sEwAAAAAAACbi0hlckpSenq6kpCRFR0crJiZGK1euVHl5uZKTkyWduOzvwIEDWrt2rSQpOTlZS5cuVXp6uu677z4VFhZq1apVWrdune2Y06ZN05AhQ7RgwQLdfPPNeuONN7R161a7M69OlVeSHnnkESUmJmrIkCEaNmyYNm/erDfffFPbt29v6+cDAAAAAACAc5zLDa7ExEQdOnRI8+fPV0VFhfr166e8vDzbWVIVFRUqLy+3xYeHhysvL09paWlatmyZgoODtWTJEo0bN84WExsbq/Xr1+uxxx7TnDlz1KdPH+Xk5GjQoEGtzitJt956q7Kzs5WZmampU6fq8ssvV25urq677ro2fTgAAAAAAAA497nc4JKklJQUpaSkOH1vzZo1DmPx8fEqLi5u8Zjjx4/X+PHj25y3yaRJkzRp0qQWYwAAAAAAANBxtOkpigAAAAAAAMC5ggYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAAAAATI0GFwAAAAAAAEyNBhcAAAAAAABMjQYXAAAAnFq+fLnCw8NltVoVFRWlHTt2tBhfUFCgqKgoWa1W9e7dW9nZ2Xbvb9iwQdHR0brwwgt1/vnna8CAAfq///f/nsklAAAAD0GDCwAAAA5ycnKUmpqq2bNnq6SkRHFxcRo1apTKy8udxpeVlWn06NGKi4tTSUmJZs2apalTpyo3N9cW061bN82ePVuFhYX66KOPdO+99+ree+/V22+/fbaWBQAAOigfd08AAAAA557Fixdr8uTJmjJliiQpKytLb7/9tlasWKHMzEyH+OzsbIWGhiorK0uSFBkZqV27dmnRokUaN26cJGno0KF2+0ybNk1//vOf9d5772nkyJFO51FTU6Oamhrb6+rq6nZYHQAA6Gg4gwsAAAB2amtrVVRUpISEBLvxhIQE7dy50+k+hYWFDvEjR47Url27VFdX5xBvGIbeeecdff755xoyZEizc8nMzJS/v79tCwkJacOKAABAR0eDCwAAAHaqqqrU0NCgwMBAu/HAwEBVVlY63aeystJpfH19vaqqqmxjhw8f1gUXXCA/Pz+NGTNGzz33nEaMGNHsXDIyMnT48GHbtn///tNYGQAA6Ki4RBEAAABOeXl52b02DMNh7FTxJ4936dJFpaWlOnr0qN555x2lp6erd+/eDpcvNrFYLLJYLG1cAQAA8BQ0uAAAAGCne/fu8vb2djhb6+DBgw5naTUJCgpyGu/j46OAgADb2Hnnnadf//rXkqQBAwZoz549yszMbLbBBQAA0BpcoggAAAA7fn5+ioqKUn5+vt14fn6+YmNjne4TExPjEL9lyxZFR0fL19e32VyGYdjdRB4AAKAtOIMLAAAADtLT05WUlKTo6GjFxMRo5cqVKi8vV3JysqQT98Y6cOCA1q5dK0lKTk7W0qVLlZ6ervvuu0+FhYVatWqV1q1bZztmZmamoqOj1adPH9XW1iovL09r167VihUr3LJGAADQcdDgAgAAgIPExEQdOnRI8+fPV0VFhfr166e8vDyFhYVJkioqKlReXm6LDw8PV15entLS0rRs2TIFBwdryZIlGjdunC3mp59+UkpKir755ht16tRJEREReumll5SYmHjW1wcAADoWGlwAAABwKiUlRSkpKU7fW7NmjcNYfHy8iouLmz3ek08+qSeffLK9pgcAAGDDPbgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqbWpwLV++XOHh4bJarYqKitKOHTtajC8oKFBUVJSsVqt69+6t7Oxsh5jc3Fz17dtXFotFffv21caNG08r7wMPPCAvLy9lZWW5vD4AAAAAAACYh8sNrpycHKWmpmr27NkqKSlRXFycRo0apfLycqfxZWVlGj16tOLi4lRSUqJZs2Zp6tSpys3NtcUUFhYqMTFRSUlJ2r17t5KSkjRhwgR98MEHbcr7+uuv64MPPlBwcLCrywMAAAAAAIDJuNzgWrx4sSZPnqwpU6YoMjJSWVlZCgkJ0YoVK5zGZ2dnKzQ0VFlZWYqMjNSUKVM0adIkLVq0yBaTlZWlESNGKCMjQxEREcrIyNDw4cPtzr5qbd4DBw7ooYce0ssvvyxfX98W11JTU6Pq6mq7DQAAAAAAAObiUoOrtrZWRUVFSkhIsBtPSEjQzp07ne5TWFjoED9y5Ejt2rVLdXV1LcY0HbO1eRsbG5WUlKRHHnlEV1xxxSnXk5mZKX9/f9sWEhJyyn0AAAAAAABwbnGpwVVVVaWGhgYFBgbajQcGBqqystLpPpWVlU7j6+vrVVVV1WJM0zFbm3fBggXy8fHR1KlTW7WejIwMHT582Lbt37+/VfsBAAAAAADg3OHTlp28vLzsXhuG4TB2qviTx1tzzJZiioqK9Kc//UnFxcUtzuWXLBaLLBZLq2IBAAAAAABwbnLpDK7u3bvL29vb4WytgwcPOpxd1SQoKMhpvI+PjwICAlqMaTpma/Lu2LFDBw8eVGhoqHx8fOTj46N9+/bp4YcfVq9evVxZJgAAAAAAAEzEpQaXn5+foqKilJ+fbzeen5+v2NhYp/vExMQ4xG/ZskXR0dG2m8A3F9N0zNbkTUpK0kcffaTS0lLbFhwcrEceeURvv/22K8sEAAAAAACAibh8iWJ6erqSkpIUHR2tmJgYrVy5UuXl5UpOTpZ04r5WBw4c0Nq1ayVJycnJWrp0qdLT03XfffepsLBQq1at0rp162zHnDZtmoYMGaIFCxbo5ptv1htvvKGtW7fqvffea3XegIAA2xlhTXx9fRUUFKTLL7/c9U8GAAAAAAAApuBygysxMVGHDh3S/PnzVVFRoX79+ikvL09hYWGSpIqKCpWXl9viw8PDlZeXp7S0NC1btkzBwcFasmSJxo0bZ4uJjY3V+vXr9dhjj2nOnDnq06ePcnJyNGjQoFbnBQAAAAAAgGdq003mU1JSlJKS4vS9NWvWOIzFx8eruLi4xWOOHz9e48ePb3NeZ77++utWxwIAAAAAAMCcXLoHFwAAAAAAAHCuocEFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAAAAAABTo8EFAAAAAAAAU6PBBQAAAAAAAFOjwQUAAACnli9frvDwcFmtVkVFRWnHjh0txhcUFCgqKkpWq1W9e/dWdna23fvPP/+84uLidNFFF+miiy7S9ddfr3/+859ncgkAAMBD0OACAACAg5ycHKWmpmr27NkqKSlRXFycRo0apfLycqfxZWVlGj16tOLi4lRSUqJZs2Zp6tSpys3NtcVs375dd9xxh7Zt26bCwkKFhoYqISFBBw4cOFvLAgAAHRQNLgAAADhYvHixJk+erClTpigyMlJZWVkKCQnRihUrnMZnZ2crNDRUWVlZioyM1JQpUzRp0iQtWrTIFvPyyy8rJSVFAwYMUEREhJ5//nk1NjbqnXfeaXYeNTU1qq6uttsAAABORoMLAAAAdmpra1VUVKSEhAS78YSEBO3cudPpPoWFhQ7xI0eO1K5du1RXV+d0n2PHjqmurk7dunVrdi6ZmZny9/e3bSEhIS6uBgAAeAIaXAAAALBTVVWlhoYGBQYG2o0HBgaqsrLS6T6VlZVO4+vr61VVVeV0n5kzZ+riiy/W9ddf3+xcMjIydPjwYdu2f/9+F1cDAAA8gY+7JwAAAIBzk5eXl91rwzAcxk4V72xckhYuXKh169Zp+/btslqtzR7TYrHIYrG4Mm0AAOCBaHABAADATvfu3eXt7e1wttbBgwcdztJqEhQU5DTex8dHAQEBduOLFi3S008/ra1bt+qqq65q38kDAACPxCWKAAAAsOPn56eoqCjl5+fbjefn5ys2NtbpPjExMQ7xW7ZsUXR0tHx9fW1jzz77rP7nf/5HmzdvVnR0dPtPHgAAeCQaXAAAAHCQnp6uF154QatXr9aePXuUlpam8vJyJScnSzpxb6yJEyfa4pOTk7Vv3z6lp6drz549Wr16tVatWqXp06fbYhYuXKjHHntMq1evVq9evVRZWanKykodPXr0rK8PAAB0LFyiCAAAAAeJiYk6dOiQ5s+fr4qKCvXr1095eXkKCwuTJFVUVKi8vNwWHx4erry8PKWlpWnZsmUKDg7WkiVLNG7cOFvM8uXLVVtbq/Hjx9vlmjt3rubNm3dW1gUAADomGlwAAABwKiUlRSkpKU7fW7NmjcNYfHy8iouLmz3e119/3U4zAwAAsMcligAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADA1GlwAAAAAAAAwNRpcAAAAAAAAMDUaXAAAAAAAADC1NjW4li9frvDwcFmtVkVFRWnHjh0txhcUFCgqKkpWq1W9e/dWdna2Q0xubq769u0ri8Wivn37auPGjS7lraur04wZM3TllVfq/PPPV3BwsCZOnKhvv/22LUsEAAAAAACASbjc4MrJyVFqaqpmz56tkpISxcXFadSoUSovL3caX1ZWptGjRysuLk4lJSWaNWuWpk6dqtzcXFtMYWGhEhMTlZSUpN27dyspKUkTJkzQBx980Oq8x44dU3FxsebMmaPi4mJt2LBBe/fu1dixY11dIgAAAAAAAEzE5QbX4sWLNXnyZE2ZMkWRkZHKyspSSEiIVqxY4TQ+OztboaGhysrKUmRkpKZMmaJJkyZp0aJFtpisrCyNGDFCGRkZioiIUEZGhoYPH66srKxW5/X391d+fr4mTJigyy+/XIMHD9Zzzz2noqKiZptvNTU1qq6uttsAAAAAAABgLi41uGpra1VUVKSEhAS78YSEBO3cudPpPoWFhQ7xI0eO1K5du1RXV9diTNMx25JXkg4fPiwvLy9deOGFTt/PzMyUv7+/bQsJCWn2WAAAAAAAADg3udTgqqqqUkNDgwIDA+3GAwMDVVlZ6XSfyspKp/H19fWqqqpqMabpmG3Je/z4cc2cOVN33nmnunbt6jQmIyNDhw8ftm379+9vZuUAAAAAAAA4V/m0ZScvLy+714ZhOIydKv7k8dYcs7V56+rqdPvtt6uxsVHLly9vdl4Wi0UWi6XZ9wEAAAAAAHDuc6nB1b17d3l7ezucNXXw4EGHs6uaBAUFOY338fFRQEBAizFNx3Qlb11dnSZMmKCysjK9++67zZ69BQAAAAAAgI7BpUsU/fz8FBUVpfz8fLvx/Px8xcbGOt0nJibGIX7Lli2Kjo6Wr69vizFNx2xt3qbm1hdffKGtW7faGmgAAAAAAADouFy+RDE9PV1JSUmKjo5WTEyMVq5cqfLyciUnJ0s6cV+rAwcOaO3atZKk5ORkLV26VOnp6brvvvtUWFioVatWad26dbZjTps2TUOGDNGCBQt0880364033tDWrVv13nvvtTpvfX29xo8fr+LiYv31r39VQ0OD7Yyvbt26yc/Pr+2fEgAAAAAAAM5ZLje4EhMTdejQIc2fP18VFRXq16+f8vLyFBYWJkmqqKhQeXm5LT48PFx5eXlKS0vTsmXLFBwcrCVLlmjcuHG2mNjYWK1fv16PPfaY5syZoz59+ignJ0eDBg1qdd5vvvlGmzZtkiQNGDDAbs7btm3T0KFDXV0qAAAAAAAATKBNN5lPSUlRSkqK0/fWrFnjMBYfH6/i4uIWjzl+/HiNHz++zXl79eplu3k9AAAAAAAAPIdL9+ACAAAAAAAAzjU0uAAAAAAAAGBqNLgAAADg1PLlyxUeHi6r1aqoqCjt2LGjxfiCggJFRUXJarWqd+/eys7Otnv/008/1bhx49SrVy95eXkpKyvrDM4eAAB4EhpcAAAAcJCTk6PU1FTNnj1bJSUliouL06hRo+weJvRLZWVlGj16tOLi4lRSUqJZs2Zp6tSpys3NtcUcO3ZMvXv31jPPPKOgoKCztRQAAOAB2nSTeQAAAHRsixcv1uTJkzVlyhRJUlZWlt5++22tWLFCmZmZDvHZ2dkKDQ21nZUVGRmpXbt2adGiRbanZ19zzTW65pprJEkzZ85s1TxqampUU1Nje11dXX06ywIAAB0UZ3ABAADATm1trYqKipSQkGA3npCQoJ07dzrdp7Cw0CF+5MiR2rVrl+rq6to8l8zMTPn7+9u2kJCQNh8LAAB0XDS4AAAAYKeqqkoNDQ0KDAy0Gw8MDFRlZaXTfSorK53G19fXq6qqqs1zycjI0OHDh23b/v3723wsAADQcXGJIgAAAJzy8vKye20YhsPYqeKdjbvCYrHIYrG0eX8AAOAZOIMLAAAAdrp37y5vb2+Hs7UOHjzocJZWk6CgIKfxPj4+CggIOGNzBQAAkGhwAQAA4CR+fn6KiopSfn6+3Xh+fr5iY2Od7hMTE+MQv2XLFkVHR8vX1/eMzRUAAECiwQUAAAAn0tPT9cILL2j16tXas2eP0tLSVF5eruTkZEkn7o01ceJEW3xycrL27dun9PR07dmzR6tXr9aqVas0ffp0W0xtba1KS0tVWlqq2tpaHThwQKWlpfryyy/P+voAAEDHwj24AAAA4CAxMVGHDh3S/PnzVVFRoX79+ikvL09hYWGSpIqKCpWXl9viw8PDlZeXp7S0NC1btkzBwcFasmSJxo0bZ4v59ttvNXDgQNvrRYsWadGiRYqPj9f27dvP2toAAEDHQ4MLAAAATqWkpCglJcXpe2vWrHEYi4+PV3FxcbPH69Wrl+3G8wAAAO2JSxQBAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqNLgAAAAAAABgajS4AAAAAAAAYGo0uAAAAAAAAGBqPu6eAJxobJC+2yH9XCF16in9Kk46z7vj5XRXXnJ2rJzuykvOjpXTXXk9JScAAABwhtHgOtfs3yAVTZOOffO/Y50vkaL+JIXc1nFyuisvOTtWTnflJWfHyumuvJ6SEwAAADgL2nSJ4vLlyxUeHi6r1aqoqCjt2LGjxfiCggJFRUXJarWqd+/eys7OdojJzc1V3759ZbFY1LdvX23cuNHlvIZhaN68eQoODlanTp00dOhQffrpp21Zonvs3yDtGG///3hI0rEDJ8b3b+gYOd2Vl5wdK6e78pKzY+V0V15PyQnTc1fNBwAA4CqXG1w5OTlKTU3V7NmzVVJSori4OI0aNUrl5eVO48vKyjR69GjFxcWppKREs2bN0tSpU5Wbm2uLKSwsVGJiopKSkrR7924lJSVpwoQJ+uCDD1zKu3DhQi1evFhLly7Vhx9+qKCgII0YMUJHjhxxdZlnX2PDib+qy3Dy5v8fK0o9EWfmnO7KS86OldNdecnZsXK6K6+n5ITpuavmAwAAaAuXL1FcvHixJk+erClTpkiSsrKy9Pbbb2vFihXKzMx0iM/OzlZoaKiysrIkSZGRkdq1a5cWLVqkcePG2Y4xYsQIZWRkSJIyMjJUUFCgrKwsrVu3rlV5DcNQVlaWZs+erdtuO3GZxZ///GcFBgbqlVde0QMPPOAwt5qaGtXU1NheHz58WJJUXV3t6sfSapWVlaqsrHQYv+BYiX5d9Y2TPZoY0rH9+nLLUh3tPNBpRFBQkIKCgs7pnO2R11NyNpfXHTlbysv36NzO2Vxevkftk9dTcjrLe/ToUUnS/t37VfNTjdN9TsfBLw/a8rT3f5ebjmcYzhp+aOKumu9kZ7Nec8f32l3/ls5kXnLyPSLnuZ2zubx8j8jZXnnbQ5vqNcMFNTU1hre3t7Fhwwa78alTpxpDhgxxuk9cXJwxdepUu7ENGzYYPj4+Rm1trWEYhhESEmIsXrzYLmbx4sVGaGhoq/N+9dVXhiSjuLjYLmbs2LHGxIkTnc5t7ty5hk786ZqNjY2NjY3Nw7b9+/e3VPZ4NHfVfM5Qr7GxsbGxsXnu5kq95tIZXFVVVWpoaFBgYKDdeGBgYLN/Na6srHQaX19fr6qqKvXs2bPZmKZjtiZv0/86i9m3b5/TuWVkZCg9Pd32urGxUd9//70CAgLk5eXldJ+zpbq6WiEhIdq/f7+6du3q1rmcaZ6yVtbZ8XjKWj1lnZLnrNXT12kYho4cOaLg4GA3zu7c5q6azxnqNffzlHVKnrNW1tnxeMpaPWWdkuestT3rtTY9RfHkYsIwjBYLDGfxJ4+35pjtFdPEYrHIYrHYjV144YXNrMI9unbt2qG/zL/kKWtlnR2Pp6zVU9Ypec5aPXmd/v7+bpqNubir5vsl6rVzh6esU/KctbLOjsdT1uop65Q8Z63tUa+5dJP57t27y9vb2+GvbAcPHnT4a1yToKAgp/E+Pj4KCAhoMabpmK3J23RfEFfmBgAAAEfuqvkAAADayqUGl5+fn6KiopSfn283np+fr9jYWKf7xMTEOMRv2bJF0dHR8vX1bTGm6ZityRseHq6goCC7mNraWhUUFDQ7NwAAADhyV80HAADQZq2+W9f/t379esPX19dYtWqV8dlnnxmpqanG+eefb3z99deGYRjGzJkzjaSkJFv8v//9b6Nz585GWlqa8dlnnxmrVq0yfH19jddee80W849//MPw9vY2nnnmGWPPnj3GM888Y/j4+Bjvv/9+q/MahmE888wzhr+/v7Fhwwbj448/Nu644w6jZ8+eRnV1tavLdLvjx48bc+fONY4fP+7uqZxxnrJW1tnxeMpaPWWdhuE5a2WdaA131Xxm4infMU9Zp2F4zlpZZ8fjKWv1lHUahuestT3X6XKDyzAMY9myZUZYWJjh5+dnXH311UZBQYHtvbvvvtuIj4+3i9++fbsxcOBAw8/Pz+jVq5exYsUKh2P+5S9/MS6//HLD19fXiIiIMHJzc13KaxiG0djYaMydO9cICgoyLBaLMWTIEOPjjz9uyxIBAAA8nrtqPgAAAFd5Gcb/v/snAAAAAAAAYEIu3YMLAAAAAAAAONfQ4AIAAAAAAICp0eACAAAAAACAqdHgAgAAAAAAgKnR4DpHLV++XOHh4bJarYqKitKOHTvcPaV2lZmZqWuuuUZdunRRjx49dMstt+jzzz9397TOuMzMTHl5eSk1NdXdUzkjDhw4oLvuuksBAQHq3LmzBgwYoKKiIndPq13V19frscceU3h4uDp16qTevXtr/vz5amxsdPfUTtvf//533XTTTQoODpaXl5def/11u/cNw9C8efMUHBysTp06aejQofr000/dM9nT0NI66+rqNGPGDF155ZU6//zzFRwcrIkTJ+rbb79134RPw6l+pr/0wAMPyMvLS1lZWWdtfu2lNevcs2ePxo4dK39/f3Xp0kWDBw9WeXn52Z8sOoyOXqtJ1GvUa+ZFvUa9ZhaeUqtJZ6deo8F1DsrJyVFqaqpmz56tkpISxcXFadSoUR2qEC8oKNCDDz6o999/X/n5+aqvr1dCQoJ++uknd0/tjPnwww+1cuVKXXXVVe6eyhnxww8/6Nprr5Wvr6/+9re/6bPPPtMf/vAHXXjhhe6eWrtasGCBsrOztXTpUu3Zs0cLFy7Us88+q+eee87dUzttP/30k/r376+lS5c6fX/hwoVavHixli5dqg8//FBBQUEaMWKEjhw5cpZnenpaWuexY8dUXFysOXPmqLi4WBs2bNDevXs1duxYN8z09J3qZ9rk9ddf1wcffKDg4OCzNLP2dap1fvXVV7ruuusUERGh7du3a/fu3ZozZ46sVutZnik6Ck+o1STqtY6Ieo16zSw8pV7zlFpNOkv1moFzzm9+8xsjOTnZbiwiIsKYOXOmm2Z05h08eNCQZBQUFLh7KmfEkSNHjEsvvdTIz8834uPjjWnTprl7Su1uxowZxnXXXefuaZxxY8aMMSZNmmQ3dttttxl33XWXm2Z0ZkgyNm7caHvd2NhoBAUFGc8884xt7Pjx44a/v7+RnZ3thhm2j5PX6cw///lPQ5Kxb9++szOpM6S5tX7zzTfGxRdfbHzyySdGWFiY8cc//vGsz609OVtnYmJih/s3CvfyxFrNMKjXOgLqtY713wLqtf/VEeo1T6nVDOPM1WucwXWOqa2tVVFRkRISEuzGExIStHPnTjfN6sw7fPiwJKlbt25unsmZ8eCDD2rMmDG6/vrr3T2VM2bTpk2Kjo7Wb3/7W/Xo0UMDBw7U888/7+5ptbvrrrtO77zzjvbu3StJ2r17t9577z2NHj3azTM7s8rKylRZWWn3u8lisSg+Pr5D/26STvx+8vLy6nB/3ZakxsZGJSUl6ZFHHtEVV1zh7umcEY2NjXrrrbd02WWXaeTIkerRo4cGDRrU4iUAQEs8tVaTqNc6Auo16rWOqqPWa55Qq0ntV6/R4DrHVFVVqaGhQYGBgXbjgYGBqqysdNOszizDMJSenq7rrrtO/fr1c/d02t369etVVFSkzMxMd0/ljPr3v/+tFStW6NJLL9Xbb7+t5ORkTZ06VWvXrnX31NrVjBkzdMcddygiIkK+vr4aOHCgUlNTdccdd7h7amdU0+8fT/rdJEnHjx/XzJkzdeedd6pr167unk67W7BggXx8fDR16lR3T+WMOXjwoI4ePapnnnlGN9xwg7Zs2aJbb71Vt912mwoKCtw9PZiQJ9ZqEvVaR0G9Rr3WEXXkes0TajWp/eo1nzM4R5wGLy8vu9eGYTiMdRQPPfSQPvroI7333nvunkq7279/v6ZNm6YtW7Z0+Hu9NDY2Kjo6Wk8//bQkaeDAgfr000+1YsUKTZw40c2zaz85OTl66aWX9Morr+iKK65QaWmpUlNTFRwcrLvvvtvd0zvjPOl3U11dnW6//XY1NjZq+fLl7p5OuysqKtKf/vQnFRcXd9ifoSTbDYVvvvlmpaWlSZIGDBignTt3Kjs7W/Hx8e6cHkzMk34fStRrHQX1GvVaR9OR6zVPqdWk9qvXOIPrHNO9e3d5e3s7dNgPHjzo0InvCH7/+99r06ZN2rZtmy655BJ3T6fdFRUV6eDBg4qKipKPj498fHxUUFCgJUuWyMfHRw0NDe6eYrvp2bOn+vbtazcWGRnZ4W64+8gjj2jmzJm6/fbbdeWVVyopKUlpaWkd/i++QUFBkuQxv5vq6uo0YcIElZWVKT8/v8P9NVCSduzYoYMHDyo0NNT2+2nfvn16+OGH1atXL3dPr910795dPj4+HvH7CWeHp9VqEvUa9Zr5UK95xu+njl6veUqtJrVfvUaD6xzj5+enqKgo5efn243n5+crNjbWTbNqf4Zh6KGHHtKGDRv07rvvKjw83N1TOiOGDx+ujz/+WKWlpbYtOjpav/vd71RaWipvb293T7HdXHvttQ6PDt+7d6/CwsLcNKMz49ixYzrvPPtfnd7e3h3isdMtCQ8PV1BQkN3vptraWhUUFHSo303S/xZLX3zxhbZu3aqAgAB3T+mMSEpK0kcffWT3+yk4OFiPPPKI3n77bXdPr934+fnpmmuu8YjfTzg7PKVWk6jXqNfMi3qNeq0j8JRaTWq/eo1LFM9B6enpSkpKUnR0tGJiYrRy5UqVl5crOTnZ3VNrNw8++KBeeeUVvfHGG+rSpYvtrwz+/v7q1KmTm2fXfrp06eJwn4rzzz9fAQEBHe7+FWlpaYqNjdXTTz+tCRMm6J///KdWrlyplStXuntq7eqmm27SU089pdDQUF1xxRUqKSnR4sWLNWnSJHdP7bQdPXpUX375pe11WVmZSktL1a1bN4WGhio1NVVPP/20Lr30Ul166aV6+umn1blzZ915551unLXrWlpncHCwxo8fr+LiYv31r39VQ0OD7fdTt27d5Ofn565pt8mpfqYnF4O+vr4KCgrS5ZdffranelpOtc5HHnlEiYmJGjJkiIYNG6bNmzfrzTff1Pbt2903aZiaJ9RqEvUa9Zp5Ua9Rr5mFp9Rq0lmq107rGYw4Y5YtW2aEhYUZfn5+xtVXX93hHscsyen24osvuntqZ1xHfey0YRjGm2++afTr18+wWCxGRESEsXLlSndPqd1VV1cb06ZNM0JDQw2r1Wr07t3bmD17tlFTU+PuqZ22bdu2Of13effddxuGceLR03PnzjWCgoIMi8ViDBkyxPj444/dO+k2aGmdZWVlzf5+2rZtm7un7rJT/UxPZtZHT7dmnatWrTJ+/etfG1ar1ejfv7/x+uuvu2/C6BA6eq1mGNRr1GvmRb1GvWYWnlKrGcbZqde8DMMwWt8OAwAAAAAAAM4t3IMLAAAAAAAApkaDCwAAAAAAAKZGgwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKZGgwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKZGgwsAAAAAAACmRoMLAAAAAAAApkaDCwAAAAAAAKb2/wAst3qZIyioRQAAAABJRU5ErkJggg==",
- "text/plain": [
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Jl5mlU65DCXh"
+ },
+ "source": [
+ "Let us start with some simple examples of calculating the Earth Mover's distance between two distributions - the basis of Optimal Transport for bias detection. We do this using the `earth_movers_distance` function.\n",
+ "\n",
+ "For concrete examples of bias detection on real datasets, skip to the next chapter."
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# Drawing both of them\n",
- "figure, axis = plt.subplots(1, 2)\n",
- "figure.set_figheight(4)\n",
- "figure.set_figwidth(12)\n",
- "figure.tight_layout(w_pad = 5)\n",
- "\n",
- "def draw(y, id):\n",
- " x = np.array(range(0, np.size(y)))\n",
- " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
- " axis[id].scatter(x, y, color=\"orange\")\n",
- "\n",
- "axis[0].title.set_text(\"Initial distribution\")\n",
- "axis[1].title.set_text(\"Required distribution\")\n",
- "draw(a, 0)\n",
- "draw(b, 1)\n",
- "\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
+ },
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.9375\n"
- ]
- }
- ],
- "source": [
- "import pandas as pd\n",
- "from aif360.sklearn.metrics import ot_distance\n",
- "\n",
- "_a = pd.Series(a)\n",
- "_b = pd.Series(b)\n",
- "c = ot_distance(_a, _b)\n",
- "\n",
- "print(c)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Usage\n",
- "\n",
- "The type of outcomes must be provided using the `mode` keyword argument. The definition for the four types of outcomes supported are provided below:\n",
- "- Binary: Yes/no outcomes. Outcomes must 0 or 1.\n",
- "- Continuous: Continuous outcomes. Outcomes could be any real number.\n",
- "- Nominal: Multiclass outcomes with no rank or order between them. Outcomes must be a finite set of integers.\n",
- "- Ordinal: Multiclass outcomes that are ranked in a specific order. Outcomes must be positive integers."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Compas Dataset"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll demonstrate finding the scanning for bias with earth_movers_distance using the Compas dataset. We scan for bias in the predictions of an `sklearn` logistic regression model with respect to different groups."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "\n",
- "np.random.seed(0)\n",
- "dataset_orig = load_preproc_data_compas()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We scan for bias at first with respect to `sex`, and then `age`.\n",
- "\n",
- "To scan for bias with respect for a feature that is one-hot encoded - in this case, age category - we need to convert it to nominal or ordinal format."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Zxqm3r4bDCXh"
+ },
+ "source": [
+ "### 1. General Optimal Transport\n",
+ "\n",
+ "Suppose we have two distributions $a$ and $b$ (as shown in the picture below), and we need to calculate the Wasserstein distance between these two distributions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mf1QLAOSDCXh"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "# Initial distribution\n",
+ "a = np.array([0., 0.01547988, 0.03095975, 0.04643963, 0.05727554, 0.05417957, 0.04643963, 0.07739938,\n",
+ " 0.10835913, 0.12383901, 0.11764706, 0.10526316, 0.09287926, 0.07739938, 0.04643962, 0. ])\n",
+ "# Required distribution\n",
+ "b = np.array([0., 0.01829787, 0.02702128, 0.04106383, 0.07, 0.10829787, 0.14212766, 0.14468085,\n",
+ " 0.13, 0.10808511, 0.08255319, 0.05170213, 0.03361702, 0.02702128, 0.01553191, 0. ])"
+ ]
+ },
{
- "data": {
- "text/html": [
- "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
],
- "text/plain": [
- " sex race priors_count=0 priors_count=1 to 3 priors_count=More than 3 \\\n",
- "0 0.0 0.0 1.0 0.0 0.0 \n",
- "1 0.0 0.0 0.0 0.0 1.0 \n",
- "2 0.0 1.0 0.0 0.0 1.0 \n",
- "3 1.0 1.0 1.0 0.0 0.0 \n",
- "4 0.0 1.0 1.0 0.0 0.0 \n",
- "\n",
- " c_charge_degree=F c_charge_degree=M age_cat two_year_recid \n",
- "0 1.0 0.0 1 1.0 \n",
- "1 1.0 0.0 0 1.0 \n",
- "2 1.0 0.0 1 1.0 \n",
- "3 0.0 1.0 1 0.0 \n",
- "4 1.0 0.0 1 0.0 "
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "from scipy.interpolate import make_interp_spline\n",
+ "\n",
+ "# Drawing both of them\n",
+ "figure, axis = plt.subplots(1, 2)\n",
+ "figure.set_figheight(4)\n",
+ "figure.set_figwidth(12)\n",
+ "figure.tight_layout(w_pad = 5)\n",
+ "\n",
+ "def draw(y, id):\n",
+ " x = np.array(range(0, np.size(y)))\n",
+ " XYSpline = make_interp_spline(x, y)\n",
+ " X = np.linspace(x.min(), x.max(), 500)\n",
+ " Y = XYSpline(X)\n",
+ " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
+ " axis[id].scatter(x, y, color=\"orange\")\n",
+ " axis[id].plot(X, Y, color='blue')\n",
+ "\n",
+ "axis[0].title.set_text(\"Initial distribution\")\n",
+ "axis[1].title.set_text(\"Required distribution\")\n",
+ "draw(a, 0)\n",
+ "draw(b, 1)\n",
+ "\n",
+ "plt.show()"
]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
- "# Binning the features corresponding to age ('reshaping' them into one ordinal column)\n",
- "age_cat_cols = ['age_cat=Less than 25', 'age_cat=25 to 45', 'age_cat=Greater than 45']\n",
- "age_cat = np.argmax(dataset_orig_df[age_cat_cols].values, axis=1).reshape(-1, 1)\n",
- "df = dataset_orig_df.drop(age_cat_cols, axis=1)\n",
- "df['age_cat'] = age_cat\n",
- "df['two_year_recid'] = dataset_orig.labels\n",
- "df.head()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Measuring bias with respect to `sex`\n",
- "\n",
- "\n",
- "We train a linear regression model on the dataset, and scan its results for bias with respect to `sex` using `earth_movers_distance`.\n",
- "\n",
- "The arguments are as follows:\n",
- "- `ground_truth`: ground truth labels;\n",
- "- `classifier`: predicted labels;\n",
- "- `prot_attr`: the values of the sensitive attributes (with respect to which the classifier may be introducing bias);\n",
- "- `num_iters`: maximum number of iterations performed when calculating the Earth Mover's Distance;\n",
- "- `mode`: mode of the labels, one of binary, nominal, ordinal and continious; in our case the labels are binary."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "from aif360.sklearn.metrics import ot_distance\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn.model_selection import train_test_split\n",
- "\n",
- "X = df.drop('two_year_recid', axis=1)\n",
- "y = df['two_year_recid']\n",
- "clf = LogisticRegression(solver='lbfgs', max_iter=10000, C=1.0, penalty='l2')\n",
- "clf.fit(X, y)\n",
- "preds = pd.Series(clf.predict_proba(X)[:,0])\n",
- "\n",
- "ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=df['sex'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
+ },
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
sex
\n",
- "
ot_val
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
0.0
\n",
- "
0.000209
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
1.0
\n",
- "
0.001647
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_w9qXdzUDCXj"
+ },
+ "source": [
+ "In order to better understand how Optimal Transport works, below is presented the code considering the case when the matrix cost distance is presented and defined as the absolute difference between positions of each part of the distribution. That is $\\text{distance}[i][j] = abs(i - j)$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3Y0nUhvcDCXj"
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "_a = pd.Series(a)\n",
+ "_b = pd.Series(b)\n",
+ "distance = np.zeros((np.size(a), np.size(b)))\n",
+ "for i in range(np.size(a)):\n",
+ " for j in range(np.size(b)):\n",
+ " distance[i][j] = abs(i - j)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "P2i9hj8VDCXj",
+ "outputId": "d2affef6-0656-40f3-f5ad-1be39352a547"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wasserstein distance is equal to 1.3773703499999999.\n"
+ ]
+ }
],
- "text/plain": [
- " sex ot_val\n",
- "0 0.0 0.000209\n",
- "1 1.0 0.001647"
+ "source": [
+ "from aif360.sklearn.metrics import ot_distance\n",
+ "c0 = ot_distance(y_true=_a, y_pred=_b, cost_matrix=distance, mode='continuous')\n",
+ "\n",
+ "print(\"Wasserstein distance is equal to \", c0, \".\", sep=\"\")"
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "bs1 = pd.DataFrame({\"sex\": ot_val1.keys(), \"ot_val\": ot_val1.values()})\n",
- "display(bs1)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We get the bias value for each each of the protected groups - in this case, Male (`0`) and Female (`1`). \n",
- "\n",
- "These values range from 0 to 1 and can be interpreted as the difference in percent between the ground truth distribution and the distribution of the protected group: for example, a value of 0.3 would mean a 30% difference."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Measuring bias with respect to `age_cat`\n",
- "\n",
- "Now we measure the bias of the same classifier with respect to the age category."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [],
- "source": [
- "ot_val2 = ot_distance(y_true=y, y_pred=preds, prot_attr=df['age_cat'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iXt221UKDCXk"
+ },
+ "source": [
+ "### 2. Randomly distributed samples\n",
+ "\n",
+ "Suppose we have two distributions $a$ and $b$ with length $N$, that are generated randomly, and we need to calculate earth_movers_distance for them."
+ ]
+ },
{
- "data": {
- "text/html": [
- "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
],
- "text/plain": [
- " sex ot_val\n",
- "0 0.0 0.000503\n",
- "1 1.0 0.000067"
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Drawing both of them\n",
+ "figure, axis = plt.subplots(1, 2)\n",
+ "figure.set_figheight(4)\n",
+ "figure.set_figwidth(12)\n",
+ "figure.tight_layout(w_pad = 5)\n",
+ "\n",
+ "def draw(y, id):\n",
+ " x = np.array(range(0, np.size(y)))\n",
+ " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
+ " axis[id].scatter(x, y, color=\"orange\")\n",
+ "\n",
+ "axis[0].title.set_text(\"Initial distribution\")\n",
+ "axis[1].title.set_text(\"Required distribution\")\n",
+ "draw(a, 0)\n",
+ "draw(b, 1)\n",
+ "\n",
+ "plt.show()"
]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "bs1 = pd.DataFrame({\"sex\": ot_val1.keys(), \"ot_val\": ot_val1.values()})\n",
- "display(bs1)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Measuring bias with respect to `race`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [],
- "source": [
- "ot_val2 = ot_distance(y_true=y, y_pred=preds, prot_attr=data['race'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2b8iTs6CDCXl"
+ },
+ "source": [
+ "There, since we can go from the initial distribution to the desired one just using permutations, the Wasserstein distance is zero."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PeRPqgPyDCXl",
+ "outputId": "bd589c43-78e5-4232-a8bb-674d8368066c"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "from aif360.sklearn.metrics import ot_distance\n",
+ "\n",
+ "_a = pd.Series(a)\n",
+ "_b = pd.Series(b)\n",
+ "c = ot_distance(_a, _b, mode='continuous')\n",
+ "\n",
+ "print(c)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "V5JjvRvJDCXl"
+ },
+ "source": [
+ "### 4. Extreme case\n",
+ "\n",
+ "One more example that is closer to our case is \"normalization\". It's an explanation of why the maximum Wasserstein distance we can get in our case is approaching 1 (with increasing the size of the sample), that is, it is normalized. We get this in the case that all our population has a value 0 of the 2-year recidivism (which is presented in the paragraph \"Compas Dataset\") and the classifier fails massively in all the cases labeling all with a 1. That would be the worst-case scenario."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "C-XRy_ZNDCXl"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "# Initial distribution\n",
+ "a = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.001])\n",
+ "# Required distribution\n",
+ "b = np.array([0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Wr7ugRoUDCXl",
+ "outputId": "73b33e31-cbe6-4ab1-994c-2922c8b51be2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Drawing both of them\n",
+ "figure, axis = plt.subplots(1, 2)\n",
+ "figure.set_figheight(4)\n",
+ "figure.set_figwidth(12)\n",
+ "figure.tight_layout(w_pad = 5)\n",
+ "\n",
+ "def draw(y, id):\n",
+ " x = np.array(range(0, np.size(y)))\n",
+ " axis[id].bar(x, y, color=\"lightgreen\", ec='black')\n",
+ " axis[id].scatter(x, y, color=\"orange\")\n",
+ "\n",
+ "axis[0].title.set_text(\"Initial distribution\")\n",
+ "axis[1].title.set_text(\"Required distribution\")\n",
+ "draw(a, 0)\n",
+ "draw(b, 1)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RuuNFQUcDCXl",
+ "outputId": "b0dc6795-09a7-4160-b8b5-fbef7b182f11"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.9375\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "from aif360.sklearn.metrics import ot_distance\n",
+ "\n",
+ "_a = pd.Series(a)\n",
+ "_b = pd.Series(b)\n",
+ "c = ot_distance(_a, _b)\n",
+ "\n",
+ "print(c)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "982f8APNDCXm"
+ },
+ "source": [
+ "## Usage\n",
+ "\n",
+ "The type of outcomes must be provided using the `mode` keyword argument. The definition for the four types of outcomes supported are provided below:\n",
+ "- Binary: Yes/no outcomes. Outcomes must 0 or 1.\n",
+ "- Continuous: Continuous outcomes. Outcomes could be any real number.\n",
+ "- Nominal: Multiclass outcomes with no rank or order between them. Outcomes must be a finite set of integers.\n",
+ "- Ordinal: Multiclass outcomes that are ranked in a specific order. Outcomes must be positive integers."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AXZLAng1DCXm"
+ },
+ "source": [
+ "## Compas Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JbzgBkZ6DCXm"
+ },
+ "source": [
+ "We'll demonstrate finding the scanning for bias with earth_movers_distance using the Compas dataset. We scan for bias in the predictions of an `sklearn` logistic regression model with respect to different groups."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "S_a5CpBmDCXm"
+ },
+ "outputs": [],
+ "source": [
+ "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_compas\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "np.random.seed(0)\n",
+ "dataset_orig = load_preproc_data_compas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MetVVGgEDCXm"
+ },
+ "source": [
+ "We scan for bias at first with respect to `sex`, and then `age`.\n",
+ "\n",
+ "To scan for bias with respect for a feature that is one-hot encoded - in this case, age category - we need to convert it to nominal or ordinal format."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nKjqZpvrDCXm",
+ "outputId": "eb2d18a6-3d40-4a89-f9de-89b7b82441cd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
race
\n",
+ "
priors_count=0
\n",
+ "
priors_count=1 to 3
\n",
+ "
priors_count=More than 3
\n",
+ "
c_charge_degree=F
\n",
+ "
c_charge_degree=M
\n",
+ "
age_cat
\n",
+ "
two_year_recid
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1
\n",
+ "
1.0
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1
\n",
+ "
0.0
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
0.0
\n",
+ "
1.0
\n",
+ "
0.0
\n",
+ "
1
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex race priors_count=0 priors_count=1 to 3 priors_count=More than 3 \\\n",
+ "0 0.0 0.0 1.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 1.0 \n",
+ "2 0.0 1.0 0.0 0.0 1.0 \n",
+ "3 1.0 1.0 1.0 0.0 0.0 \n",
+ "4 0.0 1.0 1.0 0.0 0.0 \n",
+ "\n",
+ " c_charge_degree=F c_charge_degree=M age_cat two_year_recid \n",
+ "0 1.0 0.0 1 1.0 \n",
+ "1 1.0 0.0 0 1.0 \n",
+ "2 1.0 0.0 1 1.0 \n",
+ "3 0.0 1.0 1 0.0 \n",
+ "4 1.0 0.0 1 0.0 "
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset_orig_df = pd.DataFrame(dataset_orig.features, columns=dataset_orig.feature_names)\n",
+ "# Binning the features corresponding to age ('reshaping' them into one ordinal column)\n",
+ "age_cat_cols = ['age_cat=Less than 25', 'age_cat=25 to 45', 'age_cat=Greater than 45']\n",
+ "age_cat = np.argmax(dataset_orig_df[age_cat_cols].values, axis=1).reshape(-1, 1)\n",
+ "df = dataset_orig_df.drop(age_cat_cols, axis=1)\n",
+ "df['age_cat'] = age_cat\n",
+ "df['two_year_recid'] = dataset_orig.labels\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gdQM2Pf0DCXm"
+ },
+ "source": [
+ "### Measuring bias with respect to `sex`\n",
+ "\n",
+ "\n",
+ "We train a linear regression model on the dataset, and scan its results for bias with respect to `sex` using `earth_movers_distance`.\n",
+ "\n",
+ "The arguments are as follows:\n",
+ "- `ground_truth`: ground truth labels;\n",
+ "- `classifier`: predicted labels;\n",
+ "- `prot_attr`: the values of the sensitive attributes (with respect to which the classifier may be introducing bias);\n",
+ "- `num_iters`: maximum number of iterations performed when calculating the Earth Mover's Distance;\n",
+ "- `mode`: mode of the labels, one of binary, nominal, ordinal and continious; in our case the labels are binary."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1UMII6BdDCXn"
+ },
+ "outputs": [],
+ "source": [
+ "from aif360.sklearn.metrics import ot_distance\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "\n",
+ "X = df.drop('two_year_recid', axis=1)\n",
+ "y = df['two_year_recid']\n",
+ "clf = LogisticRegression(solver='lbfgs', max_iter=10000, C=1.0, penalty='l2')\n",
+ "clf.fit(X, y)\n",
+ "preds = pd.Series(clf.predict_proba(X)[:,0])\n",
+ "\n",
+ "ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=df['sex'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "dJ9XucBpDCXn",
+ "outputId": "14f8db55-0698-4582-f0ab-a0d5b51d32c0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
sex
\n",
+ "
ot_val
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
0.0
\n",
+ "
0.000209
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
1.0
\n",
+ "
0.001647
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sex ot_val\n",
+ "0 0.0 0.000209\n",
+ "1 1.0 0.001647"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "bs1 = pd.DataFrame({\"sex\": ot_val1.keys(), \"ot_val\": ot_val1.values()})\n",
+ "display(bs1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qMuMdXc9DCXn"
+ },
+ "source": [
+ "We get the bias value for each each of the protected groups - in this case, Male (`0`) and Female (`1`).\n",
+ "\n",
+ "These values range from 0 to 1 and can be interpreted as the difference in percent between the ground truth distribution and the distribution of the protected group: for example, a value of 0.3 would mean a 30% difference."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AZHqN9JTDCXn"
+ },
+ "source": [
+ "### Measuring bias with respect to `age_cat`\n",
+ "\n",
+ "Now we measure the bias of the same classifier with respect to the age category."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "usxcylLxDCXn"
+ },
+ "outputs": [],
+ "source": [
+ "ot_val2 = ot_distance(y_true=y, y_pred=preds, prot_attr=df['age_cat'])"
+ ]
+ },
{
- "data": {
- "text/html": [
- "