Getting started
This package is part of a wider Julia XAI ecosystem. For an introduction to this ecosystem, please refer to the Getting started guide.
For this first example, we already have loaded a pre-trained LeNet5 model to look at explanations on the MNIST dataset.
using Flux
Conv((5, 5), 1 => 6, relu), # 156 parameters
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), # 2_416 parameters
MaxPool((2, 2)),
Dense(256 => 120, relu), # 30_840 parameters
Dense(120 => 84, relu), # 10_164 parameters
Dense(84 => 10), # 850 parameters
) # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
ExplainableAI.jl can be used on any differentiable classifier.
Preparing the input data
We use MLDatasets to load a single image from the MNIST dataset:
using MLDatasets
using ImageCore, ImageIO, ImageShow
index = 10
x, y = MNIST(Float32, :test)[10]
convert2image(MNIST, x)
By convention in Flux.jl, this input needs to be resized to WHCN format by adding a color channel and batch dimensions.
input = reshape(x, 28, 28, 1, :);
For any explanation of a model, ExplainableAI.jl assumes the batch dimension to come last in the input.
For the purpose of heatmapping, the input is assumed to be in WHCN order (width, height, channels, batch), which is Flux.jl's convention.
We can now select an analyzer of our choice and call analyze
to get an Explanation
. Note that for gradient-based optimizers, a backend for automatic differentiation must be loaded, by default Zygote.jl:
using ExplainableAI
using Zygote
analyzer = InputTimesGradient(model)
expl = analyze(input, analyzer);
The return value expl
is of type Explanation
and bundles the following data:
: numerical output of the analyzer, e.g. an attribution or gradientexpl.output
: model output for the given analyzer inputexpl.output_selection
: index of the output used for the explanationexpl.analyzer
: symbol corresponding the used analyzer, e.g.:Gradient
: symbol indicating a preset heatmapping style, e.g.:attibution
: optional named tuple that can be used by analyzers to return additional information.
We used InputTimesGradient
, so expl.analyzer
is :InputTimesGradient
By default, the explanation is computed for the maximally activated output neuron. Since our digit is a 9 and Julia's indexing is 1-based, the output neuron at index 10
of our trained model is maximally activated.
Finally, we obtain the result of the analyzer in form of an array.
28×28×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
-0.0 0.0 0.0 -0.0 -0.0 -0.0 -0.0 … -0.0 0.0 0.0
-0.0 0.0 0.0 -0.0 -0.0 -0.0 -0.0 -0.0 0.0 0.0
-0.0 0.0 0.0 -0.0 -0.0 -0.0 -0.0 -0.0 0.0 0.0
0.0 0.0 0.0 -0.0 -0.0 -0.0 -0.0 0.0 0.0 0.0
0.0 0.0 -0.0 -0.0 -0.0 0.0 0.0 0.0 0.0 0.0
-0.0 0.0 0.0 -0.0 0.0 -0.0 -0.0 … 0.0 0.0 0.0
-0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
-0.0 0.0 0.0 0.0 -0.0 0.0 -0.0 0.0 0.0 -0.0
-0.0 0.0 0.0 0.0 -0.0 -0.0 -0.0 0.0 0.0 -0.0
-0.0 0.0 0.0 0.0 -0.0 0.0 -0.0 0.385115 0.0714216 0.0
⋮ ⋮ ⋱ ⋮
0.0 -0.0 -0.0 0.0 0.0 0.0 -0.0 -0.0 0.0 0.0
-0.0 -0.0 0.0 0.0 0.0 0.0 -0.0 … -0.0 0.0 0.0
-0.0 0.0 0.0 0.0 0.0 0.0 -0.0 0.0 0.0 0.0
-0.0 0.0 0.0 0.0 -0.0 0.0 -0.0 0.0 0.0 0.0
-0.0 0.0 0.0 -0.0 -0.0 0.0 -0.0 0.0 0.0 0.0
0.0 0.0 -0.0 -0.0 -0.0 0.0 -0.0 0.0 0.0 0.0
0.0 0.0 0.0 -0.0 -0.0 0.0 0.0 … 0.0 0.0 0.0
0.0 0.0 -0.0 -0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 -0.0 0.0 0.0 0.0
Heatmapping basics
Since the array expl.val
is not very informative at first sight, we can visualize Explanation
s by computing a heatmap
using either VisionHeatmaps.jl or TextHeatmaps.jl.
using VisionHeatmaps
If we are only interested in the heatmap, we can combine analysis and heatmapping into a single function call:
heatmap(input, analyzer)
Neuron selection
By passing an additional index to our call to analyze
, we can compute an explanation with respect to a specific output neuron. Let's see why the output wasn't interpreted as a 4 (output neuron at index 5)
expl = analyze(input, analyzer, 5)
This heatmap shows us that the "upper loop" of the hand-drawn 9 has negative relevance with respect to the output neuron corresponding to digit 4!
The output neuron can also be specified when calling heatmap
heatmap(input, analyzer, 5)
Analyzing batches
ExplainableAI also supports explanations of input batches:
batchsize = 20
xs, _ = MNIST(Float32, :test)[1:batchsize]
batch = reshape(xs, 28, 28, 1, :) # reshape to WHCN format
expl = analyze(batch, analyzer);
This will return a single Explanation
for the entire batch. Calling heatmap
on expl
will detect the batch dimension and return a vector of heatmaps.
# Custom heatmaps
The function heatmap
automatically applies common presets for each method.
Since InputTimesGradient
computes attributions, heatmaps are shown in a blue-white-red color scheme. Gradient methods however are typically shown in grayscale:
analyzer = Gradient(model)
heatmap(input, analyzer)
analyzer = InputTimesGradient(model)
heatmap(input, analyzer)
Using VisionHeatmaps.jl, heatmaps can be heavily customized. Check out the heatmapping documentation for more information.
This page was generated using Literate.jl.