This repository contains the PyTorch implementation for our paper titled StARformer: Transformer with State-Action-Reward Representations for Visual Reinforcement Learning (ECCV 2022) and StARformer: Transformer with State-Action-Reward Representations for Robot Learning (IEEE T-PAMI).
[Installation] [Usage] [Citation] [Update Notes]
We learn local State-Action-Reward representations (StAR-representations) to improve (long) sequence modeling for reinforcement learning (and imitation learning).
![]() |
![]() |
![]() |
![]() |
For details and unormalized numbers, please check the supplementary at the end of the paper or here for conveience.
Dependencies can be installed by Conda:
For example to install env used for Atari and DMC (image input):
conda env create -f atari_and_dmc/conda_env.yml
Then activate it by
conda activate starformer
-
Atari: To run on atari environment, please install Atari ROMs.
-
DMC: Install dmc2gym by
pip install git+https://github.com/denisyarats/dmc2gym.git
Make sure you have MuJoCo installed. mujoco-py
has already been installed in the conda env for you, but it's good to check whether they two are compatible.
Please follow this instruction for datasets.
See run.sh
or below:
- atari:
python run_star_atari.py --seed 123 --data_dir_prefix [data_directory] --epochs 10 --num_steps 500000 --num_buffers 50 --batch_size 64 --seq_len 30 --model_type 'star' --game 'Breakout'
[data_directory]
is where you place the Atari dataset.
- dmc:
python run_star_dmc.py --seed 123 --data_dir_prefix [data_directory] --epochs 10 --seq_len 30 --model_type 'star' --batch_size 64 --domain ball_in_cup --task catch --lr 1e-4
similarly, [data_directory]
is where you place the DMC dataset. You can collect any replay buffer you desire and modify StateActionReturnDatasetDMC
in run_star_dmc.py
to make it compatible with your buffers.
'star'
(imitation)'star_rwd'
(offline RL)'star_fusion'
(see Figure 4a in our paper)'star_stack'
(see Figure 4b in our paper)
With num_steps=500000, batch_size=64, model_type=star_rwd
, on a single NVIDIA 3090Ti (24GB)
--seq_len=10
9685MB ~25min/epoch--seq_len=20
17033MB ~50min/epoch--seq_len=30
24007MB ~66min/epoch
If you are out of memory, you can reduce batch_size
If you find our paper useful for your research, please consider cite
@InProceedings{starformer,
author="Shang, Jinghuan and Kahatapitiya, Kumara and Li, Xiang and Ryoo, Michael S.",
title="StARformer: Transformer with State-Action-Reward Representations for Visual Reinforcement Learning",
booktitle="Computer Vision -- ECCV 2022",
year="2022",
publisher="Springer Nature Switzerland",
pages="462--479",
}
@ARTICLE{starformer-robot,
author={Shang, Jinghuan and Li, Xiang and Kahatapitiya, Kumara and Lee, Yu-Cheol and Ryoo, Michael S.},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={StARformer: Transformer with State-Action-Reward Representations for Robot Learning},
year={2022},
pages={1-16},
doi={10.1109/TPAMI.2022.3204708}
}
- Apr 6, 2023:
- fix bug in
run_star_atari.py
- fix conda env
- provide GPU usage reference
- fix bug in
- Nov 26, 2022:
- update code for dmc envrionments
- clean conda env file
This code is based on Decision-Transformer.