Creating an LRP Analyzer
This package is part the Julia-XAI ecosystem. For an introduction to the ecosystem, please refer to the Getting started guide.
We start out by loading a small convolutional neural network:
using RelevancePropagation
using Flux
model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16; pad=1),
BatchNorm(16, relu),
Conv((3, 3), 16 => 8, relu; pad=1),
BatchNorm(8, relu),
),
Chain(Flux.flatten, Dense(2048 => 512, relu), Dropout(0.5), Dense(512 => 100, softmax)),
);This model contains two chains: the convolutional layers and the fully connected layers.
Model preparation
- Use
strip_softmaxto strip the output softmax from your model. Otherwise model checks will fail. - Use
canonizeto fuse linear layers. - Don't just call
LRP(model), instead use aCompositeto apply LRP rules to your model. Read Assigning rules to layers for more information. - By default,
LRPwill callflatten_modelto flatten your model. This reduces computational overhead.
Stripping the output softmax
When using LRP, it is recommended to explain output logits instead of probabilities. This can be done by stripping the output softmax activation from the model using the strip_softmax function:
model = strip_softmax(model)Chain(
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
),
Chain(
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
),
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.If you don't remove the output softmax, model checks will fail.
Model canonization
LRP is not invariant to a model's implementation. Applying the GammaRule to two linear layers in a row will yield different results than first fusing the two layers into one linear layer and then applying the rule. This fusing is called "canonization" and can be done using the canonize function:
model_canonized = canonize(model)Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu, pad=1), # 1_168 parameters
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 14 trainable arrays, 1_103_540 parameters,
# plus 2 non-trainable, 16 parameters, summarysize 4.211 MiB.After canonization, the first BatchNorm layer has been fused into the preceding Conv layer. The second BatchNorm layer wasn't fused since its preceding Conv layer has a ReLU activation function.
Flattening the model
RelevancePropagation.jl's LRP implementation supports nested Flux Chains and Parallel layers. However, it is recommended to flatten the model before analyzing it.
LRP is implemented by first running a forward pass through the model, keeping track of the intermediate activations, followed by a backward pass that computes the relevances.
To keep the LRP implementation simple and maintainable, RelevancePropagation.jl does not pre-compute "nested" activations. Instead, for every internal chain, a new forward pass is run to compute activations.
By "flattening" a model, this overhead can be avoided. For this purpose, RelevancePropagation.jl provides the function flatten_model:
model_flat = flatten_model(model)Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.This function is called by default when creating an LRP analyzer. Note that we pass the unflattened model to the analyzer, but analyzer.model is flattened:
analyzer = LRP(model)
analyzer.modelChain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, pad=1), # 1_168 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 8, relu, pad=1), # 1_160 parameters
BatchNorm(8, relu), # 16 parameters, plus 16
Flux.flatten,
Dense(2048 => 512, relu), # 1_049_088 parameters
Dropout(0.5),
Dense(512 => 100), # 51_300 parameters
) # Total: 16 trainable arrays, 1_103_572 parameters,
# plus 4 non-trainable, 48 parameters, summarysize 4.211 MiB.If this flattening is not desired, it can be disabled by passing the keyword argument flatten=false to the LRP constructor.
LRP rules
The following examples will be run on a pre-trained LeNet-5 model:
using BSON
model = BSON.load("../model.bson", @__MODULE__)[:model] # load pre-trained LeNet-5 modelChain(
Conv((5, 5), 1 => 6, relu), # 156 parameters
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), # 2_416 parameters
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu), # 30_840 parameters
Dense(120 => 84, relu), # 10_164 parameters
Dense(84 => 10), # 850 parameters
) # Total: 10 arrays, 44_426 parameters, 174.344 KiB.We also load the MNIST dataset:
using MLDatasets
using ImageCore, ImageIO, ImageShow
index = 10
x, y = MNIST(Float32, :test)[10]
input = reshape(x, 28, 28, 1, :)
convert2image(MNIST, x)By default, the LRP constructor will assign the ZeroRule to all layers.
analyzer = LRP(model)LRP(
Conv((5, 5), 1 => 6, relu) => ZeroRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((5, 5), 6 => 16, relu) => ZeroRule(),
MaxPool((2, 2)) => ZeroRule(),
Flux.flatten => ZeroRule(),
Dense(256 => 120, relu) => ZeroRule(),
Dense(120 => 84, relu) => ZeroRule(),
Dense(84 => 10) => ZeroRule(),
)This ana lyzer will return heatmaps that look identical to the InputTimesGradient analyzer from ExplainableAI.jl. We can visualize Explanations by computing a heatmap using either VisionHeatmaps.jl or TextHeatmaps.jl, either for images or text, respectively.
using VisionHeatmaps
heatmap(input, analyzer)LRP's strength lies in assigning different rules to different layers, based on their functionality in the neural network[1]. RelevancePropagation.jl implements many LRP rules out of the box, but it is also possible to implement custom rules.
To assign different rules to different layers, use one of the composites presets, or create your own composite, as described in Assigning rules to layers.
composite = EpsilonPlusFlat() # using composite preset EpsilonPlusFlatComposite(
GlobalTypeMap( # all layers
Flux.Conv => ZPlusRule(),
Flux.ConvTranspose => ZPlusRule(),
Flux.CrossCor => ZPlusRule(),
Flux.Dense => EpsilonRule{Float32}(1.0f-6),
Flux.Scale => EpsilonRule{Float32}(1.0f-6),
Flux.LayerNorm => LayerNormRule(),
typeof(NNlib.dropout) => PassRule(),
Flux.AlphaDropout => PassRule(),
Flux.Dropout => PassRule(),
Flux.BatchNorm => PassRule(),
typeof(Flux.flatten) => PassRule(),
typeof(MLUtils.flatten) => PassRule(),
typeof(identity) => PassRule(),
),
FirstLayerTypeMap( # first layer
Flux.Conv => FlatRule(),
Flux.ConvTranspose => FlatRule(),
Flux.CrossCor => FlatRule(),
),
)analyzer = LRP(model, composite)LRP(
Conv((5, 5), 1 => 6, relu) => FlatRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((5, 5), 6 => 16, relu) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Flux.flatten => PassRule(),
Dense(256 => 120, relu) => EpsilonRule{Float32}(1.0f-6),
Dense(120 => 84, relu) => EpsilonRule{Float32}(1.0f-6),
Dense(84 => 10) => EpsilonRule{Float32}(1.0f-6),
)heatmap(input, analyzer)Computing layerwise relevances
If you are interested in computing layerwise relevances, call analyze with an LRP analyzer and the keyword argument layerwise_relevances=true.
The layerwise relevances can be accessed in the extras field of the returned Explanation:
expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances(Float32[-1.5046089f-6 -1.5046089f-6 … 4.4148962f-8 0.0; -1.5046089f-6 -1.5046089f-6 … 4.4148962f-8 0.0; … ; 6.1168203f-6 6.1168203f-6 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-3.7615224f-5 0.0 … 1.103724f-6 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0001529205 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-3.7615224f-5 0.00021888175 … 0.000114222385 1.103724f-6; 0.00018855373 0.00027439542 … 0.00020195934 -3.4702516f-5; … ; -2.6229336f-5 7.008412f-5 … -2.8691686f-6 0.0; 0.0001529205 0.00029437395 … 0.0 0.0;;; 0.0 0.0 … -0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 -0.0; 0.0 0.0 … -0.0 0.0;;; -0.0 0.0 … -0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; -0.0 0.0 … 0.0 0.0;;; -0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; -0.0 0.0 … 0.0 0.0; 0.0 -0.0 … 0.0 0.0;;; 0.0 -0.0 … -0.0 0.0; -0.0 0.0 … 0.0 0.0; … ; 0.0 -0.0 … -0.0 0.0; 0.0 0.0 … -0.0 0.0;;; 0.0 -0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 -0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[0.0 0.0 … 0.0 0.0; -0.0027488603 0.0 … 0.02671153 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.03954778 0.0 … 0.0 -0.0014172087;;; -0.0014038438 0.0 … 0.0 -0.00047513167; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.004709776 … 0.0 0.0;;; 0.0 0.0 … 0.0 -0.002364168; 0.0 0.0 … 0.0 0.0; … ; 0.00063266495 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 -0.014013007; 0.0 0.014055459 … 0.0 0.0; … ; 0.0 0.0 … -0.00013713303 0.0; 0.037856653 0.0 … 0.0 0.0;;; -0.0033661663 0.0 … 0.008260983 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.01444692 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 -0.038602687; … ; 0.0 0.0014011612 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[-0.0027488603 -0.0 0.0006903657 0.02671153; -0.0043666936 -0.01234732 0.0 -0.013395859; -0.010977116 0.04276983 -0.0 -9.9327546f-5; 0.03954778 0.0 -0.0 -0.0014172087;;; -0.0014038438 0.017887121 0.0 -0.0004751317; -0.0 0.016042996 -0.0 -0.0; 0.011004027 0.0 -0.0 -0.0; 0.004709776 -0.009777178 -0.0 0.0;;; -0.0 -0.025314394 -0.011204087 -0.002364168; 0.027246127 -0.016887693 0.0 -0.010472963; 0.00095062767 -0.00031839436 -0.022801049 0.0036320935; 0.00063266495 -0.0028789637 0.011017141 0.0;;; … ;;; 0.014055459 0.011828159 -0.0 -0.014013007; 0.01990776 0.017337693 -0.0001569362 -0.0; -0.011918854 -0.00459134 0.0 -0.004262696; 0.037856653 0.023232974 0.0 -0.00013713304;;; -0.0033661663 -0.0 0.01047569 0.008260983; -0.0 0.010837908 0.06946186 0.0038659107; 0.0 -0.0063336077 0.029801883 0.02022359; -0.0 -0.001508267 -0.0047802655 0.01444692;;; -0.0 0.060284954 0.0021190182 -0.038602687; -0.0 0.0 0.015209901 0.0; -0.037314404 0.0054231985 0.0 0.0; 0.0014011612 0.0 0.0 0.0;;;;], Float32[-0.0027488603; -0.0043666936; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; -0.0; 0.0;;], Float32[0.028250216; 0.026568355; … ; -0.0; 0.032229893;;], Float32[0.0; 0.0; … ; 0.0; 1.0;;])Note that the layerwise relevances are only kept for layers in the outermost Chain of the model. Since we used a flattened model, we obtained all relevances.
Performance tips
Using LRP with a GPU
All LRP analyzers support GPU backends, building on top of Flux.jl's GPU support. Using a GPU only requires moving the input array and model weights to the GPU.
For example, using CUDA.jl:
using CUDA, cuDNN
using Flux
using RelevancePropagation
# move input array and model weights to GPU
input = input |> gpu # or gpu(input)
model = model |> gpu # or gpu(model)
# analyzers don't require calling `gpu`
analyzer = LRP(model)
# explanations are computed on the GPU
expl = analyze(input, analyzer)Some operations, like saving, require moving explanations back to the CPU. This can be done using Flux's cpu function:
val = expl.val |> cpu # or cpu(expl.val)
using BSON
BSON.@save "explanation.bson" valUsing LRP without a GPU
Using Julia's package extension mechanism, RelevancePropagation.jl's LRP implementation can optionally make use of Tullio.jl and LoopVectorization.jl for faster LRP rules on dense layers.
This only requires loading the packages before loading RelevancePropagation.jl:
using LoopVectorization, Tullio
using RelevancePropagationThis page was generated using Literate.jl.
- 1G. Montavon et al., Layer-Wise Relevance Propagation: An Overview