Example Implementations

The following examples demonstrate the implementation of XAI methods using the XAIBase.jl interface. To evaluate our methods, we load a small, pre-trained LeNet5 model and the MNIST dataset:

using Flux
using BSON

model = BSON.load("model.bson", @__MODULE__)[:model] # load pre-trained LeNet-5 model
Chain(
  Conv((5, 5), 1 => 6, relu),           # 156 parameters
  MaxPool((2, 2)),
  Conv((5, 5), 6 => 16, relu),          # 2_416 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(256 => 120, relu),              # 30_840 parameters
  Dense(120 => 84, relu),               # 10_164 parameters
  Dense(84 => 10),                      # 850 parameters
)                   # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
using MLDatasets
using ImageCore, ImageIO, ImageShow

index = 10
x, y = MNIST(Float32, :test)[10]

# By convention in Flux.jl, the input needs to be resized to WHCN format
# by adding a color channel and batch dimensions.
input = reshape(x, 28, 28, 1, :);

convert2image(MNIST, x)

Example 1: Random explanation

To get started, we implement a nonsensical method that returns a random explanation in the shape of the input.

using XAIBase
import XAIBase: call_analyzer

struct RandomAnalyzer{M} <: AbstractXAIMethod
    model::M
end

function call_analyzer(input, method::RandomAnalyzer, output_selector::AbstractOutputSelector; kwargs...)
    output = method.model(input)
    output_selection = output_selector(output)

    val = rand(size(input)...)
    return Explanation(val, input, output, output_selection, :RandomAnalyzer, :sensitivity, nothing)
end
call_analyzer (generic function with 2 methods)

We can directly use XAIBase's analyze function to compute the random explanation:

analyzer = RandomAnalyzer(model)
expl = analyze(input, analyzer)
Explanation{Array{Float64, 4}, Array{Float32, 4}, Matrix{Float32}, Vector{CartesianIndex{2}}, Nothing}([0.8124096215489804 0.9240602770853014 … 0.8460797334794672 0.7803789711154144; 0.32901242577506506 0.8953661341270125 … 0.5407905370060289 0.103183521478598; … ; 0.10689741255267082 0.5092533259670231 … 0.11100823845599916 0.39561818845831576; 0.858045199202526 0.10166396457626381 … 0.1547935930348482 0.12213735226980682;;;;], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-17.578014; -14.809775; … ; -0.22507292; 21.765005;;], CartesianIndex{2}[CartesianIndex(10, 1)], :RandomAnalyzer, :sensitivity, nothing)

Using either VisionHeatmaps.jl or TextHeatmaps.jl, which provide package extensions on XAIBase's Explanation type, we can visualize the explanations:

using VisionHeatmaps # load heatmapping functionality

heatmap(expl.val)
(a vector displayed as a row to save space)

As expected, the explanation is just noise.

Example 2: Input sensitivity

In this second example, we naively reimplement the Gradient analyzer from ExplainableAI.jl.

using XAIBase
import XAIBase: call_analyzer

using Zygote: gradient

struct MyGradient{M} <: AbstractXAIMethod
    model::M
end

function call_analyzer(input, method::MyGradient, output_selector::AbstractOutputSelector; kwargs...)
    output = method.model(input)
    output_selection = output_selector(output)

    grad = gradient((x) -> only(method.model(x)[output_selection]), input)
    val = only(grad)
    return Explanation(val, input, output, output_selection, :MyGradient, :sensitivity, nothing)
end
call_analyzer (generic function with 3 methods)
Note

ExplainableAI.jl implements the Gradient analyzer in a more efficient way that works with batched inputs and only requires a single forward and backward pass through the model.

Once again, we can directly use XAIBase's analyze and VisionHeatmaps' heatmap functions

using VisionHeatmaps

analyzer = MyGradient(model)
expl = analyze(input, analyzer)
heatmap(expl.val)
(a vector displayed as a row to save space)

and make use of all the features provided by the Julia-XAI ecosystem.

Note

For an introduction to the Julia-XAI ecosystem, please refer to the Getting started guide.