From d1e171e14bb5c904e28294b7adcc0697fc32bdff Mon Sep 17 00:00:00 2001 From: zmsn-2077 <73586554+zmsn-2077@users.noreply.github.com> Date: Sun, 18 Dec 2022 17:42:06 +0800 Subject: [PATCH] feat(wrapper): separated wrapper for different algorithmic environments (#44) --- docs/source/spelling_wordlist.txt | 11 + examples/train_policy.py | 9 +- omnisafe/__init__.py | 5 +- omnisafe/algorithms/__init__.py | 5 + omnisafe/{ => algorithms}/algo_wrapper.py | 9 +- omnisafe/algorithms/off_policy/ddpg.py | 434 ++++++++++++++++++ omnisafe/algorithms/on_policy/cpo.py | 8 +- omnisafe/algorithms/on_policy/cppo_pid.py | 11 +- omnisafe/algorithms/on_policy/cup.py | 7 +- omnisafe/algorithms/on_policy/focops.py | 9 +- omnisafe/algorithms/on_policy/natural_pg.py | 6 +- omnisafe/algorithms/on_policy/npg_lag.py | 11 +- omnisafe/algorithms/on_policy/pcpo.py | 6 +- omnisafe/algorithms/on_policy/pdo.py | 11 +- .../algorithms/on_policy/policy_gradient.py | 11 +- omnisafe/algorithms/on_policy/ppo.py | 7 +- omnisafe/algorithms/on_policy/ppo_lag.py | 7 +- omnisafe/algorithms/on_policy/trpo.py | 6 +- omnisafe/algorithms/on_policy/trpo_lag.py | 6 +- omnisafe/common/base_buffer.py | 6 +- omnisafe/common/lagrange.py | 2 +- omnisafe/configs/off-policy/DDPG.yaml | 2 +- omnisafe/evaluator.py | 2 +- omnisafe/models/actor/cholesky_actor.py | 128 ++++++ .../models/actor/gaussian_annealing_actor.py | 7 +- omnisafe/models/actor/mlp_actor.py | 76 +++ omnisafe/models/actor_q_critic.py | 137 ++++++ omnisafe/models/constraint_actor_critic.py | 4 +- omnisafe/models/constraint_actor_q_critic.py | 75 +++ omnisafe/models/critic/q_critic.py | 3 +- omnisafe/models/critic/v_critic.py | 1 + omnisafe/utils/vtrace.py | 8 +- omnisafe/wrappers/__init__.py | 18 + .../{algorithms => wrappers}/env_wrapper.py | 0 omnisafe/wrappers/off_policy_wrapper.py | 151 ++++++ omnisafe/wrappers/on_policy_wrapper.py | 138 ++++++ omnisafe/wrappers/wrapper_registry.py | 72 +++ tests/test_policy.py | 5 +- tests/test_safety_gym_envs.py | 4 +- 39 files changed, 1350 insertions(+), 68 deletions(-) rename omnisafe/{ => algorithms}/algo_wrapper.py (94%) create mode 100644 omnisafe/algorithms/off_policy/ddpg.py create mode 100644 omnisafe/models/actor/cholesky_actor.py create mode 100644 omnisafe/models/actor/mlp_actor.py create mode 100644 omnisafe/models/actor_q_critic.py create mode 100644 omnisafe/models/constraint_actor_q_critic.py create mode 100644 omnisafe/wrappers/__init__.py rename omnisafe/{algorithms => wrappers}/env_wrapper.py (100%) create mode 100644 omnisafe/wrappers/off_policy_wrapper.py create mode 100644 omnisafe/wrappers/on_policy_wrapper.py create mode 100644 omnisafe/wrappers/wrapper_registry.py diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index ea6397940..c8b85c7e9 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -127,3 +127,14 @@ FOCOPS Kakade QCritic yaml +polyak +MSE +Daan +Wierstra +Pritzel +Heess +mul +logprob +Tanh +Eq +chol diff --git a/examples/train_policy.py b/examples/train_policy.py index 4ee98cfc8..675d51455 100644 --- a/examples/train_policy.py +++ b/examples/train_policy.py @@ -41,6 +41,11 @@ keys = [k[2:] for k in unparsed_args[0::2]] values = list(unparsed_args[1::2]) unparsed_dict = dict(zip(keys, values)) - env = omnisafe.Env(args.env_id) - agent = omnisafe.Agent(args.algo, env, parallel=args.parallel, custom_cfgs=unparsed_dict) + # env = omnisafe.Env(args.env_id) + agent = omnisafe.Agent( + args.algo, + args.env_id, + parallel=args.parallel, + custom_cfgs=unparsed_dict, + ) agent.learn() diff --git a/omnisafe/__init__.py b/omnisafe/__init__.py index 81282f0ad..77f8e63e0 100644 --- a/omnisafe/__init__.py +++ b/omnisafe/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning.""" -from omnisafe.algo_wrapper import AlgoWrapper as Agent -from omnisafe.algorithms.env_wrapper import EnvWrapper as Env +from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent + +# from omnisafe.algorithms.env_wrapper import EnvWrapper as Env from omnisafe.version import __version__ diff --git a/omnisafe/algorithms/__init__.py b/omnisafe/algorithms/__init__.py index d7be2ba09..13f758cac 100644 --- a/omnisafe/algorithms/__init__.py +++ b/omnisafe/algorithms/__init__.py @@ -14,9 +14,13 @@ # ============================================================================== """Safe Reinforcement Learning algorithms.""" +# Off Policy Safe +from omnisafe.algorithms.off_policy.ddpg import DDPG + # On Policy Safe from omnisafe.algorithms.on_policy.cpo import CPO from omnisafe.algorithms.on_policy.cppo_pid import CPPOPid +from omnisafe.algorithms.on_policy.cup import CUP from omnisafe.algorithms.on_policy.focops import FOCOPS from omnisafe.algorithms.on_policy.natural_pg import NaturalPG from omnisafe.algorithms.on_policy.npg_lag import NPGLag @@ -45,6 +49,7 @@ 'PPOLag', 'TRPO', 'TRPOLag', + 'CUP', ], 'model-based': ['MBPPOLag', 'SafeLoop'], } diff --git a/omnisafe/algo_wrapper.py b/omnisafe/algorithms/algo_wrapper.py similarity index 94% rename from omnisafe/algo_wrapper.py rename to omnisafe/algorithms/algo_wrapper.py index 7e9aeb812..def42e212 100644 --- a/omnisafe/algo_wrapper.py +++ b/omnisafe/algorithms/algo_wrapper.py @@ -28,11 +28,10 @@ class AlgoWrapper: """Algo Wrapper for algo""" - def __init__(self, algo, env, parallel=1, custom_cfgs=None): + def __init__(self, algo, env_id, parallel=1, custom_cfgs=None): self.algo = algo - self.env = env self.parallel = parallel - self.env_id = env.env_id + self.env_id = env_id # algo_type will set in _init_checks() self.algo_type = None self.custom_cfgs = custom_cfgs @@ -69,12 +68,12 @@ def learn(self): sys.exit() default_cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type) - exp_name = os.path.join(self.env.env_id, self.algo) + exp_name = os.path.join(self.env_id, self.algo) default_cfgs.update(exp_name=exp_name, env_id=self.env_id) cfgs = recursive_update(default_cfgs, self.custom_cfgs) check_all_configs(cfgs, self.algo_type) agent = registry.get(self.algo)( - env=self.env, + env_id=self.env_id, cfgs=cfgs, ) agent.learn() diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py new file mode 100644 index 000000000..16f65cabd --- /dev/null +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -0,0 +1,434 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the DDPG algorithm.""" + +import time +from copy import deepcopy + +import numpy as np +import torch + +from omnisafe.algorithms import registry +from omnisafe.common.base_buffer import BaseBuffer +from omnisafe.common.logger import Logger +from omnisafe.models.constraint_actor_q_critic import ConstraintActorQCritic +from omnisafe.utils import core, distributed_utils +from omnisafe.utils.tools import get_flat_params_from +from omnisafe.wrappers import wrapper_registry + + +@registry.register +class DDPG: + """Continuous control with deep reinforcement learning (DDPG) Algorithm. + + References: + Paper Name: Continuous control with deep reinforcement learning. + Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. + Paper URL: https://arxiv.org/abs/1509.02971 + + """ + + def __init__( + self, + env_id: str, + cfgs=None, + algo: str = 'DDPG', + wrapper_type: str = 'OffPolicyEnvWrapper', + ): + """Initialize DDPG.""" + self.env = wrapper_registry.get(wrapper_type)( + env_id, + use_cost=cfgs.use_cost, + max_ep_len=cfgs.max_ep_len, + ) + self.env_id = env_id + self.algo = algo + self.cfgs = deepcopy(cfgs) + + # Set up for learning and rolling out schedule + self.steps_per_epoch = cfgs.steps_per_epoch + self.local_steps_per_epoch = cfgs.steps_per_epoch + self.epochs = cfgs.epochs + self.total_steps = self.epochs * self.steps_per_epoch + self.start_steps = cfgs.start_steps + # The steps in each process should be integer + assert cfgs.steps_per_epoch % distributed_utils.num_procs() == 0 + # Ensure local each local process can experience at least one complete episode + assert self.env.max_ep_len <= self.local_steps_per_epoch, ( + f'Reduce number of cores ({distributed_utils.num_procs()}) or increase ' + f'batch size {self.steps_per_epoch}.' + ) + # Ensure valid number for iteration + assert cfgs.update_every > 0 + self.max_ep_len = cfgs.max_ep_len + if hasattr(self.env, '_max_episode_steps'): + self.max_ep_len = self.env.env._max_episode_steps + self.update_after = cfgs.update_after + self.update_every = cfgs.update_every + self.num_test_episodes = cfgs.num_test_episodes + + self.env.set_rollout_cfgs( + determinstic=False, + rand_a=True, + ) + + # Set up logger and save configuration to disk + self.logger = Logger(exp_name=cfgs.exp_name, data_dir=cfgs.data_dir, seed=cfgs.seed) + self.logger.save_config(cfgs._asdict()) + # Set seed + seed = cfgs.seed + 10000 * distributed_utils.proc_id() + torch.manual_seed(seed) + np.random.seed(seed) + self.env.set_seed(seed=seed) + # Setup actor-critic module + self.actor_critic = ConstraintActorQCritic( + observation_space=self.env.observation_space, + action_space=self.env.action_space, + scale_rewards=cfgs.scale_rewards, + standardized_obs=cfgs.standardized_obs, + model_cfgs=cfgs.model_cfgs, + ) + # Set PyTorch + MPI. + self._init_mpi() + # Set up experience buffer + # obs_dim, act_dim, size, batch_size + self.buf = BaseBuffer( + obs_dim=self.env.observation_space.shape, + act_dim=self.env.action_space.shape, + size=cfgs.replay_buffer_cfgs.size, + batch_size=cfgs.replay_buffer_cfgs.batch_size, + ) + # Set up optimizer for policy and q-function + self.actor_optimizer = core.set_optimizer( + 'Adam', module=self.actor_critic.actor, learning_rate=cfgs.actor_lr + ) + self.critic_optimizer = core.set_optimizer( + 'Adam', module=self.actor_critic.critic, learning_rate=cfgs.critic_lr + ) + if cfgs.use_cost: + self.cost_critic_optimizer = core.set_optimizer( + 'Adam', module=self.actor_critic.cost_critic, learning_rate=cfgs.critic_lr + ) + # Set up scheduler for policy learning rate decay + self.scheduler = self.set_learning_rate_scheduler() + # Set up target network for off_policy training + self._ac_training_setup() + torch.set_num_threads(10) + # Set up model saving + what_to_save = { + 'pi': self.actor_critic.actor, + 'obs_oms': self.actor_critic.obs_oms, + } + self.logger.setup_torch_saver(what_to_save=what_to_save) + self.logger.torch_save() + # Set up timer + self.start_time = time.time() + self.epoch_time = time.time() + self.logger.log('Start with training.') + + def set_learning_rate_scheduler(self): + """Set up learning rate scheduler.""" + scheduler = None + if self.cfgs.linear_lr_decay: + # Linear anneal + def linear_anneal(epoch): + return 1 - epoch / self.cfgs.epochs + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer=self.actor_optimizer, lr_lambda=linear_anneal + ) + return scheduler + + def _init_mpi(self): + """ + Initialize MPI specifics + """ + if distributed_utils.num_procs() > 1: + # Avoid slowdowns from PyTorch + MPI combo + distributed_utils.setup_torch_for_mpi() + start = time.time() + self.logger.log('INFO: Sync actor critic parameters') + # Sync params across cores: only once necessary, grads are averaged! + distributed_utils.sync_params(self.actor_critic) + self.logger.log(f'Done! (took {time.time()-start:0.3f} sec.)') + + def algorithm_specific_logs(self): + """ + Use this method to collect log information. + """ + + def _ac_training_setup(self): + """Set up target network for off_policy training.""" + self.ac_targ = deepcopy(self.actor_critic) + # Freeze target networks with respect to optimizer (only update via polyak averaging) + for param in self.ac_targ.actor.parameters(): + param.requires_grad = False + for param in self.ac_targ.critic.parameters(): + param.requires_grad = False + for param in self.ac_targ.cost_critic.parameters(): + param.requires_grad = False + if self.algo in ['SAC', 'TD3', 'SACLag', 'TD3Lag']: + # Freeze target networks with respect to optimizer (only update via polyak averaging) + for param in self.ac_targ.critic_.parameters(): + param.requires_grad = False + + def check_distributed_parameters(self): + """ + Check if parameters are synchronized across all processes. + """ + if distributed_utils.num_procs() > 1: + self.logger.log('Check if distributed parameters are synchronous..') + modules = {'Policy': self.actor_critic.actor.net, 'Value': self.actor_critic.critic.net} + for key, module in modules.items(): + flat_params = get_flat_params_from(module).numpy() + global_min = distributed_utils.mpi_min(np.sum(flat_params)) + global_max = distributed_utils.mpi_max(np.sum(flat_params)) + assert np.allclose(global_min, global_max), f'{key} not synced.' + + def compute_loss_pi(self, data: dict): + """ + computing pi/actor loss + + Returns: + torch.Tensor + """ + action, _ = self.actor_critic.actor.predict(data['obs'], deterministic=True) + loss_pi = self.actor_critic.critic(data['obs'], action) + pi_info = {} + return -loss_pi.mean(), pi_info + + def compute_loss_v(self, data): + """ + computing value loss + + Returns: + torch.Tensor + """ + obs, act, rew, obs_next, done = ( + data['obs'], + data['act'], + data['rew'], + data['obs_next'], + data['done'], + ) + q = self.actor_critic.critic(obs, act) + # Bellman backup for Q function + with torch.no_grad(): + act_targ, _ = self.ac_targ.actor.predict(obs, deterministic=True) + q_targ = self.ac_targ.critic(obs_next, act_targ) + backup = rew + self.cfgs.gamma * (1 - done) * q_targ + # MSE loss against Bellman backup + loss_q = ((q - backup) ** 2).mean() + # Useful info for logging + q_info = dict(Q1Vals=q.detach().numpy()) + return loss_q, q_info + + def compute_loss_c(self, data): + """ + computing cost loss + + Returns: + torch.Tensor + """ + obs, act, cost, obs_next, done = ( + data['obs'], + data['act'], + data['rew'], + data['obs_next'], + data['done'], + ) + qc = self.actor_critic.cost_critic(obs, act) + + # Bellman backup for Q function + with torch.no_grad(): + action, _ = self.ac_targ.pi.predict(obs_next, deterministic=True) + qc_targ = self.ac_targ.c(obs_next, action) + backup = cost + self.cfgs.gamma * (1 - done) * qc_targ + # MSE loss against Bellman backup + loss_qc = ((qc - backup) ** 2).mean() + # Useful info for logging + qc_info = dict(QCosts=qc.detach().numpy()) + + return loss_qc, qc_info + + def learn(self): + """ + This is main function for algorithm update, divided into the following steps: + (1). self.rollout: collect interactive data from environment + (2). self.update: perform actor/critic updates + (3). log epoch/update information for visualization and terminal log print. + + Returns: + model and environment + """ + + for steps in range(0, self.local_steps_per_epoch * self.epochs, self.update_every): + # Until start_steps have elapsed, randomly sample actions + # from a uniform distribution for better exploration. Afterwards, + # use the learned policy (with some noise, via act_noise). + use_rand_action = steps < self.start_steps + self.env.roll_out( + self.actor_critic, + self.buf, + self.logger, + deterministic=False, + use_rand_action=use_rand_action, + ep_steps=self.update_every, + ) + + # Update handling + if steps >= self.update_after: + for _ in range(self.update_every): + batch = self.buf.sample_batch() + self.update(data=batch) + + # End of epoch handling + if steps % self.steps_per_epoch == 0 and steps: + epoch = steps // self.steps_per_epoch + if self.cfgs.exploration_noise_anneal: + self.actor_critic.anneal_exploration(frac=epoch / self.epochs) + # if self.cfgs.use_cost_critic: + # if self.use_cost_decay: + # self.cost_limit_decay(epoch) + + # Save model to disk + if (epoch + 1) % self.cfgs.save_freq == 0: + self.logger.torch_save(itr=epoch) + + # Test the performance of the deterministic version of the agent. + self.test_agent() + # Log info about epoch + self.log(epoch, steps) + return self.actor_critic + + def update(self, data): + """update""" + # First run one gradient descent step for Q. + self.update_value_net(data) + if self.cfgs.use_cost: + self.update_cost_net(data) + for param in self.actor_critic.cost_critic.parameters(): + param.requires_grad = False + + # Freeze Q-network so you don't waste computational effort + # computing gradients for it during the policy learning step. + for param in self.actor_critic.critic.parameters(): + param.requires_grad = False + + # Next run one gradient descent step for pi. + self.update_policy_net(data) + + # Unfreeze Q-network so you can optimize it at next DDPG step. + for param in self.actor_critic.critic.parameters(): + param.requires_grad = True + + if self.cfgs.use_cost: + for param in self.actor_critic.cost_critic.parameters(): + param.requires_grad = True + + # Finally, update target networks by polyak averaging. + self.polyak_update_target() + + def polyak_update_target(self): + """polyak update target network""" + with torch.no_grad(): + for param, param_targ in zip(self.actor_critic.parameters(), self.ac_targ.parameters()): + # Notes: We use an in-place operations "mul_", "add_" to update target + # params, as opposed to "mul" and "add", which would make new tensors. + param_targ.data.mul_(self.cfgs.polyak) + param_targ.data.add_((1 - self.cfgs.polyak) * param.data) + + def update_policy_net(self, data) -> None: + """update policy network""" + # Train policy with one steps of gradient descent + self.actor_optimizer.zero_grad() + loss_pi, _ = self.compute_loss_pi(data) + loss_pi.backward() + self.actor_optimizer.step() + self.logger.store(**{'Loss/Pi': loss_pi.item()}) + + def update_value_net(self, data: dict) -> None: + """update value network""" + # Train value critic with one steps of gradient descent + self.critic_optimizer.zero_grad() + loss_q, q_info = self.compute_loss_v(data) + loss_q.backward() + self.critic_optimizer.step() + self.logger.store(**{'Loss/Value': loss_q.item(), 'Q1Vals': q_info['Q1Vals']}) + + def update_cost_net(self, data): + """update cost network""" + # Train cost critic with one steps of gradient descent + self.cost_critic_optimizer.zero_grad() + loss_qc, qc_info = self.compute_loss_c(data) + loss_qc.backward() + self.cost_critic_optimizer.step() + self.logger.store(**{'Loss/Cost': loss_qc.item(), 'QCosts': qc_info['QCosts']}) + + def test_agent(self): + """test agent""" + for _ in range(self.num_test_episodes): + # self.env.set_rollout_cfgs(deterministic=True, rand_a=False) + self.env.roll_out( + self.actor_critic, + self.buf, + self.logger, + deterministic=True, + use_rand_action=False, + ep_steps=self.max_ep_len, + ) + + def log(self, epoch, total_steps): + """Log info about epoch""" + fps = self.cfgs.steps_per_epoch / (time.time() - self.epoch_time) + # Step the actor learning rate scheduler if provided + if self.scheduler and self.cfgs.linear_lr_decay: + current_lr = self.scheduler.get_last_lr()[0] + self.scheduler.step() + else: + current_lr = self.cfgs.actor_lr + + self.logger.log_tabular('Epoch', epoch) + self.logger.log_tabular('Metrics/EpRet') + self.logger.log_tabular('Metrics/EpCosts') + self.logger.log_tabular('Metrics/EpLen') + self.logger.log_tabular('Test/EpRet') + self.logger.log_tabular('Test/EpCosts') + self.logger.log_tabular('Test/EpLen') + self.logger.log_tabular('Values/V', min_and_max=True) + self.logger.log_tabular('Q1Vals') + if self.cfgs.use_cost: + self.logger.log_tabular('Values/C', min_and_max=True) + self.logger.log_tabular('QCosts') + self.logger.log_tabular('Loss/Pi', std=False) + self.logger.log_tabular('Loss/Value') + if self.cfgs.use_cost: + self.logger.log_tabular('Loss/Cost') + self.logger.log_tabular('Misc/Seed', self.cfgs.seed) + self.logger.log_tabular('LR', current_lr) + if self.cfgs.scale_rewards: + reward_scale_mean = self.actor_critic.ret_oms.mean.item() + reward_scale_stddev = self.actor_critic.ret_oms.std.item() + self.logger.log_tabular('Misc/RewScaleMean', reward_scale_mean) + self.logger.log_tabular('Misc/RewScaleStddev', reward_scale_stddev) + if self.cfgs.exploration_noise_anneal: + noise_std = np.exp(self.actor_critic.pi.log_std[0].item()) + self.logger.log_tabular('Misc/ExplorationNoiseStd', noise_std) + self.algorithm_specific_logs() + self.logger.log_tabular('TotalEnvSteps', total_steps) + self.logger.log_tabular('Time', int(time.time() - self.start_time)) + self.logger.log_tabular('FPS', int(fps)) + + self.logger.dump_tabular() diff --git a/omnisafe/algorithms/on_policy/cpo.py b/omnisafe/algorithms/on_policy/cpo.py index 3cc668af9..207bd0abc 100644 --- a/omnisafe/algorithms/on_policy/cpo.py +++ b/omnisafe/algorithms/on_policy/cpo.py @@ -41,14 +41,16 @@ class CPO(TRPO): def __init__( self, - env, + env_id, cfgs, algo='CPO', + wrapper_type: str = 'OnPolicyEnvWrapper', ): super().__init__( - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) self.cost_limit = cfgs.cost_limit self.loss_pi_cost_before = 0.0 @@ -239,7 +241,7 @@ def update_policy_net( # point in trust region is feasible and safety boundary doesn't intersect # ==> entire trust region is feasible optim_case = 3 - elif cost < 0 and B >= 0: + elif cost < 0 and B >= 0: # pylint: disable=chained-comparison # x = 0 is feasible and safety boundary intersects # ==> most of trust region is feasible optim_case = 2 diff --git a/omnisafe/algorithms/on_policy/cppo_pid.py b/omnisafe/algorithms/on_policy/cppo_pid.py index 5db95c74c..85581a066 100644 --- a/omnisafe/algorithms/on_policy/cppo_pid.py +++ b/omnisafe/algorithms/on_policy/cppo_pid.py @@ -32,13 +32,20 @@ class CPPOPid(PolicyGradient, PIDLagrangian): """ - def __init__(self, env, cfgs, algo: str = 'CPPO-PID'): + def __init__( + self, + env_id, + cfgs, + algo: str = 'CPPO-PID', + wrapper_type: str = 'OnPolicyEnvWrapper', + ): PolicyGradient.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) PIDLagrangian.__init__(self, **self.cfgs.PID_cfgs._asdict()) diff --git a/omnisafe/algorithms/on_policy/cup.py b/omnisafe/algorithms/on_policy/cup.py index cdaeb910a..721382d26 100644 --- a/omnisafe/algorithms/on_policy/cup.py +++ b/omnisafe/algorithms/on_policy/cup.py @@ -35,17 +35,19 @@ class CUP(PolicyGradient, Lagrange): def __init__( self, - env, + env_id, cfgs, algo='CUP', + wrapper_type: str = 'OnPolicyEnvWrapper', ): r"""The :meth:`init` function.""" PolicyGradient.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( @@ -54,7 +56,6 @@ def __init__( lagrangian_multiplier_init=self.cfgs.lagrange_cfgs.lagrangian_multiplier_init, lambda_lr=self.cfgs.lagrange_cfgs.lambda_lr, lambda_optimizer=self.cfgs.lagrange_cfgs.lambda_optimizer, - lagrangian_upper_bound=self.cfgs.lagrange_cfgs.lagrangian_upper_bound, ) self.lam = self.cfgs.lam self.eta = self.cfgs.eta diff --git a/omnisafe/algorithms/on_policy/focops.py b/omnisafe/algorithms/on_policy/focops.py index b4f60d5f6..1b80e1962 100644 --- a/omnisafe/algorithms/on_policy/focops.py +++ b/omnisafe/algorithms/on_policy/focops.py @@ -13,10 +13,8 @@ # limitations under the License. # ============================================================================== """Implementation of the FOCOPS algorithm.""" -import time import torch -from torch.distributions.normal import Normal from omnisafe.algorithms import registry from omnisafe.algorithms.on_policy.policy_gradient import PolicyGradient @@ -37,17 +35,19 @@ class FOCOPS(PolicyGradient, Lagrange): def __init__( self, - env, + env_id, cfgs, algo='FOCOPS', + wrapper_type: str = 'OnPolicyEnvWrapper', ): r"""The :meth:`init` function.""" PolicyGradient.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( @@ -122,7 +122,6 @@ def slice_data(self, data) -> dict: 'adv': adv[i * batch_size : (i + 1) * batch_size], 'discounted_ret': discounted_ret[i * batch_size : (i + 1) * batch_size], 'cost_adv': cost_adv[i * batch_size : (i + 1) * batch_size], - 'target_v': target_v[i * batch_size : (i + 1) * batch_size], } ) diff --git a/omnisafe/algorithms/on_policy/natural_pg.py b/omnisafe/algorithms/on_policy/natural_pg.py index 24c740223..4b86d0ee2 100644 --- a/omnisafe/algorithms/on_policy/natural_pg.py +++ b/omnisafe/algorithms/on_policy/natural_pg.py @@ -40,14 +40,16 @@ class NaturalPG(PolicyGradient): def __init__( self, - env, + env_id, cfgs, algo: str = 'NaturalPolicyGradient', + wrapper_type: str = 'OnPolicyEnvWrapper', ): super().__init__( - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) self.cg_damping = cfgs.cg_damping self.cg_iters = cfgs.cg_iters diff --git a/omnisafe/algorithms/on_policy/npg_lag.py b/omnisafe/algorithms/on_policy/npg_lag.py index c5852d302..aa10d0b8a 100644 --- a/omnisafe/algorithms/on_policy/npg_lag.py +++ b/omnisafe/algorithms/on_policy/npg_lag.py @@ -29,14 +29,21 @@ class NPGLag(NaturalPG, Lagrange): """ - def __init__(self, env, cfgs, algo: str = 'NPG-LAG'): + def __init__( + self, + env_id, + cfgs, + algo: str = 'NPG-Lag', + wrapper_type: str = 'OnPolicyEnvWrapper', + ): """initialize""" NaturalPG.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( self, diff --git a/omnisafe/algorithms/on_policy/pcpo.py b/omnisafe/algorithms/on_policy/pcpo.py index b1f6191d3..edef0fdf4 100644 --- a/omnisafe/algorithms/on_policy/pcpo.py +++ b/omnisafe/algorithms/on_policy/pcpo.py @@ -40,14 +40,16 @@ class PCPO(TRPO): def __init__( self, - env, + env_id, cfgs, algo='PCPO', + wrapper_type: str = 'OnPolicyEnvWrapper', ): super().__init__( - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) self.cost_limit = self.cfgs.cost_limit diff --git a/omnisafe/algorithms/on_policy/pdo.py b/omnisafe/algorithms/on_policy/pdo.py index 66d9e1401..633dbebe7 100644 --- a/omnisafe/algorithms/on_policy/pdo.py +++ b/omnisafe/algorithms/on_policy/pdo.py @@ -29,13 +29,20 @@ class PDO(PolicyGradient, Lagrange): """ - def __init__(self, env, cfgs, algo='PDO'): + def __init__( + self, + env_id, + cfgs, + algo='PDO', + wrapper_type: str = 'OnPolicyEnvWrapper', + ): """initialization""" PolicyGradient.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( self, diff --git a/omnisafe/algorithms/on_policy/policy_gradient.py b/omnisafe/algorithms/on_policy/policy_gradient.py index dc1d7d259..0c84eb44b 100644 --- a/omnisafe/algorithms/on_policy/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/policy_gradient.py @@ -26,11 +26,11 @@ from omnisafe.models.constraint_actor_critic import ConstraintActorCritic from omnisafe.utils import core, distributed_utils from omnisafe.utils.tools import get_flat_params_from +from omnisafe.wrappers import wrapper_registry -# pylint: disable-next=too-many-instance-attributes @registry.register -class PolicyGradient: +class PolicyGradient: # pylint: disable=too-many-instance-attributes """The Policy Gradient algorithm. References: @@ -43,9 +43,10 @@ class PolicyGradient: # pylint: disable-next=too-many-locals def __init__( self, - env, + env_id, cfgs=None, algo: str = 'PolicyGradient', + wrapper_type: str = 'OnPolicyEnvWrapper', ) -> None: r"""Initialize the algorithm. @@ -56,7 +57,7 @@ def __init__( cfgs: (default: :const:`None`) This is a dictionary of the algorithm hyper-parameters. """ - self.env = env + self.env = wrapper_registry.get(wrapper_type)(env_id) self.algo = algo self.cfgs = deepcopy(cfgs) @@ -73,7 +74,7 @@ def __init__( self.logger = Logger(exp_name=cfgs.exp_name, data_dir=cfgs.data_dir, seed=cfgs.seed) self.logger.save_config(cfgs._asdict()) # Set seed - seed = cfgs.seed + 10000 * distributed_utils.proc_id() + seed = int(cfgs.seed) + 10000 * distributed_utils.proc_id() torch.manual_seed(seed) np.random.seed(seed) self.env.env.reset(seed=seed) diff --git a/omnisafe/algorithms/on_policy/ppo.py b/omnisafe/algorithms/on_policy/ppo.py index 0b025c2c8..e24335344 100644 --- a/omnisafe/algorithms/on_policy/ppo.py +++ b/omnisafe/algorithms/on_policy/ppo.py @@ -30,19 +30,22 @@ class PPO(PolicyGradient): Paper URL: https://arxiv.org/pdf/1707.06347.pdf """ + # pylint: disable-next=too-many-arguments def __init__( self, - env, + env_id, cfgs, algo='ppo', clip=0.2, + wrapper_type: str = 'OnPolicyEnvWrapper', ): """Initialize PPO.""" self.clip = clip super().__init__( - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) def compute_loss_pi(self, data: dict): diff --git a/omnisafe/algorithms/on_policy/ppo_lag.py b/omnisafe/algorithms/on_policy/ppo_lag.py index 1af0124a7..5ce205555 100644 --- a/omnisafe/algorithms/on_policy/ppo_lag.py +++ b/omnisafe/algorithms/on_policy/ppo_lag.py @@ -32,20 +32,23 @@ class PPOLag(PolicyGradient, Lagrange): """ + # pylint: disable-next=too-many-arguments def __init__( self, - env, + env_id, cfgs, algo='PPO-Lag', clip=0.2, + wrapper_type: str = 'OnPolicyEnvWrapper', ): """Initialize PPO-Lag algorithm.""" self.clip = clip PolicyGradient.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( self, diff --git a/omnisafe/algorithms/on_policy/trpo.py b/omnisafe/algorithms/on_policy/trpo.py index 412a7978a..630a7dcd8 100644 --- a/omnisafe/algorithms/on_policy/trpo.py +++ b/omnisafe/algorithms/on_policy/trpo.py @@ -39,14 +39,16 @@ class TRPO(NaturalPG): def __init__( self, - env, + env_id, cfgs, algo='TRPO', + wrapper_type: str = 'OnPolicyEnvWrapper', ): super().__init__( - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) # pylint: disable-next=too-many-arguments,too-many-locals,arguments-differ diff --git a/omnisafe/algorithms/on_policy/trpo_lag.py b/omnisafe/algorithms/on_policy/trpo_lag.py index 8241394b4..e3282863f 100644 --- a/omnisafe/algorithms/on_policy/trpo_lag.py +++ b/omnisafe/algorithms/on_policy/trpo_lag.py @@ -34,16 +34,18 @@ class TRPOLag(TRPO, Lagrange): def __init__( self, - env, + env_id, cfgs, algo: str = 'TRPO-Lag', + wrapper_type: str = 'OnPolicyEnvWrapper', ): """initialize""" TRPO.__init__( self, - env=env, + env_id=env_id, cfgs=cfgs, algo=algo, + wrapper_type=wrapper_type, ) Lagrange.__init__( self, diff --git a/omnisafe/common/base_buffer.py b/omnisafe/common/base_buffer.py index ca1d50dc2..a973a8ca8 100644 --- a/omnisafe/common/base_buffer.py +++ b/omnisafe/common/base_buffer.py @@ -27,7 +27,7 @@ class BaseBuffer: def __init__(self, obs_dim, act_dim, size, batch_size): """init""" self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32) - self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32) + self.obs_next_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32) self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32) self.rew_buf = np.zeros(size, dtype=np.float32) self.cost_buf = np.zeros(size, dtype=np.float32) @@ -39,7 +39,7 @@ def __init__(self, obs_dim, act_dim, size, batch_size): def store(self, obs, act, rew, cost, next_obs, done): """store""" self.obs_buf[self.ptr] = obs - self.obs2_buf[self.ptr] = next_obs + self.obs_next_buf[self.ptr] = next_obs self.act_buf[self.ptr] = act self.rew_buf[self.ptr] = rew self.cost_buf[self.ptr] = cost @@ -52,7 +52,7 @@ def sample_batch(self): idxs = np.random.randint(0, self.size, size=self.batch_size) batch = dict( obs=self.obs_buf[idxs], - obs2=self.obs2_buf[idxs], + obs_next=self.obs_next_buf[idxs], act=self.act_buf[idxs], rew=self.rew_buf[idxs], cost=self.cost_buf[idxs], diff --git a/omnisafe/common/lagrange.py b/omnisafe/common/lagrange.py index 343aa86ad..e6f11f0ed 100644 --- a/omnisafe/common/lagrange.py +++ b/omnisafe/common/lagrange.py @@ -28,7 +28,7 @@ def __init__( lagrangian_multiplier_init: float, lambda_lr: float, lambda_optimizer: str, - lagrangian_upper_bound: None, + lagrangian_upper_bound=None, ): """init""" self.cost_limit = cost_limit diff --git a/omnisafe/configs/off-policy/DDPG.yaml b/omnisafe/configs/off-policy/DDPG.yaml index a06459e27..48c76ae4c 100644 --- a/omnisafe/configs/off-policy/DDPG.yaml +++ b/omnisafe/configs/off-policy/DDPG.yaml @@ -37,7 +37,7 @@ defaults: # Optional Configuration ## Whether to use cost critic use_cost: False - use_linear_lr_decay: False + linear_lr_decay: False exploration_noise_anneal: False reward_penalty: False use_max_grad_norm: False diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 171f8b554..8e5f901e5 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -22,9 +22,9 @@ from gymnasium.spaces import Box, Discrete from gymnasium.utils.save_video import save_video -from omnisafe.algorithms.env_wrapper import EnvWrapper from omnisafe.models.actor import ActorBuilder from omnisafe.utils.online_mean_std import OnlineMeanStd +from omnisafe.wrappers.on_policy_wrapper import OnPolicyEnvWrapper as EnvWrapper class Evaluator: diff --git a/omnisafe/models/actor/cholesky_actor.py b/omnisafe/models/actor/cholesky_actor.py new file mode 100644 index 000000000..ee05933c7 --- /dev/null +++ b/omnisafe/models/actor/cholesky_actor.py @@ -0,0 +1,128 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of CholeskyActor.""" + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import MultivariateNormal + +from omnisafe.utils.model_utils import build_mlp_network, initialize_layer + + +# pylint: disable-next=too-many-instance-attributes +class MLPCholeskyActor(nn.Module): + """Implementation of CholeskyActor.""" + + COV_MIN = 1e-4 # last exp is 1e-2 + MEAN_CLAMP_MIN = -5 + MEAN_CLAMP_MAX = 5 + COV_CLAMP_MIN = -5 + COV_CLAMP_MAX = 20 + + # pylint: disable-next=too-many-arguments + def __init__( + self, + obs_dim, + act_dim, + act_limit, + hidden_sizes, + activation, + cov_min, + mu_clamp_min, + mu_clamp_max, + cov_clamp_min, + cov_clamp_max, + weight_initialization_mode, + ): + """Initialize.""" + super().__init__() + pi_sizes = [obs_dim] + hidden_sizes + self.act_limit = act_limit + self.act_low = torch.nn.Parameter( + torch.as_tensor(-act_limit), requires_grad=False + ) # (1, act_dim) + self.act_high = torch.nn.Parameter( + torch.as_tensor(act_limit), requires_grad=False + ) # (1, act_dim) + self.act_dim = act_dim + self.obs_dim = obs_dim + self.cov_min = cov_min + self.mu_clamp_min = mu_clamp_min + self.mu_clamp_max = mu_clamp_max + self.cov_clamp_min = cov_clamp_min + self.cov_clamp_max = cov_clamp_max + + self.net = build_mlp_network(pi_sizes, activation, activation) + self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim) + self.cholesky_layer = nn.Linear(hidden_sizes[-1], (self.act_dim * (self.act_dim + 1)) // 2) + initialize_layer(weight_initialization_mode, self.mu_layer) + # initialize_layer(weight_initialization_mode,self.cholesky_layer) + nn.init.constant_(self.mu_layer.bias, 0.0) + nn.init.constant_(self.cholesky_layer.bias, 0.0) + + def predict( + self, + obs, + deterministic=False, + ): # pylint: disable=invalid-name + """ + forwards input through the network + :param obs: (B, obs_dim) + :return: mu vector (B, act_dim) and cholesky factorization of covariance matrix (B, act_dim, act_dim) + """ + if len(obs.shape) == 1: + obs = torch.unsqueeze(obs, dim=0) + B = obs.size(0) + + net_out = self.net(obs) + + clamped_mu = torch.clamp(self.mu_layer(net_out), self.mu_clamp_min, self.mu_clamp_max) + mean = torch.sigmoid(clamped_mu) # (B, act_dim) + + mean = self.act_low + (self.act_high - self.act_low) * mean + # Compute logprob from Gaussian, and then apply correction for Tanh squashing. + # NOTE: The correction formula is a little bit magic. To get an understanding + # of where it comes from, check out the original SAC paper (arXiv 1801.01290) + # and look in appendix C. This is a more numerically-stable equivalent to Eq 21. + # Try deriving it yourself as a (very difficult) exercise. :) + cholesky_vector = torch.clamp( + self.cholesky_layer(net_out), self.cov_clamp_min, self.cov_clamp_max + ) # (B, (act_dim*(act_dim+1))//2) + cholesky_diag_index = torch.arange(self.act_dim, dtype=torch.long) + 1 + # cholesky_diag_index = (cholesky_diag_index * (cholesky_diag_index + 1)) // 2 - 1 + cholesky_diag_index = ( + torch.div(cholesky_diag_index * (cholesky_diag_index + 1), 2, rounding_mode='floor') - 1 + ) + # add a small value to prevent the diagonal from being 0. + cholesky_vector[:, cholesky_diag_index] = ( + F.softplus(cholesky_vector[:, cholesky_diag_index]) + self.COV_MIN + ) + tril_indices = torch.tril_indices(row=self.act_dim, col=self.act_dim, offset=0) + cholesky = torch.zeros(size=(B, self.act_dim, self.act_dim), dtype=torch.float32) + cholesky[:, tril_indices[0], tril_indices[1]] = cholesky_vector + pi_distribution = MultivariateNormal(mean, scale_tril=cholesky) + + if deterministic: + pi_action = mean + else: + pi_action = pi_distribution.rsample() + + pi_action = torch.tanh(pi_action) + pi_action = self.act_limit * pi_action + return pi_action.squeeze(), cholesky + + def forward(self, obs, deterministic=False): + """Forward.""" diff --git a/omnisafe/models/actor/gaussian_annealing_actor.py b/omnisafe/models/actor/gaussian_annealing_actor.py index 258bcf2c7..3bdf51014 100644 --- a/omnisafe/models/actor/gaussian_annealing_actor.py +++ b/omnisafe/models/actor/gaussian_annealing_actor.py @@ -36,13 +36,13 @@ def __init__( activation, weight_initialization_mode, shared=None, - satrt_std: float = 0.5, + start_std: float = 0.5, end_std: float = 0.01, ): super().__init__( obs_dim, act_dim, hidden_sizes, activation, weight_initialization_mode, shared ) - self.start_std = satrt_std + self.start_std = start_std self.end_std = end_std self._std = self.start_std * torch.ones(self.act_dim, dtype=torch.float32) @@ -76,9 +76,6 @@ def predict(self, obs, deterministic=False, need_log_prob=False): else: out = dist.sample() - action = torch.clamp(out, -1, 1) - action = self.act_min + (action + 1) * 0.5 * (self.act_max - self.act_min) - if need_log_prob: log_prob = dist.log_prob(out).sum(axis=-1) return out, log_prob diff --git a/omnisafe/models/actor/mlp_actor.py b/omnisafe/models/actor/mlp_actor.py new file mode 100644 index 000000000..ccd8a3bf2 --- /dev/null +++ b/omnisafe/models/actor/mlp_actor.py @@ -0,0 +1,76 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of MLPActor.""" + +import numpy as np +import torch +from torch import nn +from torch.distributions.normal import Normal + +from omnisafe.models.base import Actor +from omnisafe.utils.model_utils import Activation, InitFunction, build_mlp_network + + +class MLPActor(Actor): + """A abstract class for actor.""" + + # pylint: disable-next=too-many-arguments + def __init__( + self, + obs_dim: int, + act_dim: int, + act_noise, + act_limit, + hidden_sizes: list, + activation: Activation, + weight_initialization_mode: InitFunction = 'xavier_uniform', + shared: nn.Module = None, + ): + super().__init__(obs_dim, act_dim, hidden_sizes, activation) + self.act_limit = act_limit + self.act_noise = act_noise + + if shared is not None: # use shared layers + action_head = build_mlp_network( + sizes=[hidden_sizes[-1], act_dim], + activation=activation, + output_activation='tanh', + weight_initialization_mode=weight_initialization_mode, + ) + self.net = nn.Sequential(shared, action_head) + else: + self.net = build_mlp_network( + [obs_dim] + list(hidden_sizes) + [act_dim], + activation=activation, + output_activation='tanh', + weight_initialization_mode=weight_initialization_mode, + ) + + def _distribution(self, obs): + mean = self.net(obs) + return Normal(mean, self._std) + + def forward(self, obs, act=None): + """forward""" + # Return output from network scaled to action space limits. + return self.act_limit * self.net(obs) + + def predict(self, obs, deterministic=False, need_log_prob=False): + if deterministic: + action = self.act_limit * self.net(obs) + else: + action = self.act_limit * self.net(obs) + action += self.act_noise * np.random.randn(self.act_dim) + return action.to(torch.float32), torch.tensor(1, dtype=torch.float32) diff --git a/omnisafe/models/actor_q_critic.py b/omnisafe/models/actor_q_critic.py new file mode 100644 index 000000000..66dc66a83 --- /dev/null +++ b/omnisafe/models/actor_q_critic.py @@ -0,0 +1,137 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of ActorQCritic.""" + +import numpy as np +import torch +import torch.nn as nn +from gymnasium.spaces import Box + +from omnisafe.models.actor.mlp_actor import MLPActor +from omnisafe.models.critic.q_critic import QCritic +from omnisafe.utils.model_utils import build_mlp_network +from omnisafe.utils.online_mean_std import OnlineMeanStd + + +# pylint: disable-next=too-many-instance-attributes +class ActorQCritic(nn.Module): + """Class for ActorQCritic.""" + + # pylint: disable-next=too-many-arguments + def __init__( + self, + observation_space, + action_space, + standardized_obs: bool, + shared_weights: bool, + model_cfgs, + weight_initialization_mode='kaiming_uniform', + ) -> None: + """Initialize ActorQCritic""" + super().__init__() + + self.obs_shape = observation_space.shape + self.obs_oms = OnlineMeanStd(shape=self.obs_shape) if standardized_obs else None + self.act_dim = action_space.shape[0] + self.act_limit = action_space.high[0] + self.ac_kwargs = model_cfgs.ac_kwargs + # build policy and value functions + if isinstance(action_space, Box): + if model_cfgs.pi_type == 'dire': + actor_fn = MLPActor + act_dim = action_space.shape[0] + else: + raise ValueError + + self.obs_dim = observation_space.shape[0] + + # Use for shared weights + layer_units = [self.obs_dim] + model_cfgs.ac_kwargs.pi.hidden_sizes + + activation = model_cfgs.ac_kwargs.pi.activation + if shared_weights: + shared = build_mlp_network( + layer_units, + activation=activation, + weight_initialization_mode=weight_initialization_mode, + output_activation=activation, + ) + else: + shared = None + + self.actor = actor_fn( + obs_dim=self.obs_dim, + act_dim=act_dim, + act_noise=model_cfgs.ac_kwargs.pi.act_noise, + act_limit=self.act_limit, + hidden_sizes=model_cfgs.ac_kwargs.pi.hidden_sizes, + activation=model_cfgs.ac_kwargs.pi.activation, + weight_initialization_mode=weight_initialization_mode, + shared=shared, + ) + self.critic = QCritic( + self.obs_dim, + act_dim, + hidden_sizes=model_cfgs.ac_kwargs.val.hidden_sizes, + activation=model_cfgs.ac_kwargs.val.activation, + weight_initialization_mode=weight_initialization_mode, + shared=shared, + ) + self.critic_ = QCritic( + self.obs_dim, + act_dim, + hidden_sizes=model_cfgs.ac_kwargs.val.hidden_sizes, + activation=model_cfgs.ac_kwargs.val.activation, + weight_initialization_mode=weight_initialization_mode, + shared=shared, + ) + + def step(self, obs, deterministic=False): + """ + If training, this includes exploration noise! + Expects that obs is not pre-processed. + Args: + obs, , description + Returns: + action, value, log_prob(action) + Note: + Training mode can be activated with ac.train() + Evaluation mode is activated by ac.eval() + """ + with torch.no_grad(): + if self.obs_oms: + # Note: Update RMS in Algorithm.running_statistics() method + # self.obs_oms.update(obs) if self.training else None + obs = self.obs_oms(obs) + if isinstance(self.pi, MLPActor): + action = self.pi.predict(obs, determinstic=deterministic) + else: + action, logp_a = self.pi.predict(obs, determinstic=deterministic) + value = self.v(obs, action) + action = np.clip(action.numpy(), -self.act_limit, self.act_limit) + + return action, value.numpy(), logp_a.numpy() + + def anneal_exploration(self, frac): + """update internals of actors + 1) Updates exploration parameters for Gaussian actors update log_std + frac: progress of epochs, i.e. current epoch / total epochs + e.g. 10 / 100 = 0.1 + """ + if hasattr(self.pi, 'set_log_std'): + self.pi.set_log_std(1 - frac) + + def forward(self, obs, act): + """Compute the value of a given state-action pair.""" diff --git a/omnisafe/models/constraint_actor_critic.py b/omnisafe/models/constraint_actor_critic.py index e2fec8a32..096889f16 100644 --- a/omnisafe/models/constraint_actor_critic.py +++ b/omnisafe/models/constraint_actor_critic.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Implementation of ConstraintActorCritic.""" + import torch from omnisafe.models.actor_critic import ActorCritic @@ -29,9 +30,6 @@ def __init__( action_space, standardized_obs: bool, scale_rewards: bool, - # shared_weights: bool, - # ac_kwargs: dict, - # weight_initialization_mode='kaiming_uniform', model_cfgs, ) -> None: ActorCritic.__init__( diff --git a/omnisafe/models/constraint_actor_q_critic.py b/omnisafe/models/constraint_actor_q_critic.py new file mode 100644 index 000000000..e6116866e --- /dev/null +++ b/omnisafe/models/constraint_actor_q_critic.py @@ -0,0 +1,75 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of ConstraintActorQCritic.""" + +import numpy as np +import torch + +from omnisafe.models.actor_q_critic import ActorQCritic +from omnisafe.models.critic.q_critic import QCritic + + +class ConstraintActorQCritic(ActorQCritic): + """ConstraintActorQCritic is a wrapper around ActorQCritic that adds a cost critic to the model.""" + + # pylint: disable-next=too-many-arguments + def __init__( + self, + observation_space, + action_space, + standardized_obs: bool, + model_cfgs, + ): + """Initialize ConstraintActorQCritic.""" + + super().__init__( + observation_space=observation_space, + action_space=action_space, + standardized_obs=standardized_obs, + shared_weights=model_cfgs.shared_weights, + model_cfgs=model_cfgs, + ) + self.cost_critic = QCritic( + obs_dim=self.obs_dim, + act_dim=self.act_dim, + hidden_sizes=self.ac_kwargs.val.hidden_sizes, + activation=self.ac_kwargs.val.activation, + weight_initialization_mode=model_cfgs.weight_initialization_mode, + shared=model_cfgs.shared_weights, + ) + + def step(self, obs, deterministic=False): + """ + If training, this includes exploration noise! + Expects that obs is not pre-processed. + Args: + obs, description + Returns: + action, value, log_prob(action) + Note: + Training mode can be activated with ac.train() + Evaluation mode is activated by ac.eval() + """ + with torch.no_grad(): + if self.obs_oms: + # Note: Update RMS in Algorithm.running_statistics() method + # self.obs_oms.update(obs) if self.training else None + obs = self.obs_oms(obs) + action, logp_a = self.actor.predict(obs, deterministic=deterministic) + value = self.critic(obs, action) + cost_value = self.cost_critic(obs, action) + action = np.clip(action.numpy(), -self.act_limit, self.act_limit) + + return action, value.numpy(), cost_value.numpy(), logp_a.numpy() diff --git a/omnisafe/models/critic/q_critic.py b/omnisafe/models/critic/q_critic.py index fed091147..2b4db763e 100644 --- a/omnisafe/models/critic/q_critic.py +++ b/omnisafe/models/critic/q_critic.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Implementation of QCritic.""" +from typing import Optional import torch import torch.nn as nn @@ -59,7 +60,7 @@ def __init__( def forward( self, obs: torch.Tensor, - act: torch.Tensor, + act: Optional[torch.Tensor] = None, ): """Forward.""" obs = self.obs_encoder(obs) diff --git a/omnisafe/models/critic/v_critic.py b/omnisafe/models/critic/v_critic.py index 29072d9d9..82cf0795c 100644 --- a/omnisafe/models/critic/v_critic.py +++ b/omnisafe/models/critic/v_critic.py @@ -61,6 +61,7 @@ def __init__( def forward( self, obs: torch.Tensor, + act: torch.Tensor = None, ) -> torch.Tensor: """Forward.""" return torch.squeeze(self.net(obs), -1) diff --git a/omnisafe/utils/vtrace.py b/omnisafe/utils/vtrace.py index d88874140..b9ed68cb2 100644 --- a/omnisafe/utils/vtrace.py +++ b/omnisafe/utils/vtrace.py @@ -48,14 +48,12 @@ def calculate_v_trace( assert c_bar <= rho_bar sequence_length = policy_action_probs.shape[0] - # print('sequence_length:', sequence_length) + # pylint: disable-next=assignment-from-no-return rhos = np.divide(policy_action_probs, behavior_action_probs) - clip_rhos = np.minimum(rhos, rho_bar) - clip_cs = np.minimum(rhos, c_bar) - # values_plus_1 = np.append(values, bootstrap_value) + clip_rhos = np.minimum(rhos, rho_bar) # pylint: disable=assignment-from-no-return + clip_cs = np.minimum(rhos, c_bar) # pylint: disable=assignment-from-no-return v_s = np.copy(values[:-1]) # copy all values except bootstrap value - # v_s = np.zeros_like(values) last_v_s = values[-1] # bootstrap from last state # calculate v_s diff --git a/omnisafe/wrappers/__init__.py b/omnisafe/wrappers/__init__.py new file mode 100644 index 000000000..9eb8141a6 --- /dev/null +++ b/omnisafe/wrappers/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Environment wrappers.""" + +from omnisafe.wrappers.off_policy_wrapper import OffPolicyEnvWrapper +from omnisafe.wrappers.on_policy_wrapper import OnPolicyEnvWrapper diff --git a/omnisafe/algorithms/env_wrapper.py b/omnisafe/wrappers/env_wrapper.py similarity index 100% rename from omnisafe/algorithms/env_wrapper.py rename to omnisafe/wrappers/env_wrapper.py diff --git a/omnisafe/wrappers/off_policy_wrapper.py b/omnisafe/wrappers/off_policy_wrapper.py new file mode 100644 index 000000000..ee42c0f54 --- /dev/null +++ b/omnisafe/wrappers/off_policy_wrapper.py @@ -0,0 +1,151 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""env_wrapper""" + +import safety_gymnasium +import torch + +from omnisafe.wrappers.wrapper_registry import WRAPPER_REGISTRY + + +# pylint: disable=too-many-instance-attributes +@WRAPPER_REGISTRY.register +class OffPolicyEnvWrapper: + """OffPolicyEnvWrapper""" + + def __init__( + self, + env_id, + use_cost, + max_ep_len, + render_mode=None, + ): + # check env_id is str + self.env = safety_gymnasium.make(env_id, render_mode=render_mode) + self.env_id = env_id + self.render_mode = render_mode + self.metadata = self.env.metadata + self.use_cost = use_cost + + if hasattr(self.env, '_max_episode_steps'): + self.max_ep_len = self.env._max_episode_steps + else: + self.max_ep_len = max_ep_len + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self.seed = None + self.curr_o, _ = self.env.reset(seed=self.seed) + self.ep_ret = 0 + self.ep_cost = 0 + self.ep_len = 0 + # self.deterministic = False + self.local_steps_per_epoch = None + self.cost_gamma = None + self.use_cost = None + self.penalty_param = None + + def make(self): + """create environments""" + return self.env + + def reset(self, seed=None): + """reset environment""" + self.curr_o, info = self.env.reset(seed=seed) + return self.curr_o, info + + def render(self): + """render environment""" + return self.env.render() + + def set_seed(self, seed): + """set environment seed""" + self.seed = seed + + def set_rollout_cfgs(self, **kwargs): + """set rollout configs""" + for key, value in kwargs.items(): + setattr(self, key, value) + + def step(self, action): + """engine step""" + next_obs, reward, cost, terminated, truncated, info = self.env.step(action) + return next_obs, reward, cost, terminated, truncated, info + + # pylint: disable=too-many-arguments, too-many-locals + def roll_out( + self, + actor_critic, + buf, + logger, + deterministic, + use_rand_action, + ep_steps, + ): + """collect data and store to experience buffer.""" + for _ in range(ep_steps): + ep_ret = self.ep_ret + ep_len = self.ep_len + ep_cost = self.ep_cost + obs = self.curr_o + action, value, cost_value, _ = actor_critic.step( + torch.as_tensor(obs, dtype=torch.float32), deterministic=deterministic + ) + # Store values for statistic purpose + if self.use_cost: + logger.store(**{'Values/V': value, 'Values/C': cost_value}) + else: + logger.store(**{'Values/V': value}) + if use_rand_action: + action = self.env.action_space.sample() + # Step the env + # pylint: disable=unused-variable + obs_next, reward, cost, done, truncated, info = self.step(action) + ep_ret += reward + ep_cost += cost + ep_len += 1 + self.ep_len = ep_len + self.ep_ret = ep_ret + self.ep_cost = ep_cost + # Ignore the "done" signal if it comes from hitting the time + # horizon (that is, when it's an artificial terminal signal + # that isn't based on the agent's state) + self.curr_o = obs_next + if not deterministic: + done = False if ep_len >= self.max_ep_len else done + buf.store(obs, action, reward, cost, obs_next, done) + if done or ep_len >= self.max_ep_len: + logger.store( + **{ + 'Metrics/EpRet': ep_ret, + 'Metrics/EpLen': ep_len, + 'Metrics/EpCosts': ep_cost, + } + ) + self.curr_o, _ = self.env.reset(seed=self.seed) + self.ep_ret, self.ep_cost, self.ep_len = 0, 0, 0 + + else: + if done or ep_len >= self.max_ep_len: + logger.store( + **{ + 'Test/EpRet': ep_ret, + 'Test/EpLen': ep_len, + 'Test/EpCosts': ep_cost, + } + ) + self.curr_o, _ = self.env.reset(seed=self.seed) + self.ep_ret, self.ep_cost, self.ep_len = 0, 0, 0 + + return diff --git a/omnisafe/wrappers/on_policy_wrapper.py b/omnisafe/wrappers/on_policy_wrapper.py new file mode 100644 index 000000000..8284c8d2c --- /dev/null +++ b/omnisafe/wrappers/on_policy_wrapper.py @@ -0,0 +1,138 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""env_wrapper""" + +import safety_gymnasium +import torch + +from omnisafe.wrappers.wrapper_registry import WRAPPER_REGISTRY + + +@WRAPPER_REGISTRY.register +class OnPolicyEnvWrapper: # pylint: disable=too-many-instance-attributes + """env_wrapper""" + + def __init__(self, env_id, render_mode=None): + # check env_id is str + self.env = safety_gymnasium.make(env_id, render_mode=render_mode) + self.env_id = env_id + self.render_mode = render_mode + self.metadata = self.env.metadata + + if hasattr(self.env, '_max_episode_steps'): + self.max_ep_len = self.env._max_episode_steps + else: + self.max_ep_len = 1000 + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self.seed = None + self.curr_o, _ = self.env.reset(seed=self.seed) + self.rand_a = True + self.ep_steps = 1000 + self.ep_ret = 0 + self.ep_costs = 0 + self.ep_len = 0 + self.deterministic = False + self.local_steps_per_epoch = None + self.cost_gamma = None + self.use_cost = None + self.penalty_param = None + + def make(self): + """create environments""" + return self.env + + def reset(self, seed=None): + """reset environment""" + self.curr_o, info = self.env.reset(seed=seed) + return self.curr_o, info + + def render(self): + """render environment""" + return self.env.render() + + def set_seed(self, seed): + """set environment seed""" + self.seed = seed + + def set_rollout_cfgs(self, **kwargs): + """set rollout configs""" + for key, value in kwargs.items(): + setattr(self, key, value) + + def step(self, action): + """engine step""" + next_obs, reward, cost, terminated, truncated, info = self.env.step(action) + return next_obs, reward, cost, terminated, truncated, info + + # pylint: disable-next=too-many-locals + def roll_out(self, agent, buf, logger): + """collect data and store to experience buffer.""" + obs, _ = self.env.reset() + ep_ret, ep_costs, ep_len = 0.0, 0.0, 0 + for step_i in range(self.local_steps_per_epoch): + action, value, cost_value, logp = agent.step(torch.as_tensor(obs, dtype=torch.float32)) + next_obs, reward, cost, done, truncated, _ = self.step(action) + ep_ret += reward + ep_costs += (self.cost_gamma**ep_len) * cost + ep_len += 1 + + # Save and log + # Notes: + # - raw observations are stored to buffer (later transformed) + # - reward scaling is performed in buffer + buf.store( + obs=obs, + act=action, + rew=reward, + val=value, + logp=logp, + cost=cost, + cost_val=cost_value, + ) + + # Store values for statistic purpose + if self.use_cost: + logger.store(**{'Values/V': value, 'Values/C': cost_value}) + else: + logger.store(**{'Values/V': value}) + + # Update observation + obs = next_obs + + timeout = ep_len == self.max_ep_len + terminal = done or timeout or truncated + epoch_ended = step_i == self.local_steps_per_epoch - 1 + + if terminal or epoch_ended: + if timeout or epoch_ended: + _, value, cost_value, _ = agent(torch.as_tensor(obs, dtype=torch.float32)) + else: + value, cost_value = 0.0, 0.0 + + # Automatically compute GAE in buffer + buf.finish_path(value, cost_value, penalty_param=float(self.penalty_param)) + + # Only save EpRet / EpLen if trajectory finished + if terminal: + logger.store( + **{ + 'Metrics/EpRet': ep_ret, + 'Metrics/EpLen': ep_len, + 'Metrics/EpCost': ep_costs, + } + ) + ep_ret, ep_costs, ep_len = 0.0, 0.0, 0 + obs, _ = self.env.reset() diff --git a/omnisafe/wrappers/wrapper_registry.py b/omnisafe/wrappers/wrapper_registry.py new file mode 100644 index 000000000..15576a86f --- /dev/null +++ b/omnisafe/wrappers/wrapper_registry.py @@ -0,0 +1,72 @@ +# Copyright 2022 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for algorithms.""" + +import inspect + + +class WrapperRegistry: + """A registry to map strings to classes. + Args: + name (str): Registry name. + """ + + def __init__(self, name): + self._name = name + self._module_dict = {} + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f'(name={self._name}, items={list(self._module_dict.keys())})' + ) + return format_str + + @property + def name(self): + """Return the name of the registry.""" + return self._name + + @property + def module_dict(self): + """Return a dict mapping names to classes.""" + return self._module_dict + + def get(self, key): + """Get the class that has been registered under the given key.""" + return self._module_dict.get(key, None) + + def _register_module(self, module_class): + """Register a module. + Args: + module (:obj:`nn.Module`): Module to be registered. + """ + if not inspect.isclass(module_class): + raise TypeError(f'module must be a class, but got {type(module_class)}') + module_name = module_class.__name__ + if module_name in self._module_dict: + raise KeyError(f'{module_name} is already registered in {self.name}') + self._module_dict[module_name] = module_class + + def register(self, cls): + """Register a module class.""" + self._register_module(cls) + return cls + + +WRAPPER_REGISTRY = WrapperRegistry('OmniSafe-Wrappers') + + +register = WRAPPER_REGISTRY.register +get = WRAPPER_REGISTRY.get diff --git a/tests/test_policy.py b/tests/test_policy.py index fc986f388..a0d329bbc 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -37,7 +37,6 @@ def test_on_policy(algo): """Test algorithms""" env_id = 'SafetyPointGoal1-v0' - custom_cfgs = {'epochs': 1, 'steps_per_epoch': 1000, 'pi_iters': 1, 'critic_iters': 1} - env = omnisafe.Env(env_id) - agent = omnisafe.Agent(algo, env, custom_cfgs=custom_cfgs, parallel=1) + custom_cfgs = {'epochs': 1, 'steps_per_epoch': 2000, 'pi_iters': 1, 'critic_iters': 1} + agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs, parallel=1) agent.learn() diff --git a/tests/test_safety_gym_envs.py b/tests/test_safety_gym_envs.py index f71c031e4..6ca2c4863 100644 --- a/tests/test_safety_gym_envs.py +++ b/tests/test_safety_gym_envs.py @@ -16,6 +16,7 @@ import helpers import omnisafe +from omnisafe.wrappers.env_wrapper import EnvWrapper as Env @helpers.parametrize( @@ -30,7 +31,6 @@ def test_on_policy(algo, agent_id, env_id, level): # env_id = 'PointGoal1' custom_cfgs = {'epochs': 1, 'steps_per_epoch': 1000, 'pi_iters': 1, 'critic_iters': 1} - env = omnisafe.Env(env_id) - agent = omnisafe.Agent(algo, env, custom_cfgs=custom_cfgs, parallel=1) + agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs, parallel=1) # agent.set_seed(seed=0) agent.learn()