Heatmapping

Since numerical explanations are not very informative at first sight, we can visualize them by computing a heatmap, using either VisionHeatmaps.jl or TextHeatmaps.jl.

This page showcases different options and preset for heatmapping, building on the basics shown in the Getting started section.

We start out by loading the same pre-trained LeNet5 model and MNIST input data:

using ExplainableAI
using VisionHeatmaps
using Zygote
using Flux

model
Chain(
  Conv((5, 5), 1 => 6, relu),           # 156 parameters
  MaxPool((2, 2)),
  Conv((5, 5), 6 => 16, relu),          # 2_416 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(256 => 120, relu),              # 30_840 parameters
  Dense(120 => 84, relu),               # 10_164 parameters
  Dense(84 => 10),                      # 850 parameters
)                   # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
using MLDatasets
using ImageCore, ImageIO, ImageShow

index = 10
x, y = MNIST(Float32, :test)[10]
input = reshape(x, 28, 28, 1, :)

img = convert2image(MNIST, x)

Automatic heatmap presets

The function heatmap automatically applies common presets for each method.

Since InputTimesGradient computes attributions, heatmaps are shown in a blue-white-red color scheme. Gradient methods however are typically shown in grayscale:

analyzer = Gradient(model)
heatmap(input, analyzer)
analyzer = InputTimesGradient(model)
heatmap(input, analyzer)

Custom heatmap settings

Color schemes

We can partially or fully override presets by passing keyword arguments to heatmap. For example, we can use a custom color scheme from ColorSchemes.jl using the keyword argument colorscheme:

using ColorSchemes

expl = analyze(input, analyzer)
heatmap(expl; colorscheme=:jet)
heatmap(expl; colorscheme=:inferno)

Refer to the ColorSchemes.jl catalogue for a gallery of available color schemes.

Color channel reduction

Explanations have the same dimensionality as the inputs to the classifier. For images with multiple color channels, this means that the explanation also has a "color channel" dimension.

The keyword argument reduce can be used to reduce this dimension to a single scalar value for each pixel. The following presets are available:

  • :sum: sum up color channels (default setting)
  • :norm: compute 2-norm over the color channels
  • :maxabs: compute maximum(abs, x) over the color channels
heatmap(expl; reduce=:sum)
heatmap(expl; reduce=:norm)
heatmap(expl; reduce=:maxabs)

In this example, the heatmaps look identical. Since MNIST only has a single color channel, there is no need for color channel reduction.

Mapping explanations onto the color scheme

To map a color-channel-reduced explanation onto a color scheme, we first need to normalize all values to the range $[0, 1]$.

For this purpose, two presets are available through the rangescale keyword argument:

  • :extrema: normalize to the minimum and maximum value of the explanation
  • :centered: normalize to the maximum absolute value of the explanation. Values of zero will be mapped to the center of the color scheme.

Depending on the color scheme, one of these presets may be more suitable than the other. The default color scheme for InputTimesGradient, seismic, is centered around zero, making :centered a good choice:

heatmap(expl; rangescale=:centered)
heatmap(expl; rangescale=:extrema)

However, for the inferno color scheme, which is not centered around zero, :extrema leads to a heatmap with higher contrast.

heatmap(expl; rangescale=:centered, colorscheme=:inferno)
heatmap(expl; rangescale=:extrema, colorscheme=:inferno)

For the full list of heatmap keyword arguments, refer to the heatmap documentation.

Heatmap overlays

Heatmaps can be overlaid onto the input image using the heatmap_overlay function from VisionHeatmaps.jl. This can be useful for visualizing the relevance of specific regions of the input:

heatmap_overlay(expl, img)

The alpha value of the heatmap can be adjusted using the alpha keyword argument:

heatmap_overlay(expl, img; alpha=0.3)

All previously discussed keyword arguments for heatmap can also be used with heatmap_overlay:

heatmap_overlay(expl, img; alpha=0.7, colorscheme=:inferno, rangescale=:extrema)

Heatmapping batches

Heatmapping also works with input batches. Let's demonstrate this by using a batch of 25 images from the MNIST dataset:

xs, ys = MNIST(Float32, :test)[1:25]
batch = reshape(xs, 28, 28, 1, :); # reshape to WHCN format

The heatmap function automatically recognizes that the explanation is batched and returns a Vector of images:

heatmaps = heatmap(batch, analyzer)
(a vector displayed as a row to save space)

Image.jl's mosaic function can used to display them in a grid:

mosaic(heatmaps; nrow=5)
Example block output

When heatmapping batches, the mapping to the color scheme is applied per sample. For example, rangescale=:extrema will normalize each heatmap to the minimum and maximum value of each sample in the batch. This ensures that heatmaps don't depend on other samples in the batch.

If this bevahior is not desired, heatmap can be called with the keyword-argument process_batch=true:

expl = analyze(batch, analyzer)
heatmaps = heatmap(expl; process_batch=true)
mosaic(heatmaps; nrow=5)
Example block output

This can be useful when comparing heatmaps for fixed output neurons:

expl = analyze(batch, analyzer, 7) # explain digit "6"
heatmaps = heatmap(expl; process_batch=true)
mosaic(heatmaps; nrow=5)
Example block output
Output type consistency

To obtain a singleton Vector containing a single heatmap for non-batched inputs, use the heatmap keyword argument unpack_singleton=false.

Processing heatmaps

Heatmapping makes use of the Julia-based image processing ecosystem Images.jl.

If you want to further process heatmaps, you may benefit from reading about some fundamental conventions that the ecosystem utilizes that are different from how images are typically represented in OpenCV, MATLAB, ImageJ or Python.

Saving heatmaps

Since heatmaps are regular Images.jl images, they can be saved as such:

using FileIO

img = heatmap(input, analyzer)
save("heatmap.png", img)

This page was generated using Literate.jl.