Getting started

Note

This package is part of a wider Julia XAI ecosystem. For an introduction to this ecosystem, please refer to the Getting started guide.

For this first example, we already have loaded a pre-trained LeNet5 model to look at explanations on the MNIST dataset.

using Flux

model
Chain(
  Conv((5, 5), 1 => 6, relu),           # 156 parameters
  MaxPool((2, 2)),
  Conv((5, 5), 6 => 16, relu),          # 2_416 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(256 => 120, relu),              # 30_840 parameters
  Dense(120 => 84, relu),               # 10_164 parameters
  Dense(84 => 10),                      # 850 parameters
)                   # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
Supported models

ExplainableAI.jl can be used on any differentiable classifier.

Preparing the input data

We use MLDatasets to load a single image from the MNIST dataset:

using MLDatasets
using ImageCore, ImageIO, ImageShow

index = 10
x, y = MNIST(Float32, :test)[10]

convert2image(MNIST, x)

By convention in Flux.jl, this input needs to be resized to WHCN format by adding a color channel and batch dimensions.

input = reshape(x, 28, 28, 1, :);
Input format

For any explanation of a model, ExplainableAI.jl assumes the batch dimension to come last in the input.

For the purpose of heatmapping, the input is assumed to be in WHCN order (width, height, channels, batch), which is Flux.jl's convention.

Explanations

We can now select an analyzer of our choice and call analyze to get an Explanation. Note that for gradient-based optimizers, a backend for automatic differentiation must be loaded, by default Zygote.jl:

using ExplainableAI
using Zygote

analyzer = InputTimesGradient(model)
expl = analyze(input, analyzer);

The return value expl is of type Explanation and bundles the following data:

  • expl.val: numerical output of the analyzer, e.g. an attribution or gradient
  • expl.output: model output for the given analyzer input
  • expl.output_selection: index of the output used for the explanation
  • expl.analyzer: symbol corresponding the used analyzer, e.g. :Gradient or :LRP
  • expl.heatmap: symbol indicating a preset heatmapping style, e.g. :attibution, :sensitivity or :cam
  • expl.extras: optional named tuple that can be used by analyzers to return additional information.

We used InputTimesGradient, so expl.analyzer is :InputTimesGradient.

expl.analyzer
:InputTimesGradient

By default, the explanation is computed for the maximally activated output neuron. Since our digit is a 9 and Julia's indexing is 1-based, the output neuron at index 10 of our trained model is maximally activated.

Finally, we obtain the result of the analyzer in form of an array.

expl.val
28×28×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
 -0.0   0.0   0.0  -0.0  -0.0  -0.0  -0.0  …  -0.0       0.0         0.0
 -0.0   0.0   0.0  -0.0  -0.0  -0.0  -0.0     -0.0       0.0         0.0
 -0.0   0.0   0.0  -0.0  -0.0  -0.0  -0.0     -0.0       0.0         0.0
  0.0   0.0   0.0  -0.0  -0.0  -0.0  -0.0      0.0       0.0         0.0
  0.0   0.0  -0.0  -0.0  -0.0   0.0   0.0      0.0       0.0         0.0
 -0.0   0.0   0.0  -0.0   0.0  -0.0  -0.0  …   0.0       0.0         0.0
 -0.0   0.0   0.0   0.0   0.0   0.0   0.0      0.0       0.0         0.0
 -0.0   0.0   0.0   0.0  -0.0   0.0  -0.0      0.0       0.0        -0.0
 -0.0   0.0   0.0   0.0  -0.0  -0.0  -0.0      0.0       0.0        -0.0
 -0.0   0.0   0.0   0.0  -0.0   0.0  -0.0      0.385115  0.0714216   0.0
  ⋮                             ⋮          ⋱   ⋮                    
  0.0  -0.0  -0.0   0.0   0.0   0.0  -0.0     -0.0       0.0         0.0
 -0.0  -0.0   0.0   0.0   0.0   0.0  -0.0  …  -0.0       0.0         0.0
 -0.0   0.0   0.0   0.0   0.0   0.0  -0.0      0.0       0.0         0.0
 -0.0   0.0   0.0   0.0  -0.0   0.0  -0.0      0.0       0.0         0.0
 -0.0   0.0   0.0  -0.0  -0.0   0.0  -0.0      0.0       0.0         0.0
  0.0   0.0  -0.0  -0.0  -0.0   0.0  -0.0      0.0       0.0         0.0
  0.0   0.0   0.0  -0.0  -0.0   0.0   0.0  …   0.0       0.0         0.0
  0.0   0.0  -0.0  -0.0   0.0   0.0   0.0      0.0       0.0         0.0
  0.0   0.0   0.0   0.0   0.0   0.0  -0.0      0.0       0.0         0.0

Heatmapping basics

Since the array expl.val is not very informative at first sight, we can visualize Explanations by computing a heatmap using either VisionHeatmaps.jl or TextHeatmaps.jl.

using VisionHeatmaps

heatmap(expl)

If we are only interested in the heatmap, we can combine analysis and heatmapping into a single function call:

heatmap(input, analyzer)

For a more detailed explanation of the heatmap function, refer to the heatmapping section.

List of analyzers

Neuron selection

By passing an additional index to our call to analyze, we can compute an explanation with respect to a specific output neuron. Let's see why the output wasn't interpreted as a 4 (output neuron at index 5)

expl = analyze(input, analyzer, 5)
heatmap(expl)

This heatmap shows us that the "upper loop" of the hand-drawn 9 has negative relevance with respect to the output neuron corresponding to digit 4!

Note

The output neuron can also be specified when calling heatmap:

heatmap(input, analyzer, 5)

Analyzing batches

ExplainableAI also supports explanations of input batches:

batchsize = 20
xs, _ = MNIST(Float32, :test)[1:batchsize]
batch = reshape(xs, 28, 28, 1, :) # reshape to WHCN format
expl = analyze(batch, analyzer);

This will return a single Explanationexpl for the entire batch. Calling heatmap on expl will detect the batch dimension and return a vector of heatmaps.

heatmap(expl)
(a vector displayed as a row to save space)

For more information on heatmapping batches, refer to the heatmapping documentation.


This page was generated using Literate.jl.