Assigning LRP Rules to Layers

In this example, we will show how to assign LRP rules to specific layers. For this purpose, we first define a small VGG-like convolutional neural network:

using RelevancePropagation
using Flux

model = Chain(
    Chain(
        Conv((3, 3), 3 => 8, relu; pad=1),
        Conv((3, 3), 8 => 8, relu; pad=1),
        MaxPool((2, 2)),
        Conv((3, 3), 8 => 16, relu; pad=1),
        Conv((3, 3), 16 => 16, relu; pad=1),
        MaxPool((2, 2)),
    ),
    Chain(Flux.flatten, Dense(1024 => 512, relu), Dropout(0.5), Dense(512 => 100, relu)),
);

Manually assigning rules

When creating an LRP-analyzer, we can assign individual rules to each layer. As we can see above, our model is a Chain of two Flux Chains. Using flatten_model, we can flatten the model into a single Chain:

model_flat = flatten_model(model)
Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    # 224 parameters
  Conv((3, 3), 8 => 8, relu, pad=1),    # 584 parameters
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   # 1_168 parameters
  Conv((3, 3), 16 => 16, relu, pad=1),  # 2_320 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(1024 => 512, relu),             # 524_800 parameters
  Dropout(0.5),
  Dense(512 => 100, relu),              # 51_300 parameters
)                   # Total: 12 arrays, 580_396 parameters, 2.215 MiB.

This allows us to define an LRP analyzer using an array of rules matching the length of the Flux chain:

rules = [
    FlatRule(),
    ZPlusRule(),
    ZeroRule(),
    ZPlusRule(),
    ZPlusRule(),
    ZeroRule(),
    PassRule(),
    EpsilonRule(),
    PassRule(),
    EpsilonRule(),
];

The LRP analyzer will show a summary of how layers and rules got matched:

LRP(model_flat, rules)
LRP(
  Conv((3, 3), 3 => 8, relu, pad=1)   => FlatRule(),
  Conv((3, 3), 8 => 8, relu, pad=1)   => ZPlusRule(),
  MaxPool((2, 2))                     => ZeroRule(),
  Conv((3, 3), 8 => 16, relu, pad=1)  => ZPlusRule(),
  Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
  MaxPool((2, 2))                     => ZeroRule(),
  Flux.flatten                        => PassRule(),
  Dense(1024 => 512, relu)            => EpsilonRule{Float32}(1.0f-6),
  Dropout(0.5)                        => PassRule(),
  Dense(512 => 100, relu)             => EpsilonRule{Float32}(1.0f-6),
)

However, this approach only works for models that can be fully flattened. For unflattened models and models containing Parallel and SkipConnection layers, we can compose rules using ChainTuple, ParallelTuple and SkipConnectionTuples which match the model structure:

rules = ChainTuple(
    ChainTuple(FlatRule(), ZPlusRule(), ZeroRule(), ZPlusRule(), ZPlusRule(), ZeroRule()),
    ChainTuple(PassRule(), EpsilonRule(), PassRule(), EpsilonRule()),
)

analyzer = LRP(model, rules; flatten=false)
LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)   => FlatRule(),
    Conv((3, 3), 8 => 8, relu, pad=1)   => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
    Conv((3, 3), 8 => 16, relu, pad=1)  => ZPlusRule(),
    Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
  ),
  ChainTuple(
    Flux.flatten             => PassRule(),
    Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
    Dropout(0.5)             => PassRule(),
    Dense(512 => 100, relu)  => EpsilonRule{Float32}(1.0f-6),
  ),
)
Keyword argument `flatten`

We used the LRP keyword argument flatten=false to showcase that the structure of the model can be preserved. For performance reasons, the default flatten=true is recommended.

Custom composites

Instead of manually defining a list of rules, we can also define a Composite. A composite constructs a list of LRP-rules by sequentially applying the composite primitives it contains.

To obtain the same set of rules as in the previous example, we can define

composite = Composite(
    GlobalTypeMap( # the following maps of layer types to LRP rules are applied globally
        Conv                 => ZPlusRule(),   # apply ZPlusRule on all Conv layers
        Dense                => EpsilonRule(), # apply EpsilonRule on all Dense layers
        Dropout              => PassRule(),    # apply PassRule on all Dropout layers
        MaxPool              => ZeroRule(),    # apply ZeroRule on all MaxPool layers
        typeof(Flux.flatten) => PassRule(),    # apply PassRule on all flatten layers
    ),
    FirstLayerMap( # the following rule is applied to the first layer
        FlatRule(),
    ),
);

We now construct an LRP analyzer from composite

