{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Obtaining an explanation\n",
"\n",
"In this notebook we will show how to obtain an explanation for a given model and data image. We will use CIFAR10 as the dataset and an Efficientnet-B2 model as the classifier. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data.dataloader import DataLoader\n",
"from efficientnet_pytorch import EfficientNet\n",
"from ReVel.perturbations import get_perturbation\n",
"from ReVel.LLEs import get_xai_model\n",
"from ReVel.revel.revel import ReVel\n",
"from ReVel.load_data import load_data\n",
"import numpy as np\n",
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"from plotly.subplots import make_subplots\n",
"from IPython.display import display\n",
"import pandas as pd\n",
"import json\n",
"#import imports\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"n_classes = {'CIFAR10':10, 'CIFAR100':100, 'EMNIST':47,'FashionMNIST':10}\n",
"\n",
"# Cargar la configuración del experimento\n",
"with open('config.json', 'r') as f:\n",
" config = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/CIFAR10/test/cifar-10-python.tar.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 170498071/170498071 [00:09<00:00, 17383534.76it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./data/CIFAR10/test/cifar-10-python.tar.gz to ./data/CIFAR10/test\n",
"Loaded the pretrained model.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_96406/798391298.py:12: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(f\"../../../models/classifier_{config['dataset_W']}.pt\",map_location=device)\n"
]
}
],
"source": [
"num_classes = n_classes[config[\"dataset_W\"]]\n",
"perturbation = get_perturbation(name=config[\"perturbation_W\"],\n",
" dim=config[\"features_W\"],\n",
" num_classes=n_classes[config[\"dataset_W\"]],\n",
" final_size=(224,224),\n",
" kernel=config[\"kernel_W\"],\n",
" max_dist= config[\"maxDist_W\"],\n",
" ratio=config[\"ratio_W\"],)\n",
"Test = load_data(config[\"dataset_W\"],perturbation = perturbation,train=False,dir=\"./data\")\n",
"TestLoader = iter(DataLoader(Test, batch_size=1, shuffle=False))\n",
"classifier = EfficientNet.from_name(\"efficientnet-b2\",num_classes=len(Test.classes))\n",
"state_dict = torch.load(f\"../../../models/classifier_{config['dataset_W']}.pt\",map_location=device)\n",
"classifier.load_state_dict(state_dict)\n",
"classifier.to(device)\n",
"classifier.eval()\n",
"print(\"Loaded the pretrained model.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explanation methods\n",
"\n",
"To obtain an explanation we need to choose an explanation method. The explanation methods are choosed from LIME[[Ribeiro et al., 2016](https://arxiv.org/abs/1602.04938)] and SHAP[[Lundberg and Lee, 2017](https://arxiv.org/abs/1705.07874)] with different configurations."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Load Explainer\n",
"\n",
"explainer = get_xai_model(name=config[\"xai_model_W\"],\n",
" perturbation=perturbation,\n",
" max_examples=config[\"evaluations_W\"],\n",
" dim=config[\"features_W\"],\n",
" sigma = config[\"sigma_W\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also choose the following image to explain:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Select the input to be explained.\n",
"# It must be the first iteration of the TestLoader\n",
"inp, target = next(TestLoader)\n",
"clase = Test.classes[torch.argmax(target,dim=1)[0]]\n",
"\n",
"figure = go.Figure()\n",
"figure.add_trace(go.Image(z=inp[0].permute(1,2,0)))\n",
"# Aniade clase como titulo\n",
"figure.update_layout(title_text=f\"{clase}\",title_x=0.5)\n",
"\n",
"figure.show(renderer=\"svg\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the following code, we can visualize the probabilities generated by our model for the image: "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Pasar inp a float\n",
"inp = inp.float()\n",
"probabilities = torch.softmax(classifier(inp.to(device)),dim=1)[0]\n",
"clas = torch.argmax(probabilities)\n",
"args = torch.argsort(probabilities,descending=True)\n",
"prob_classes = [(Test.classes[arg],float(probabilities[arg]) ) for arg in args[:10]]\n",
"fig = px.bar(prob_classes,x=0,y=1,title=\"Probabilities of the classes\")\n",
"fig.update_layout(title_x=0.5)\n",
"fig.update_xaxes(title=\"Class\")\n",
"fig.update_yaxes(range=[0,1],title=\"Probability\")\n",
"fig.show(renderer=\"svg\")\n",
"#plt.bar(range(len(prob_classes)),[prob_class[1] for prob_class in prob_classes])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Genera una perturbacion de inp\n",
"var1 = inp.numpy()[0]\n",
"var1 = var1.swapaxes(0,1).swapaxes(1,2)\n",
"neutralImage = explainer.perturbation.fn_neutral_image(var1)\n",
"var = inp.numpy()\n",
"var = var[0].transpose(1,2,0)\n",
"\n",
"segments = explainer.perturbation.segmentation_fn(var)\n",
"\n",
"\n",
"\n",
"perturbation = explainer.perturbation.perturbation(img=var1,neutral=neutralImage,segments=segments,indexes=[8,9,14,15])\n",
"\n",
"# mostrar la imagen perturbada con plotly go. La imagen\n",
"# perturbada tiene dimensiones (1,3,224,224)\n",
"figure = go.Figure()\n",
"figure.add_trace(go.Image(z=perturbation))\n",
"figure.show(renderer=\"svg\")\n",
"# Mostrar los segmentos con plotly go. segments tiene\n",
"# dimensiones (224,224)\n",
"segs_fig = px.imshow(segments,title=\"Segmentation\")\n",
"segs_fig.show(renderer=\"svg\")\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img_perturbated = torch.tensor(perturbation).float()\n",
"img_perturbated = img_perturbated.permute(2,0,1)\n",
"img_perturbated = img_perturbated.unsqueeze_(0)\n",
"probabilities = torch.softmax(classifier(img_perturbated.to(device)),dim=1)[0]\n",
"clas = torch.argmax(probabilities)\n",
"args = torch.argsort(probabilities,descending=True)\n",
"\n",
"prob_classes = [(Test.classes[arg],float(probabilities[arg]) ) for arg in args[:10]]\n",
"fig = px.bar(prob_classes,x=0,y=1,title=\"Probabilities of the classes\")\n",
"fig.update_layout(title_x=0.5)\n",
"fig.update_xaxes(title=\"Class\")\n",
"fig.update_yaxes(range=[0,1],title=\"Probability\")\n",
"fig.show(renderer=\"svg\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Obtaining explanation\n",
"- Define the pipeline fot the image forward and for the feature-space forward"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def classify(image,model=classifier):\n",
" '''\n",
" This function takes an image and returns the predicted probabilities.\n",
" :param image: A tensor of shape HxWxC\n",
" :return: A tensor of shape Cx1\n",
" '''\n",
" if isinstance(image, np.ndarray):\n",
" \n",
" image = np.expand_dims(image,0)\n",
" \n",
" image = torch.Tensor(image).to(device)\n",
" \n",
" else:\n",
" image = torch.unsqueeze(image,0)\n",
" \n",
" # image dims: (N,H,W,C) -> (N,C,H,W)\n",
" \n",
" image = torch.transpose(image,3,2).transpose(2,1)\n",
" \n",
" result = model(image)\n",
" return result\n",
"\n",
"def model_fordward(X:np.array,explainator=explainer,model=classify,img=inp):\n",
" '''\n",
" This function takes a feature vector and returns the predicted probabilities of the original img.\n",
" :param X: A tensor of shape F.\n",
" :param explainator: An explainator object.\n",
" :param model: A function that takes an image and returns the predicted probabilities.\n",
" This function accept an image of shape HxWxC and returns a tensor of shape Cx1.\n",
" :param img: The original image.\n",
" :return: A tensor of shape Cx1\n",
" '''\n",
" img = img[0].permute(1,2,0)\n",
" \n",
" neutral = explainator.perturbation.fn_neutral_image(img)\n",
" \n",
" avoid = [i for i in range(len(X)) if X[i] == 0]\n",
" \n",
" segments = explainator.perturbation.segmentation_fn(img.numpy())\n",
" \n",
" \n",
" perturbation = explainator.perturbation.perturbation(img,neutral,segments=segments,indexes=avoid)\n",
" return model(perturbation)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:02<00:00, 17.86it/s]\n",
"100%|██████████| 50/50 [00:03<00:00, 16.07it/s]\n"
]
}
],
"source": [
"image = inp.numpy()[0]\n",
"image = image.swapaxes(0,1).swapaxes(1,2)\n",
"explaination = explainer.explain_instance(image,classify,segments)\n",
"explaination2 = explainer.explain_instance(image,classify,segments)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"image = inp[0].permute(1,2,0).numpy()\n",
"\n",
"revel = ReVel(model_f=classify,\n",
" model_g=model_fordward,\n",
" instance=image,\n",
" lle=explainer,\n",
" n_classes=num_classes,\n",
" segments=segments,)\n",
"figures = []\n",
"figures2 = []\n",
"\n",
"probabilities = torch.softmax(classifier(inp.to(device)),dim=1)[0]\n",
"clas = torch.argmax(probabilities)\n",
"args = torch.argsort(probabilities,descending=True)\n",
"ids = [clas for clas in args[:10]]\n",
"for i in args[:10]:\n",
" image_final = revel.coloured_importance_mask(explaination,segments,i)\n",
" figure = go.Image(z=image_final*256)\n",
" figures.append(figure)\n",
" \n",
" image_final = revel.coloured_importance_mask(explaination2,segments,i)\n",
" figure2 = go.Image(z=image_final*256)\n",
" figures2.append(figure2)\n",
"clases = [f\"{Test.classes[clas]}: {100*probabilities[clas]:.2f}%\" for clas in ids]\n",
"fig = make_subplots(rows=2, cols=5, subplot_titles=clases)\n",
"fig2 = make_subplots(rows=2, cols=5, subplot_titles=clases)\n",
"for i in range(2):\n",
" for j in range(5):\n",
" fig.add_trace(figures[i*5+j],row=i+1,col=j+1)\n",
" fig2.add_trace(figures2[i*5+j],row=i+1,col=j+1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig.show(renderer=\"svg\")\n",
"fig2.show(renderer=\"svg\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:03<00:00, 15.01it/s]\n",
"100%|██████████| 50/50 [00:02<00:00, 19.54it/s]\n",
"100%|██████████| 50/50 [00:02<00:00, 20.16it/s]\n"
]
}
],
"source": [
"df = revel.evaluate(3)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" conciseness | \n",
" local_fidelity | \n",
" local_concordance | \n",
" prescriptivity | \n",
" robustness | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.740329 | \n",
" 0.919741 | \n",
" 0.883064 | \n",
" 0.888679 | \n",
" 0.831222 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.720869 | \n",
" 0.934867 | \n",
" 0.930382 | \n",
" 0.897691 | \n",
" 0.831222 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.689639 | \n",
" 0.930308 | \n",
" 0.897731 | \n",
" 0.918164 | \n",
" 0.831222 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" conciseness local_fidelity local_concordance prescriptivity robustness\n",
"0 0.740329 0.919741 0.883064 0.888679 0.831222\n",
"1 0.720869 0.934867 0.930382 0.897691 0.831222\n",
"2 0.689639 0.930308 0.897731 0.918164 0.831222"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"conciseness 0.716946\n",
"local_fidelity 0.928305\n",
"local_concordance 0.903726\n",
"prescriptivity 0.901511\n",
"robustness 0.831222\n",
"dtype: float64"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Media por columna del df\n",
"display(df)\n",
"df.mean(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/50 [00:00, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:02<00:00, 18.82it/s]\n",
"100%|██████████| 50/50 [00:02<00:00, 18.63it/s]\n",
"100%|██████████| 50/50 [00:02<00:00, 17.04it/s]\n",
"100%|██████████| 50/50 [00:02<00:00, 19.24it/s]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" conciseness | \n",
" local_fidelity | \n",
" local_concordance | \n",
" prescriptivity | \n",
" robustness | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" 0.732805 | \n",
" 0.930852 | \n",
" 0.899099 | \n",
" 0.585245 | \n",
" 0.900865 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" 0.733857 | \n",
" 0.918754 | \n",
" 0.863023 | \n",
" 0.824418 | \n",
" 0.900865 | \n",
"
\n",
" \n",
" | 0 | \n",
" 2 | \n",
" 0.662039 | \n",
" 0.923757 | \n",
" 0.916792 | \n",
" 0.823079 | \n",
" 0.879404 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2 | \n",
" 0.724136 | \n",
" 0.923744 | \n",
" 0.885252 | \n",
" 0.818978 | \n",
" 0.879404 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id conciseness local_fidelity local_concordance prescriptivity \\\n",
"0 1 0.732805 0.930852 0.899099 0.585245 \n",
"1 1 0.733857 0.918754 0.863023 0.824418 \n",
"0 2 0.662039 0.923757 0.916792 0.823079 \n",
"1 2 0.724136 0.923744 0.885252 0.818978 \n",
"\n",
" robustness \n",
"0 0.900865 \n",
"1 0.900865 \n",
"0 0.879404 \n",
"1 0.879404 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"TestLoader = iter(DataLoader(Test, batch_size=1, shuffle=True))\n",
"df = pd.DataFrame(columns=[\"id\",\"conciseness\",'local_fidelity','local_concordance','prescriptivity','robustness'])\n",
"\n",
"for i in range(2):\n",
" image, label = next(TestLoader)\n",
" image = inp[0].permute(1,2,0).float()\n",
" segments = explainer.perturbation.segmentation_fn(image.numpy())\n",
" model_f = lambda x: classify(x,model=classifier)\n",
" model_g = lambda x: model_fordward(x,explainator=explainer,model=classify,img=image)\n",
" revel = ReVel(model_f=model_f,\n",
" model_g=model_fordward,\n",
" instance=image,\n",
" lle=explainer,\n",
" n_classes=num_classes,\n",
" segments=segments)\n",
" df_local = revel.evaluate(2)\n",
" df_local[\"id\"] = i+1\n",
" # aniade el df_local al df\n",
" df = pd.concat([df,df_local],axis=0)\n",
"display(df)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" conciseness | \n",
" local_fidelity | \n",
" local_concordance | \n",
" prescriptivity | \n",
" robustness | \n",
"
\n",
" \n",
" | id | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1 | \n",
" 0.733331 | \n",
" 0.924803 | \n",
" 0.881061 | \n",
" 0.704831 | \n",
" 0.900865 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.693088 | \n",
" 0.923750 | \n",
" 0.901022 | \n",
" 0.821029 | \n",
" 0.879404 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" conciseness local_fidelity local_concordance prescriptivity robustness\n",
"id \n",
"1 0.733331 0.924803 0.881061 0.704831 0.900865\n",
"2 0.693088 0.923750 0.901022 0.821029 0.879404"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"conciseness 0.713210\n",
"local_fidelity 0.924277\n",
"local_concordance 0.891042\n",
"prescriptivity 0.762930\n",
"robustness 0.890135\n",
"dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df_example = df.groupby(\"id\").mean()\n",
"display(df_example)\n",
"display(df_example.mean(axis=0))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.20"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "13b772770439038447b9da6362dc127f555d438a28e9435454732856e4838794"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}