Skip to content

Commit

Permalink
feat(off-policy): support off-policy lag (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Apr 6, 2023
1 parent b79bea7 commit 52aaf32
Show file tree
Hide file tree
Showing 18 changed files with 721 additions and 225 deletions.
62 changes: 54 additions & 8 deletions omnisafe/adapter/offpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def __init__( # pylint: disable=too-many-arguments
seed: int,
cfgs: Config,
) -> None:
"""Initialize the off-policy adapter.
Args:
env_id (str): The environment id.
num_envs (int): The number of environments.
seed (int): The random seed.
cfgs (Config): The configuration.
"""
super().__init__(env_id, num_envs, seed, cfgs)

self._ep_ret: torch.Tensor
Expand All @@ -72,6 +80,44 @@ def __init__( # pylint: disable=too-many-arguments
self._device = cfgs.train_cfgs.device
self._reset_log()

def eval_policy( # pylint: disable=too-many-locals
self,
episode: int,
agent: ConstraintActorQCritic,
logger: Logger,
) -> None:
"""Roll out the environment and store the data in the buffer.
Args:
episode (int): Number of episodes.
agent (ConstraintActorCritic): Agent.
logger (Logger): Logger.
"""
for _ in range(episode):
ep_ret, ep_cost, ep_len = 0.0, 0.0, 0
done = False
obs, _ = self._eval_env.reset()
obs = obs.to(self._device)
while not done:
act = agent.step(obs, deterministic=True)
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act)
obs, reward, cost, terminated, truncated = (
torch.as_tensor(x, dtype=torch.float32, device=self._device)
for x in (obs, reward, cost, terminated, truncated)
)
ep_ret += info.get('original_reward', reward).cpu()
ep_cost += info.get('original_cost', cost).cpu()
ep_len += 1
done = terminated or truncated
if done:
logger.store(
**{
'Metrics/TestEpRet': ep_ret,
'Metrics/TestEpCost': ep_cost,
'Metrics/TestEpLen': ep_len,
},
)

def roll_out( # pylint: disable=too-many-locals
self,
roll_out_step: int,
Expand Down Expand Up @@ -100,26 +146,26 @@ def roll_out( # pylint: disable=too-many-locals
)
else:
act = agent.step(self._current_obs, deterministic=False)

next_obs, reward, cost, terminated, truncated, info = self.step(act)

self._log_value(reward=reward, cost=cost, info=info)
real_next_obs = next_obs.clone()
for idx, done in enumerate(torch.logical_or(terminated, truncated)):
if done:
real_next_obs[idx] = info['final_observation'][idx]
self._log_metrics(logger, idx)
self._reset_log(idx)

buffer.store(
obs=self._current_obs,
act=act,
reward=reward,
cost=cost,
done=terminated,
next_obs=next_obs,
done=torch.logical_and(terminated, torch.logical_xor(terminated, truncated)),
next_obs=real_next_obs,
)

self._current_obs = next_obs
for idx, done in enumerate(torch.logical_or(terminated, truncated)):
if done or self._ep_len[idx] >= self._max_ep_len:
# self.reset()
self._log_metrics(logger, idx)
self._reset_log(idx)

def _log_value(
self,
Expand Down
31 changes: 24 additions & 7 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,39 @@ def __init__( # pylint: disable=too-many-arguments
) -> None:
"""Initialize the online adapter.
OmniSafe is a framework for safe reinforcement learning. It is designed to be
compatible with any existing RL algorithms. The online adapter is used
to adapt the environment to the framework.
OmniSafe provides a set of adapters to adapt the environment to the framework.
- OnPolicyAdapter: Adapt the environment to the on-policy framework.
- OffPolicyAdapter: Adapt the environment to the off-policy framework.
- SauteAdapter: Adapt the environment to the SAUTE framework.
- SimmerAdapter: Adapt the environment to the SIMMER framework.
Args:
env_id (str): The environment id.
num_envs (int): The number of environments.
seed (int): The random seed.
cfgs (Config): The configuration.
"""
assert env_id in support_envs(), f'Env {env_id} is not supported.'

self._env_id = env_id
self._env = make(env_id, num_envs=num_envs, device=cfgs.train_cfgs.device)
self._cfgs = cfgs
self._device = cfgs.train_cfgs.device

self._env_id = env_id
self._env = make(env_id, num_envs=num_envs, device=self._device)
self._eval_env = make(env_id, num_envs=1, device=self._device)
self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
cost_normalize=cfgs.algo_cfgs.cost_normalize,
)

self._env.set_seed(seed)
self._eval_env.set_seed(seed)

def _wrapper(
self,
Expand Down Expand Up @@ -115,26 +129,29 @@ def _wrapper(
obs_normalize (bool): Whether to normalize the observation.
reward_normalize (bool): Whether to normalize the reward.
cost_normalize (bool): Whether to normalize the cost.
"""
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, device=self._device, time_limit=1000)
self._env = TimeLimit(self._env, time_limit=1000, device=self._device)
self._eval_env = TimeLimit(self._eval_env, time_limit=1000, device=self._device)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env, device=self._device)
self._eval_env = AutoReset(self._eval_env, device=self._device)
if obs_normalize:
self._env = ObsNormalize(self._env, device=self._device)
self._eval_env = ObsNormalize(self._eval_env, device=self._device)
if reward_normalize:
self._env = RewardNormalize(self._env, device=self._device)
if cost_normalize:
self._env = CostNormalize(self._env, device=self._device)
self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0)
self._env = ActionScale(self._env, low=-1.0, high=1.0, device=self._device)
self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env, device=self._device)
self._eval_env = Unsqueeze(self._eval_env, device=self._device)

@property
def action_space(self) -> OmnisafeSpace:
"""The action space of the environment.
Returns:
OmnisafeSpace: the action space.
"""
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from omnisafe.algorithms import off_policy, on_policy
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.algorithms.off_policy import DDPG, SAC, TD3
from omnisafe.algorithms.off_policy import DDPG, SAC, TD3, DDPGLag, SACLag, TD3Lag

# On-Policy Safe
from omnisafe.algorithms.on_policy import (
Expand Down
9 changes: 4 additions & 5 deletions omnisafe/algorithms/off_policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
"""Off-policy algorithms."""

from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag
from omnisafe.algorithms.off_policy.sac import SAC
from omnisafe.algorithms.off_policy.sac_lag import SACLag
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.algorithms.off_policy.td3_lag import TD3Lag


__all__ = [
'DDPG',
'TD3',
'SAC',
]
__all__ = ['DDPG', 'TD3', 'SAC', 'DDPGLag', 'TD3Lag', 'SACLag']
36 changes: 22 additions & 14 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _init_env(self) -> None:
self._update_cycle % self._steps_per_sample == 0
), 'The number of steps per epoch is not divisible by the number of steps per sample.'
self._samples_per_epoch = self._update_cycle // self._steps_per_sample
self._update_count = 0

def _init_model(self) -> None:
self._cfgs.model_cfgs.critic['num_critics'] = 1
Expand Down Expand Up @@ -114,6 +115,10 @@ def _init_log(self) -> None:
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)

self._logger.register_key('Metrics/TestEpRet', window_length=50)
self._logger.register_key('Metrics/TestEpCost', window_length=50)
self._logger.register_key('Metrics/TestEpLen', window_length=50)

self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/LR')

Expand All @@ -137,10 +142,6 @@ def _init_log(self) -> None:
self._logger.register_key('Time/Epoch')
self._logger.register_key('Time/FPS')

def _update_epoch(self) -> None:
"""Update something per epoch"""
self._actor_critic.actor_scheduler.step()

def learn(self) -> tuple[int | float, ...]:
"""This is main function for algorithm update, divided into the following steps:
Expand Down Expand Up @@ -185,13 +186,20 @@ def learn(self) -> tuple[int | float, ...]:
self._log_when_not_update()
update_time += time.time() - update_start

self._env.eval_policy(
episode=2,
agent=self._actor_critic,
logger=self._logger,
)

self._logger.store(**{'Time/Update': update_time})
self._logger.store(**{'Time/Rollout': roll_out_time})

if step > self._cfgs.algo_cfgs.start_learning_steps:
# update something per epoch
# e.g. update lagrange multiplier
self._update_epoch()
if (
step > self._cfgs.algo_cfgs.start_learning_steps
and self._cfgs.model_cfgs.linear_lr_decay
):
self._actor_critic.actor_scheduler.step()

self._logger.store(
**{
Expand All @@ -218,8 +226,9 @@ def learn(self) -> tuple[int | float, ...]:
return ep_ret, ep_cost, ep_len

def _update(self) -> None:
for step in range(self._steps_per_sample // self._cfgs.algo_cfgs.update_iters):
for _ in range(self._cfgs.algo_cfgs.update_iters):
data = self._buf.sample_batch()
self._update_count += 1
obs, act, reward, cost, done, next_obs = (
data['obs'],
data['act'],
Expand All @@ -229,16 +238,15 @@ def _update(self) -> None:
data['next_obs'],
)

self._update_rewrad_critic(obs, act, reward, done, next_obs)
self._update_reward_critic(obs, act, reward, done, next_obs)
if self._cfgs.algo_cfgs.use_cost:
self._update_cost_critic(obs, act, cost, done, next_obs)

if step % self._cfgs.algo_cfgs.policy_delay == 0:
if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0:
self._update_actor(obs)
self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)

self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)

def _update_rewrad_critic(
def _update_reward_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
Expand Down
74 changes: 74 additions & 0 deletions omnisafe/algorithms/off_policy/ddpg_lag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrangian version of Deep Deterministic Policy Gradient algorithm."""


import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.common.lagrange import Lagrange


@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class DDPGLag(DDPG):
"""The Lagrangian version of Deep Deterministic Policy Gradient (DDPG) algorithm.
References:
- Title: Continuous control with deep reinforcement learning
- Authors: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess,
Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
- URL: `DDPG <https://arxiv.org/abs/1509.02971>`_
"""

def _init(self) -> None:
super()._init()
self._lagrange = Lagrange(**self._cfgs.lagrange_cfgs)

def _init_log(self) -> None:
super()._init_log()
self._logger.register_key('Metrics/LagrangeMultiplier')

def _update(self) -> None:
super()._update()
Jc = self._logger.get_stats('Metrics/EpCost')[0]
self._lagrange.update_lagrange_multiplier(Jc)
self._logger.store(
**{
'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(),
},
)

def _loss_pi(
self,
obs: torch.Tensor,
) -> torch.Tensor:
action = self._actor_critic.actor.predict(obs, deterministic=True)
loss_r = -self._actor_critic.reward_critic(obs, action)[0]
loss_c = (
self._lagrange.lagrangian_multiplier.item()
* self._actor_critic.cost_critic(obs, action)[0]
)
return (loss_r + loss_c).mean() / (1 + self._lagrange.lagrangian_multiplier.item())

def _log_when_not_update(self) -> None:
super()._log_when_not_update()
self._logger.store(
**{
'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(),
},
)
Loading

0 comments on commit 52aaf32

Please sign in to comment.