analyzer = LRP(model, composite; flatten=false)
LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)   => FlatRule(),
    Conv((3, 3), 8 => 8, relu, pad=1)   => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
    Conv((3, 3), 8 => 16, relu, pad=1)  => ZPlusRule(),
    Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
  ),
  ChainTuple(
    Flux.flatten             => PassRule(),
    Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
    Dropout(0.5)             => PassRule(),
    Dense(512 => 100, relu)  => EpsilonRule{Float32}(1.0f-6),
  ),
)

As you can see, this analyzer contains the same rules as our previous one. To compute rules for a model without creating an analyzer, use lrp_rules:

lrp_rules(model, composite)
ChainTuple(
  ChainTuple(
    FlatRule(),
    ZPlusRule(),
    ZeroRule(),
    ZPlusRule(),
    ZPlusRule(),
    ZeroRule(),
  ),
  ChainTuple(
    PassRule(),
    EpsilonRule{Float32}(1.0f-6),
    PassRule(),
    EpsilonRule{Float32}(1.0f-6),
  ),
)

Composite primitives

The following Composite primitives can used to construct a Composite.

To apply a single rule, use:

To apply a set of rules to layers based on their type, use:

Primitives are called sequentially in the order the Composite was created with and overwrite rules specified by previous primitives.

Assigning a rule to a specific layer

To assign a rule to a specific layer, we can use LayerMap, which maps an LRP-rule to all layers in the model at the given index.

To display indices, use the show_layer_indices helper function:

show_layer_indices(model)
ChainTuple(
  ChainTuple(
    (1, 1),
    (1, 2),
    (1, 3),
    (1, 4),
    (1, 5),
    (1, 6),
  ),
  ChainTuple(
    (2, 1),
    (2, 2),
    (2, 3),
    (2, 4),
  ),
)

Let's demonstrate LayerMap by assigning a specific rule to the last Conv layer at index (1, 5):

composite = Composite(LayerMap((1, 5), EpsilonRule()))

LRP(model, composite; flatten=false)
LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)   => ZeroRule(),
    Conv((3, 3), 8 => 8, relu, pad=1)   => ZeroRule(),
    MaxPool((2, 2))                     => ZeroRule(),
    Conv((3, 3), 8 => 16, relu, pad=1)  => ZeroRule(),
    Conv((3, 3), 16 => 16, relu, pad=1) => EpsilonRule{Float32}(1.0f-6),
    MaxPool((2, 2))                     => ZeroRule(),
  ),
  ChainTuple(
    Flux.flatten             => ZeroRule(),
    Dense(1024 => 512, relu) => ZeroRule(),
    Dropout(0.5)             => ZeroRule(),
    Dense(512 => 100, relu)  => ZeroRule(),
  ),
)

This approach also works with Parallel layers.

Composite presets

RelevancePropagation.jl provides a set of default composites. A list of all implemented default composites can be found in the API reference, e.g. the EpsilonPlusFlat composite:

composite = EpsilonPlusFlat()
Composite(
  GlobalTypeMap(  # all layers
    Flux.Conv               => ZPlusRule(),
    Flux.ConvTranspose      => ZPlusRule(),
    Flux.CrossCor           => ZPlusRule(),
    Flux.Dense              => EpsilonRule{Float32}(1.0f-6),
    Flux.Scale              => EpsilonRule{Float32}(1.0f-6),
    Flux.LayerNorm          => LayerNormRule(),
    typeof(NNlib.dropout)   => PassRule(),
    Flux.AlphaDropout       => PassRule(),
    Flux.Dropout            => PassRule(),
    Flux.BatchNorm          => PassRule(),
    typeof(Flux.flatten)    => PassRule(),
    typeof(MLUtils.flatten) => PassRule(),
    typeof(identity)        => PassRule(),
 ),
  FirstLayerTypeMap(  # first layer
    Flux.Conv          => FlatRule(),
    Flux.ConvTranspose => FlatRule(),
    Flux.CrossCor      => FlatRule(),
 ),
)
analyzer = LRP(model, composite; flatten=false)
LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)   => FlatRule(),
    Conv((3, 3), 8 => 8, relu, pad=1)   => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
    Conv((3, 3), 8 => 16, relu, pad=1)  => ZPlusRule(),
    Conv((3, 3), 16 => 16, relu, pad=1) => ZPlusRule(),
    MaxPool((2, 2))                     => ZeroRule(),
  ),
  ChainTuple(
    Flux.flatten             => PassRule(),
    Dense(1024 => 512, relu) => EpsilonRule{Float32}(1.0f-6),
    Dropout(0.5)             => PassRule(),
    Dense(512 => 100, relu)  => EpsilonRule{Float32}(1.0f-6),
  ),
)

This page was generated using Literate.jl.