Creating an LRP Analyzer
This package is part the Julia-XAI ecosystem. For an introduction to the ecosystem, please refer to the Getting started guide.
We start out by loading a small convolutional neural network:
using RelevancePropagation
using Flux
model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16; pad=1),
BatchNorm(16, relu),
Conv((3, 3), 16 => 8, relu; pad=1),
BatchNorm(8, relu),
),
Chain(Flux.flatten, Dense(2048 => 512, relu), Dropout(0.5), Dense(512 => 100, softmax)),
);
This model contains two chains: the convolutional layers and the fully connected layers.
Model preparation
- Use
strip_softmax
to strip the output softmax from your model. Otherwise model checks will fail. - Use
canonize
to fuse linear layers. - Don't just call
LRP(model)
, instead use aComposite
to apply LRP rules to your model. Read Assigning rules to layers for more information. - By default,
LRP
will callflatten_model
to flatten your model. This reduces computational overhead.
Stripping the output softmax
When using LRP, it is recommended to explain output logits instead of probabilities. This can be done by stripping the output softmax activation from the model using the strip_softmax
function:
model = strip_softmax(model)
Chain(
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
),
Chain(
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
),
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.
If you don't remove the output softmax, model checks will fail.
Model canonization
LRP is not invariant to a model's implementation. Applying the GammaRule
to two linear layers in a row will yield different results than first fusing the two layers into one linear layer and then applying the rule. This fusing is called "canonization" and can be done using the canonize
function:
model_canonized = canonize(model)
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu, pad=1), # 1_168 parameters
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 14 trainable arrays, 1_103_540 parameters,
# plus 2 non-trainable, 16 parameters, summarysize 4.211 MiB.
After canonization, the first BatchNorm
layer has been fused into the preceding Conv
layer. The second BatchNorm
layer wasn't fused since its preceding Conv
layer has a ReLU activation function.
Flattening the model
RelevancePropagation.jl's LRP implementation supports nested Flux Chains and Parallel layers. However, it is recommended to flatten the model before analyzing it.
LRP is implemented by first running a forward pass through the model, keeping track of the intermediate activations, followed by a backward pass that computes the relevances.
To keep the LRP implementation simple and maintainable, RelevancePropagation.jl does not pre-compute "nested" activations. Instead, for every internal chain, a new forward pass is run to compute activations.
By "flattening" a model, this overhead can be avoided. For this purpose, RelevancePropagation.jl provides the function flatten_model
:
model_flat = flatten_model(model)
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.
This function is called by default when creating an LRP analyzer. Note that we pass the unflattened model to the analyzer, but analyzer.model
is flattened:
analyzer = LRP(model)
analyzer.model
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.
If this flattening is not desired, it can be disabled by passing the keyword argument flatten=false
to the LRP
constructor.
LRP rules
The following examples will be run on a pre-trained LeNet-5 model:
using BSON
model = BSON.load("../model.bson", @__MODULE__)[:model] # load pre-trained LeNet-5 model
Chain(
Conv((5, 5), 1 => 6, relu), # 156 parameters
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), # 2_416 parameters
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu), # 30_840 parameters
Dense(120 => 84, relu), # 10_164 parameters
Dense(84 => 10), # 850 parameters
) # Total: 10 arrays, 44_426 parameters, 174.344 KiB.
We also load the MNIST dataset:
using MLDatasets
using ImageCore, ImageIO, ImageShow
index = 10
x, y = MNIST(Float32, :test)[10]
input = reshape(x, 28, 28, 1, :)
convert2image(MNIST, x)
By default, the LRP
constructor will assign the ZeroRule
to all layers.
analyzer = LRP(model)
LRP(
Conv((5, 5), 1 => 6, relu) => ZeroRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((5, 5), 6 => 16, relu) => ZeroRule(),
MaxPool((2, 2)) => ZeroRule(),
Flux.flatten => ZeroRule(),
Dense(256 => 120, relu) => ZeroRule(),
Dense(120 => 84, relu) => ZeroRule(),
Dense(84 => 10) => ZeroRule(),
)
This ana lyzer will return heatmaps that look identical to the InputTimesGradient
analyzer from ExplainableAI.jl. We can visualize Explanation
s by computing a heatmap
using either VisionHeatmaps.jl or TextHeatmaps.jl, either for images or text, respectively.
using VisionHeatmaps
heatmap(input, analyzer)
LRP's strength lies in assigning different rules to different layers, based on their functionality in the neural network[1]. RelevancePropagation.jl implements many LRP rules out of the box, but it is also possible to implement custom rules.
To assign different rules to different layers, use one of the composites presets, or create your own composite, as described in Assigning rules to layers.
composite = EpsilonPlusFlat() # using composite preset EpsilonPlusFlat
Composite(
GlobalTypeMap( # all layers
Flux.Conv => ZPlusRule(),
Flux.ConvTranspose => ZPlusRule(),
Flux.CrossCor => ZPlusRule(),
Flux.Dense => EpsilonRule{Float32}(1.0f-6),
Flux.Scale => EpsilonRule{Float32}(1.0f-6),
Flux.LayerNorm => LayerNormRule(),
typeof(NNlib.dropout) => PassRule(),
Flux.AlphaDropout => PassRule(),
Flux.Dropout => PassRule(),
Flux.BatchNorm => PassRule(),
typeof(Flux.flatten) => PassRule(),
typeof(MLUtils.flatten) => PassRule(),
typeof(identity) => PassRule(),
),
FirstLayerTypeMap( # first layer
Flux.Conv => FlatRule(),
Flux.ConvTranspose => FlatRule(),
Flux.CrossCor => FlatRule(),
),
)
analyzer = LRP(model, composite)
LRP(
Conv((5, 5), 1 => 6, relu) => FlatRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((5, 5), 6 => 16, relu) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Flux.flatten => PassRule(),
Dense(256 => 120, relu) => EpsilonRule{Float32}(1.0f-6),
Dense(120 => 84, relu) => EpsilonRule{Float32}(1.0f-6),
Dense(84 => 10) => EpsilonRule{Float32}(1.0f-6),
)
heatmap(input, analyzer)
Computing layerwise relevances
If you are interested in computing layerwise relevances, call analyze
with an LRP analyzer and the keyword argument layerwise_relevances=true
.
The layerwise relevances can be accessed in the extras
field of the returned Explanation
:
expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances
(Float32[-1.5046089f-6 -1.5046089f-6 … 4.4148962f-8 0.0; -1.5046089f-6 -1.5046089f-6 … 4.4148962f-8 0.0; … ; 6.1168203f-6 6.1168203f-6 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-3.7615224f-5 0.0 … 1.103724f-6 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0001529205 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-3.7615224f-5 0.00021888175 … 0.000114222385 1.103724f-6; 0.00018855373 0.00027439542 … 0.00020195934 -3.4702516f-5; … ; -2.6229336f-5 7.008412f-5 … -2.8691686f-6 0.0; 0.0001529205 0.00029437395 … 0.0 0.0;;; 0.0 0.0 … -0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 -0.0; 0.0 0.0 … -0.0 0.0;;; -0.0 0.0 … -0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; -0.0 0.0 … 0.0 0.0;;; -0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; -0.0 0.0 … 0.0 0.0; 0.0 -0.0 … 0.0 0.0;;; 0.0 -0.0 … -0.0 0.0; -0.0 0.0 … 0.0 0.0; … ; 0.0 -0.0 … -0.0 0.0; 0.0 0.0 … -0.0 0.0;;; 0.0 -0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 -0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[0.0 0.0 … 0.0 0.0; -0.0027488603 0.0 … 0.02671153 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.03954778 0.0 … 0.0 -0.0014172087;;; -0.0014038438 0.0 … 0.0 -0.00047513167; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.004709776 … 0.0 0.0;;; 0.0 0.0 … 0.0 -0.002364168; 0.0 0.0 … 0.0 0.0; … ; 0.00063266495 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 -0.014013007; 0.0 0.014055459 … 0.0 0.0; … ; 0.0 0.0 … -0.00013713303 0.0; 0.037856653 0.0 … 0.0 0.0;;; -0.0033661663 0.0 … 0.008260983 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.01444692 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 -0.038602687; … ; 0.0 0.0014011612 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-0.0027488603 -0.0 0.0006903657 0.02671153; -0.0043666936 -0.01234732 0.0 -0.013395859; -0.010977116 0.04276983 -0.0 -9.9327546f-5; 0.03954778 0.0 -0.0 -0.0014172087;;; -0.0014038438 0.017887121 0.0 -0.0004751317; -0.0 0.016042996 -0.0 -0.0; 0.011004027 0.0 -0.0 -0.0; 0.004709776 -0.009777178 -0.0 0.0;;; -0.0 -0.025314394 -0.011204087 -0.002364168; 0.027246127 -0.016887693 0.0 -0.010472963; 0.00095062767 -0.00031839436 -0.022801049 0.0036320935; 0.00063266495 -0.0028789637 0.011017141 0.0;;; … ;;; 0.014055459 0.011828159 -0.0 -0.014013007; 0.01990776 0.017337693 -0.0001569362 -0.0; -0.011918854 -0.00459134 0.0 -0.004262696; 0.037856653 0.023232974 0.0 -0.00013713304;;; -0.0033661663 -0.0 0.01047569 0.008260983; -0.0 0.010837908 0.06946186 0.0038659107; 0.0 -0.0063336077 0.029801883 0.02022359; -0.0 -0.001508267 -0.0047802655 0.01444692;;; -0.0 0.060284954 0.0021190182 -0.038602687; -0.0 0.0 0.015209901 0.0; -0.037314404 0.0054231985 0.0 0.0; 0.0014011612 0.0 0.0 0.0;;;;], Float32[-0.0027488603; -0.0043666936; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; -0.0; 0.0;;], Float32[0.028250216; 0.026568355; … ; -0.0; 0.032229893;;], Float32[0.0; 0.0; … ; 0.0; 1.0;;])
Note that the layerwise relevances are only kept for layers in the outermost Chain
of the model. Since we used a flattened model, we obtained all relevances.
Performance tips
Using LRP with a GPU
All LRP analyzers support GPU backends, building on top of Flux.jl's GPU support. Using a GPU only requires moving the input array and model weights to the GPU.
For example, using CUDA.jl:
using CUDA, cuDNN
using Flux
using RelevancePropagation
# move input array and model weights to GPU
input = input |> gpu # or gpu(input)
model = model |> gpu # or gpu(model)
# analyzers don't require calling `gpu`
analyzer = LRP(model)
# explanations are computed on the GPU
expl = analyze(input, analyzer)
Some operations, like saving, require moving explanations back to the CPU. This can be done using Flux's cpu
function:
val = expl.val |> cpu # or cpu(expl.val)
using BSON
BSON.@save "explanation.bson" val
Using LRP without a GPU
Using Julia's package extension mechanism, RelevancePropagation.jl's LRP implementation can optionally make use of Tullio.jl and LoopVectorization.jl for faster LRP rules on dense layers.
This only requires loading the packages before loading RelevancePropagation.jl:
using LoopVectorization, Tullio
using RelevancePropagation
This page was generated using Literate.jl.
- 1G. Montavon et al., Layer-Wise Relevance Propagation: An Overview