Example Implementations
The following examples demonstrate the implementation of XAI methods using the XAIBase.jl interface. To evaluate our methods, we load a small, pre-trained LeNet5 model and the MNIST dataset:
using Flux
using BSON
model = BSON.load("model.bson", @__MODULE__)[:model] # load pre-trained LeNet-5 model
Chain(
Conv((5, 5), 1 => 6, relu), # 156 parameters
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), # 2_416 parameters
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu), # 30_840 parameters
Dense(120 => 84, relu), # 10_164 parameters
Dense(84 => 10), # 850 parameters
) # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
using MLDatasets
using ImageCore, ImageIO, ImageShow
index = 10
x, y = MNIST(Float32, :test)[10]
# By convention in Flux.jl, the input needs to be resized to WHCN format
# by adding a color channel and batch dimensions.
input = reshape(x, 28, 28, 1, :);
convert2image(MNIST, x)
Example 1: Random explanation
To get started, we implement a nonsensical method that returns a random explanation in the shape of the input.
using XAIBase
import XAIBase: call_analyzer
struct RandomAnalyzer{M} <: AbstractXAIMethod
model::M
end
function call_analyzer(input, method::RandomAnalyzer, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)
val = rand(size(input)...)
return Explanation(val, input, output, output_selection, :RandomAnalyzer, :sensitivity, nothing)
end
call_analyzer (generic function with 2 methods)
We can directly use XAIBase's analyze
function to compute the random explanation:
analyzer = RandomAnalyzer(model)
expl = analyze(input, analyzer)
Explanation{Array{Float64, 4}, Array{Float32, 4}, Matrix{Float32}, Vector{CartesianIndex{2}}, Nothing}([0.7405504332868575 0.29448073762762217 … 0.660789356015917 0.09218085145340982; 0.15038993655822874 0.9565562574303522 … 0.6848470506806823 0.3639803886921913; … ; 0.44837868838800066 0.6360935119343811 … 0.7248519536246463 0.8219807542413412; 0.30096254672314016 0.21385818293632386 … 0.9954215837740947 0.8429159801652091;;;;], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-17.578014; -14.809775; … ; -0.22507292; 21.765005;;], CartesianIndex{2}[CartesianIndex(10, 1)], :RandomAnalyzer, :sensitivity, nothing)
Using either VisionHeatmaps.jl or TextHeatmaps.jl, which provide package extensions on XAIBase's Explanation
type, we can visualize the explanations:
using VisionHeatmaps # load heatmapping functionality
heatmap(expl.val)
As expected, the explanation is just noise.
Example 2: Input sensitivity
In this second example, we naively reimplement the Gradient
analyzer from ExplainableAI.jl.
using XAIBase
import XAIBase: call_analyzer
using Zygote: gradient
struct MyGradient{M} <: AbstractXAIMethod
model::M
end
function call_analyzer(input, method::MyGradient, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)
grad = gradient((x) -> only(method.model(x)[output_selection]), input)
val = only(grad)
return Explanation(val, input, output, output_selection, :MyGradient, :sensitivity, nothing)
end
call_analyzer (generic function with 3 methods)
ExplainableAI.jl implements the Gradient
analyzer in a more efficient way that works with batched inputs and only requires a single forward and backward pass through the model.
Once again, we can directly use XAIBase's analyze
and VisionHeatmaps' heatmap
functions
using VisionHeatmaps
analyzer = MyGradient(model)
expl = analyze(input, analyzer)
heatmap(expl.val)
heatmap(expl.val, colorscheme=:twilight, reduce=:norm, rangescale=:centered)
and make use of all the features provided by the Julia-XAI ecosystem.
For an introduction to the Julia-XAI ecosystem, please refer to the Getting started guide.