GitPedia

Recurrent Independent Mechanisms

Implementation of the paper Recurrent Independent Mechanisms (https://arxiv.org/pdf/1909.10893.pdf)

From dido1998·Updated May 22, 2026·View on GitHub·

An implementation of [Recurrent Independent Mechanisms (Goyal et al. 2019)](https://arxiv.org/pdf/1909.10893.pdf) in PyTorch. The project is written primarily in Python, distributed under the Other license, first published in 2020. Key topics include: attention, deep-learning, generalization, grus, lstms.

Recurrent-Independent-Mechanisms

An implementation of Recurrent Independent Mechanisms (Goyal et al. 2019) in PyTorch.

Paper Summary

This paper aims to build models that can generalize to different environments with specific factors of variation from the environment that it was trained on. To achieve this the authors build recurrent networks that are modular in nature and each module is independent of the other modules and only interact sparsely through attention. In this way each module can learn different aspects of the environment and is only responsible for ensuring similar performance on the same aspect of a different environment.

These different modules are modeled using LSTMs or GRUs. The total number of modules are fixed to Kt. At each time-step a fixed number (Ka) modules are selected to be active. These Ka active modules are selected using an input attention mechanism. The top-Ka modules that produce the highest scores for the input are selected to be active. The other modules are fed a null-input (all zeros).

Once the new states for each module (normal LSTM or GRU computation) are computed given their inputs that come from the input attention, each module can interact with each other using another attention mechanism which is called the communication attention mechanism. Only the states of the active modules are updated using this attention mechanism. The active modules can refer to the active modules as well as the inactive modules for updating their states.

The image below has been taken from the original paper.

<p align="center"> <img width="560" height="300" src="https://github.com/dido1998/Recurrent-Independent-Mechanisms/blob/master/README-RES/rim_image.png"> </p>

Updates

8/3/2020 : Implemented GroupLSTMCell and GroupGRUCell which eliminate the need for using Kt LSTM or GRU Cells. Previously, the computation of the LSTM or GRU operation required looping over Kt cells. Now, the GroupLSTMCell and GroupGRUCell can compute the LSTM or GRU operation at once (parallely) without using a loop. This results in a speed-up of the RIM computation as shown below

<p align="center"> <img width="500" height="450" src="https://github.com/dido1998/Recurrent-Independent-Mechanisms/blob/master/README-RES/time_comparison.png"> </p>

7/3/2020 : Added support for n-layered and bidirectional RIM similar to nn.LSTM and nn.GRU.

Setup

  • For using RIM as a standalone replacement for LSTMs or GRUs

Running the Installation instructions below will automatically install the above libraries.

  • For running the experiments below
    • Install tqdm using pip install tqdm
    • For running the RL experiments
      • Install gym-minigrid using pip install gym-minigrid
      • Install torch_ac using pip install torch_ac>=1.1.0
      • Install tensorboardX using pip install tensorboardX>=1.6

Installation

git clone https://github.com/dido1998/Recurrent-Independent-Mechanisms.git
cd Recurrent\ Independent\ Mechanisms
pip install -e .

This will allow you to use RIMs from anywhere in your system.
This code was tested with python3.6

Documentation

RIMCell

A single RIM cell similar to nn.LSTMCell or nn.GRUCell.

Class RIM.RIMCell(device,
input_size,
hidden_size,
num_units,
k,
rnn_cell,
input_key_size = 64,
input_value_size = 400,
input_query_size = 64,
num_input_heads = 1,
input_dropout = 0.1,
comm_key_size = 32,
comm_value_size = 100,
comm_query_size = 32,
num_comm_heads = 4,
comm_dropout = 0.1
)

For description of the RIMCell please check the paper.

Parameters

ParameterDescription
devicetorch.device('cuda') or torch.device('cpu').
input_sizeThe number of expected input features.
hidden_sizeThe number of hidden features in each unit.
num_unitsNumber of total RIM units.
kNumber of active RIMs at every time-step.
rnn_cell'LSTM' or 'GRU'
input_key_sizeNumber of features in the input key.
input_value_sizeNumber of features in the input value.
input_query_sizeNumber of features in the input query.
num_input_headsNumber of input attention heads.
input_dropoutDropout applied to the input attention probabilities.
comm_key_sizeNumber of features in the communication key.
comm_value_sizeNumber of features in the communication value.
comm_query_sizeNumber of features in the communication query.
num_comm_headsNumber of communication attention heads.
comm_dropoutDropout applied to the communication attention probabilities.

Inputs

InputDescription
xInput of shape (batch_size, 1, input_size).
hsHidden state for the current time-step of shape (batch_size, num_units, hidden_size).
csThis is given if rnn_cell == 'LSTM' else it is None. Cell state for the current time-step of shape (batch_size, num_units, hidden_size).

Outputs

OutputDescription
hsThe new hidden state of shape (batch_size, num_units, hidden_size).
csThis is only returned if rnn_cell == 'LSTM'. The new cell state of shape (batch_size, num_units, hidden_size).

Example

from RIM import RIMCell
timesteps = 50
batch_size = 32
num_units = 6
k = 4
input_size = 32
hidden_size = 64
# Model definition. The definition of each argument is same as above.
rim_model = RIMCell(torch.device('cuda'), input_size, hidden_size, num_units, k, 'LSTM')

# creating hidden states and cell states
hs = torch.randn(batch_size, num_units, hidden_size)
cs = torch.randn(batch_size, num_units, hidden_size)
 
# Creating Input
xs = torch.randn(batch_size, timesteps, input_size)
xs = torch.split(xs, 1, 1)

for x in xs:
    hs, cs = rim_model(x, hs, cs)

RIM

A recurrent network made up of RIM cells similar to nn.LSTM or nn.GRU.

class RIM.RIM(device,
input_size,
hidden_size,
num_units,
k,
rnn_cell,
n_layers,
bidirectional,
**kwargs
)

Parameters

ParameterDescription
device'cpu' or 'cuda'.
input_sizeInput feature size.
hidden_sizeHidden feature size of each RIM unit.
num_unitsNumber of RIM units.
kNumber of active RIMs at each time-step
rnn_cell'LSTM' or 'GRU'
n_layersNumber of RIM layers
bidirectionalTrue or False

The keyword arguments are same as RIM.RIMCell.

Inputs

InputDescription
xInput of shape (seq_len, batch_size, input_size)
hsHidden state of shape (num_layers * num_directions, batch_size, hidden_size * num_units). If not provided, it is randomly initialized
csProvided only id rnn_cell == LSTM. Shape is same as hs. If not provided, it is randomly initialized.

Outputs

OutputDescription
outputOutput of shape (seq_len, batch_size, num_directions * hidden_size * num_units)
hsHidden state of shape (num_directions * num_layers, batch_size, hidden_size * num_units)
csReturned if rnn_cell == LSTM. Cell state of shape same as hs.

Example

from RIM import RIM
rim_model = RIM('cuda', 16, 24, 6, 4, 'LSTM', 4, True)
x = torch.randn(7, 4, 16).cuda()
out, h, c = rim_model(x)

Gym MiniGrid

The minigrid environment is available here. Results for the gym minigrd environment solved using PPO.

You need to cd into the minigrid_experiments directory to run these experiments.

Training

python3.6 train.py --algo ppo --env <Any of the available envs in the minigrid repo>
                   --model <name of the directory to store the trained model and related files>
                   --use_rim
                   --frames <num_frames>

You can also use a2c for training by changing the --algo option accordingly. If the --use_rim is not specified, the model will use a singleLSTM for training. I recommend using a 80000 frames for task-1, 1000000 for task-2 and 300000 for task-3. I recommend keeping the other parameters same for convergence. If you tweak the other parameters and get better results let me know :)

Evaluation

python3.6 evaluate.py --env <Any of the available envs in the minigrid repo>
                      --model <directory where model is stored> 
                      --use_rim

The --use_rim flag is used when your model was trained using an RIM. For simple LSTM you can leave the --use_rim flag.

Visualization

python3.6 visualize.py --env <Any of the available envs in the minigrid repo>
                        --model <directory where model is stored> 
                        --gif <name of the gif file> 
                        --use_rim

The --use_rim flag has similar use as in evaluation.

For all the tables, the model is trained on the star-marked column and only evaluated on the other columns.

I report the mean return per episode in each case

The environment names used below are same as the ones in the minigrid repo

Task 1

The models shown in the gif have been trained on the MiniGrid-Empty-5x5-V0 environment.

LSTMRIM
ModelMiniGrid-Empty-5x5-V0 *MiniGrid-Empty-16x16-V0
RIM (Kt = 4, Ka = 3)0.910.92
RIM (Kt = 4, Ka = 2)0.920.95
LSTM0.800.84

Task 2

The modelS shown in the gif have been trained on the MiniGrid-MultiRoom-N2-S4-V0 (2 rooms) environment.

LSTMRIM
ModelMiniGrid-MultiRoom-N2-S4-V0 (2 rooms) *MiniGrid-MultiRoom-N2-S5-V0 (4 rooms)MiniGrid-MultiRoom-N6-V0 (6 rooms)
RIM (Kt = 4, Ka = 3)0.810.660.05
RIM (Kt = 4, Ka = 2)0.810.100.00
LSTM0.820.040.00

Task 3

The models shown in the gif have been trained on the MiniGrid-DoorKey-5x5-V0 environment.

LSTMRIM
ModelMiniGrid-DoorKey-5x5-v0 *MiniGrid-DoorKey-6x6-v0MiniGrid-DoorKey-8x8-v0MiniGrid-DoorKey-16x16-v0
RIM (Kt=4, Ka = 3)0.900.680.380.18
RIM (Kt = 4, Ka = 2)0.850.620.290.13
LSTM0.900.630.350.12

Insight: Task 2 and Task 3 demonstrate the importance of the hyper-parameter Ka (number of active modules per timestep). We can see that reducing Ka from 3 to 2 drastically reduces performance especially in task 2. We also see that the RIM with Ka = 2 is the best performing model for task 1 but task 1 is a comparitively simple task. It would be interesting to see what causes each RIM to activate in each environment.

Sequential MNIST Task

Results for MNIST task:

The model has been trained on MNIST datset with individual image size 14*14

KtKah16*1619*1924*24
6660080.3156.1937.45
RIM6560088.6759.3228.85
6460087.8969.7546.23
LSTM--60080.4339.7420.48

This task can be run using -

python3.6 main.py --args

args has the following options-

ArgumentsDescription
cudaTo use GPU or not
epochsnumber of epochs to train
batch_sizeBatch size for training
hidden_SizePer RIM hidden size
input_sizeInput feature size
modelLSTM or RIM
trainset to True for training and False for testing.
rnn_cellLSTM or GRU
key_size_inputInput key size
value_size_inputInput value size
query size inputInput query size
num_input_headsNumber of heads in input attention
input_dropoutInput dropout value.
key_size_commCommunication key size.
value_size_commCommunication value size.
query_size_commCommunication query size.
num_comm_headsNumber of heads in communication attention
comm_dropoutCommunication dropout value
num_unitsNumber of RIMs (Kt)
kNumber of active RIMs (ka)
sizeImage size for training.
loadsavedload saved model for training from log_dir.
log_dirDirectory path to save meta data.

Contact

For any issues/questions, you can open a GitHub issue or contact me directly.

Contributors

Showing top 1 contributor by commit count.

View all contributors on GitHub →

This article is auto-generated from dido1998/Recurrent-Independent-Mechanisms via the GitHub API.Last fetched: 6/27/2026