Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 6 additions & 115 deletions beginner_source/introyt/captumyt.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,15 @@

Before you get started, you need to have a Python environment with:

- Python version 3.6 or higher
- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress
(the latest version is recommended)
- PyTorch version 1.2 or higher (the latest version is recommended)
- TorchVision version 0.6 or higher (the latest version is recommended)
- Python version 3.9 or higher
- PyTorch version (the latest version is recommended)
- TorchVision version (the latest version is recommended)
- Captum (the latest version is recommended)
- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib
function whose arguments have been renamed in later versions

To install Captum in an Anaconda or pip virtual environment, use the
appropriate command for your environment below:

With ``conda``:
To install Captum in pip virtual environment, use the
command for your environment below:

.. code-block:: sh

conda install pytorch torchvision captum flask-compress matplotlib=3.3.4 -c pytorch

With ``pip``:

Expand Down Expand Up @@ -189,7 +181,7 @@
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()
2.ToImage() + v2.ToDtype(torch.float32, scale=True): ToTensor()
])

# standard ImageNet normalization
Expand Down Expand Up @@ -396,104 +388,3 @@
#


##########################################################################
# Visualization with Captum Insights
# ----------------------------------
#
# Captum Insights is an interpretability visualization widget built on top
# of Captum to facilitate model understanding. Captum Insights works
# across images, text, and other features to help users understand feature
# attribution. It allows you to visualize attribution for multiple
# input/output pairs, and provides visualization tools for image, text,
# and arbitrary data.
#
# In this section of the notebook, we’ll visualize multiple image
# classification inferences with Captum Insights.
#
# First, let’s gather some image and see what the model thinks of them.
# For variety, we’ll take our cat, a teapot, and a trilobite fossil:
#

imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg']

for img in imgs:
img = Image.open(img)
transformed_img = transform(img)
input_img = transform_normalize(transformed_img)
input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension

output = model(input_img)
output = F.softmax(output, dim=1)
prediction_score, pred_label_idx = torch.topk(output, 1)
pred_label_idx.squeeze_()
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')')


##########################################################################
# …and it looks like our model is identifying them all correctly - but of
# course, we want to dig deeper. For that we’ll use the Captum Insights
# widget, which we configure with an ``AttributionVisualizer`` object,
# imported below. The ``AttributionVisualizer`` expects batches of data,
# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking
# at images specifically, so well also import ``ImageFeature``.
#
# We configure the ``AttributionVisualizer`` with the following arguments:
#
# - An array of models to be examined (in our case, just the one)
# - A scoring function, which allows Captum Insights to pull out the
# top-k predictions from a model
# - An ordered, human-readable list of classes our model is trained on
# - A list of features to look for - in our case, an ``ImageFeature``
# - A dataset, which is an iterable object returning batches of inputs
# and labels - just like you’d use for training
#

from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

# Baseline is all-zeros input - this may differ depending on your data
def baseline_func(input):
return input * 0

# merging our image transforms from above
def full_img_transform(input):
i = Image.open(input)
i = transform(i)
i = transform_normalize(i)
i = i.unsqueeze(0)
return i


input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0)

visualizer = AttributionVisualizer(
models=[model],
score_func=lambda o: torch.nn.functional.softmax(o, 1),
classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())),
features=[
ImageFeature(
"Photo",
baseline_transforms=[baseline_func],
input_transforms=[],
)
],
dataset=[Batch(input_imgs, labels=[282,849,69])]
)


#########################################################################
# Note that running the cell above didn’t take much time at all, unlike
# our attributions above. That’s because Captum Insights lets you
# configure different attribution algorithms in a visual widget, after
# which it will compute and display the attributions. *That* process will
# take a few minutes.
#
# Running the cell below will render the Captum Insights widget. You can
# then choose attributions methods and their arguments, filter model
# responses based on predicted class or prediction correctness, see the
# model’s predictions with associated probabilities, and view heatmaps of
# the attribution compared with the original image.
#

visualizer.render()