Assigning LRP Rules to Layers
In this example, we will show how to assign LRP rules to specific layers. For this purpose, we first define a small VGG-like convolutional neural network:
using RelevancePropagation
using Flux
model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu; pad=1),
Conv((3, 3), 16 => 16, relu; pad=1),
MaxPool((2, 2)),
),
Chain(Flux.flatten, Dense(1024 => 512, relu), Dropout(0.5), Dense(512 => 100, relu)),
);
Manually assigning rules
When creating an LRP-analyzer, we can assign individual rules to each layer. As we can see above, our model is a Chain
of two Flux Chain
s. Using flatten_model
, we can flatten the model into a single Chain
:
model_flat = flatten_model(model)
Chain(
Conv((3, 3), 3 => 8, relu, pad=1), # 224 parameters
Conv((3, 3), 8 => 8, relu, pad=1), # 584 parameters
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu, pad=1), # 1_168 parameters
Conv((3, 3), 16 => 16, relu, pad=1), # 2_320 parameters
MaxPool((2, 2)),
Flux.flatten,
Dense(1024 => 512, relu), # 524_800 parameters
Dropout(0.5),
Dense(512 => 100, relu), # 51_300 parameters
) # Total: 12 arrays, 580_396 parameters, 2.215 MiB.
This allows us to define an LRP analyzer using an array of rules matching the length of the Flux chain:
rules = [
FlatRule(),
ZPlusRule(),
ZeroRule(),
ZPlusRule(),
ZPlusRule(),
ZeroRule(),
PassRule(),
EpsilonRule(),
PassRule(),
EpsilonRule(),
];
The LRP
analyzer will show a summary of how layers and rules got matched:
LRP(model_flat, rules)
LRP(
Conv((3, 3), 3 => 8, relu, pad=1) => FlatRule(),
Conv((3, 3), 8 => 8, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((3, 3), 8 => 16, relu, pad=1) => ZPlusRule(),
Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Flux.flatten => PassRule(),
Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
Dropout(0.5) => PassRule(),
Dense(512 => 100, relu) => EpsilonRule{Float32}(1.0f-6),
)
However, this approach only works for models that can be fully flattened. For unflattened models and models containing Parallel
and SkipConnection
layers, we can compose rules using ChainTuple
, ParallelTuple
and SkipConnectionTuple
s which match the model structure:
rules = ChainTuple(
ChainTuple(FlatRule(), ZPlusRule(), ZeroRule(), ZPlusRule(), ZPlusRule(), ZeroRule()),
ChainTuple(PassRule(), EpsilonRule(), PassRule(), EpsilonRule()),
)
analyzer = LRP(model, rules; flatten=false)
LRP(
ChainTuple(
Conv((3, 3), 3 => 8, relu, pad=1) => FlatRule(),
Conv((3, 3), 8 => 8, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((3, 3), 8 => 16, relu, pad=1) => ZPlusRule(),
Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
),
ChainTuple(
Flux.flatten => PassRule(),
Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
Dropout(0.5) => PassRule(),
Dense(512 => 100, relu) => EpsilonRule{Float32}(1.0f-6),
),
)
We used the LRP
keyword argument flatten=false
to showcase that the structure of the model can be preserved. For performance reasons, the default flatten=true
is recommended.
Custom composites
Instead of manually defining a list of rules, we can also define a Composite
. A composite constructs a list of LRP-rules by sequentially applying the composite primitives it contains.
To obtain the same set of rules as in the previous example, we can define
composite = Composite(
GlobalTypeMap( # the following maps of layer types to LRP rules are applied globally
Conv => ZPlusRule(), # apply ZPlusRule on all Conv layers
Dense => EpsilonRule(), # apply EpsilonRule on all Dense layers
Dropout => PassRule(), # apply PassRule on all Dropout layers
MaxPool => ZeroRule(), # apply ZeroRule on all MaxPool layers
typeof(Flux.flatten) => PassRule(), # apply PassRule on all flatten layers
),
FirstLayerMap( # the following rule is applied to the first layer
FlatRule(),
),
);
We now construct an LRP analyzer from composite
analyzer = LRP(model, composite; flatten=false)
LRP(
ChainTuple(
Conv((3, 3), 3 => 8, relu, pad=1) => FlatRule(),
Conv((3, 3), 8 => 8, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((3, 3), 8 => 16, relu, pad=1) => ZPlusRule(),
Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
),
ChainTuple(
Flux.flatten => PassRule(),
Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
Dropout(0.5) => PassRule(),
Dense(512 => 100, relu) => EpsilonRule{Float32}(1.0f-6),
),
)
As you can see, this analyzer contains the same rules as our previous one. To compute rules for a model without creating an analyzer, use lrp_rules
:
lrp_rules(model, composite)
ChainTuple(
ChainTuple(
FlatRule(),
ZPlusRule(),
ZeroRule(),
ZPlusRule(),
ZPlusRule(),
ZeroRule(),
),
ChainTuple(
PassRule(),
EpsilonRule{Float32}(1.0f-6),
PassRule(),
EpsilonRule{Float32}(1.0f-6),
),
)
Composite primitives
The following Composite primitives can used to construct a Composite
.
To apply a single rule, use:
LayerMap
to apply a rule to a layer at a given indexGlobalMap
to apply a rule to all layersRangeMap
to apply a rule to a positional range of layersFirstLayerMap
to apply a rule to the first layerLastLayerMap
to apply a rule to the last layer
To apply a set of rules to layers based on their type, use:
GlobalTypeMap
to apply a dictionary that maps layer types to LRP-rulesRangeTypeMap
for aTypeMap
on generalized rangesFirstLayerTypeMap
for aTypeMap
on the first layer of a modelLastLayerTypeMap
for aTypeMap
on the last layerFirstNTypeMap
for aTypeMap
on the firstn
layers
Primitives are called sequentially in the order the Composite
was created with and overwrite rules specified by previous primitives.
Assigning a rule to a specific layer
To assign a rule to a specific layer, we can use LayerMap
, which maps an LRP-rule to all layers in the model at the given index.
To display indices, use the show_layer_indices
helper function:
show_layer_indices(model)
ChainTuple(
ChainTuple(
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(1, 6),
),
ChainTuple(
(2, 1),
(2, 2),
(2, 3),
(2, 4),
),
)
Let's demonstrate LayerMap
by assigning a specific rule to the last Conv
layer at index (1, 5)
:
composite = Composite(LayerMap((1, 5), EpsilonRule()))
LRP(model, composite; flatten=false)
LRP(
ChainTuple(
Conv((3, 3), 3 => 8, relu, pad=1) => ZeroRule(),
Conv((3, 3), 8 => 8, relu, pad=1) => ZeroRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((3, 3), 8 => 16, relu, pad=1) => ZeroRule(),
Conv((3, 3), 16 => 16, relu, pad=1) => EpsilonRule{Float32}(1.0f-6),
MaxPool((2, 2)) => ZeroRule(),
),
ChainTuple(
Flux.flatten => ZeroRule(),
Dense(1024 => 512, relu) => ZeroRule(),
Dropout(0.5) => ZeroRule(),
Dense(512 => 100, relu) => ZeroRule(),
),
)
This approach also works with Parallel
layers.
Composite presets
RelevancePropagation.jl provides a set of default composites. A list of all implemented default composites can be found in the API reference, e.g. the EpsilonPlusFlat
composite:
composite = EpsilonPlusFlat()
Composite(
GlobalTypeMap( # all layers
Flux.Conv => ZPlusRule(),
Flux.ConvTranspose => ZPlusRule(),
Flux.CrossCor => ZPlusRule(),
Flux.Dense => EpsilonRule{Float32}(1.0f-6),
Flux.Scale => EpsilonRule{Float32}(1.0f-6),
Flux.LayerNorm => LayerNormRule(),
typeof(NNlib.dropout) => PassRule(),
Flux.AlphaDropout => PassRule(),
Flux.Dropout => PassRule(),
Flux.BatchNorm => PassRule(),
typeof(Flux.flatten) => PassRule(),
typeof(MLUtils.flatten) => PassRule(),
typeof(identity) => PassRule(),
),
FirstLayerTypeMap( # first layer
Flux.Conv => FlatRule(),
Flux.ConvTranspose => FlatRule(),
Flux.CrossCor => FlatRule(),
),
)
analyzer = LRP(model, composite; flatten=false)
LRP(
ChainTuple(
Conv((3, 3), 3 => 8, relu, pad=1) => FlatRule(),
Conv((3, 3), 8 => 8, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
Conv((3, 3), 8 => 16, relu, pad=1) => ZPlusRule(),
Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
MaxPool((2, 2)) => ZeroRule(),
),
ChainTuple(
Flux.flatten => PassRule(),
Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
Dropout(0.5) => PassRule(),
Dense(512 => 100, relu) => EpsilonRule{Float32}(1.0f-6),
),
)
This page was generated using Literate.jl.