From 77807eec54573fef14be4d7a748f82f19582d645 Mon Sep 17 00:00:00 2001 From: Ruiyang Sun Date: Mon, 27 Feb 2023 22:06:17 +0800 Subject: [PATCH] fix(algo): fix no return in algo_wrapper::learn (#122) --- omnisafe/adapter/online_adapter.py | 8 +++ omnisafe/algorithms/algo_wrapper.py | 5 +- omnisafe/algorithms/base_algo.py | 3 +- .../on_policy/base/policy_gradient.py | 12 ++++- omnisafe/configs/on-policy/PDO.yaml | 2 +- omnisafe/configs/on-policy/PPO.yaml | 2 +- omnisafe/envs/core.py | 10 ++++ omnisafe/envs/wrapper.py | 15 ++++++ omnisafe/evaluator.py | 49 +++++++++++++++++++ 9 files changed, 97 insertions(+), 9 deletions(-) create mode 100644 omnisafe/evaluator.py diff --git a/omnisafe/adapter/online_adapter.py b/omnisafe/adapter/online_adapter.py index f2439b508..ba9277c4b 100644 --- a/omnisafe/adapter/online_adapter.py +++ b/omnisafe/adapter/online_adapter.py @@ -123,3 +123,11 @@ def reset(self) -> Tuple[torch.Tensor, Dict]: info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning). """ return self._env.reset() + + def save(self) -> Dict[str, torch.nn.Module]: + """Save the environment. + + Returns: + Dict[str, torch.nn.Module]: the saved environment. + """ + return self._env.save() diff --git a/omnisafe/algorithms/algo_wrapper.py b/omnisafe/algorithms/algo_wrapper.py index 1a75fc308..8293126d1 100644 --- a/omnisafe/algorithms/algo_wrapper.py +++ b/omnisafe/algorithms/algo_wrapper.py @@ -96,10 +96,7 @@ def learn(self): env_id=self.env_id, cfgs=cfgs, ) - agent.learn() - ep_ret = agent.logger.get_stats('Metrics/EpRet') - ep_len = agent.logger.get_stats('Metrics/EpLen') - ep_cost = agent.logger.get_stats('Metrics/EpCost') + ep_ret, ep_cost, ep_len = agent.learn() return ep_ret, ep_len, ep_cost # def evaluate(self, num_episodes: int = 10, horizon: int = 1000, cost_criteria: float = 1.0): diff --git a/omnisafe/algorithms/base_algo.py b/omnisafe/algorithms/base_algo.py index d28282115..a1113de5b 100644 --- a/omnisafe/algorithms/base_algo.py +++ b/omnisafe/algorithms/base_algo.py @@ -15,6 +15,7 @@ """Implementation of the Policy Gradient algorithm.""" from abc import ABC, abstractmethod +from typing import Tuple, Union import torch @@ -63,5 +64,5 @@ def _init_log(self) -> None: """Initialize the logger.""" @abstractmethod - def learn(self) -> None: + def learn(self) -> Tuple[Union[int, float], ...]: """Learn the policy.""" diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index 345508ac8..a63fd5411 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -15,7 +15,7 @@ """Implementation of the Policy Gradient algorithm.""" import time -from typing import Dict, Tuple +from typing import Dict, Tuple, Union import torch import torch.nn as nn @@ -94,9 +94,12 @@ def _init_log(self) -> None: config=self._cfgs, ) + obs_normalizer = self._env.save()['obs_normalizer'] what_to_save = { 'pi': self._actor_critic.actor, + 'obs_normalizer': obs_normalizer, } + self._logger.setup_torch_saver(what_to_save) self._logger.torch_save() @@ -134,7 +137,7 @@ def _init_log(self) -> None: self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS') - def learn(self) -> None: + def learn(self) -> Tuple[Union[int, float], ...]: """This is main function for algorithm update, divided into the following steps: - :meth:`rollout`: collect interactive data from environment. @@ -184,8 +187,13 @@ def learn(self) -> None: if (epoch + 1) % self._cfgs.save_freq == 0: self._logger.torch_save() + ep_ret = self._logger.get_stats('Metrics/EpRet')[0] + ep_cost = self._logger.get_stats('Metrics/EpCost')[0] + ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() + return ep_ret, ep_cost, ep_len + def _update(self) -> None: data = self._buf.get() obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = ( diff --git a/omnisafe/configs/on-policy/PDO.yaml b/omnisafe/configs/on-policy/PDO.yaml index 7b64a8564..f6cef7ac9 100644 --- a/omnisafe/configs/on-policy/PDO.yaml +++ b/omnisafe/configs/on-policy/PDO.yaml @@ -31,7 +31,7 @@ defaults: # Number of epochs epochs: 500 # Number of steps per epoch - steps_per_epoch: 32000 + steps_per_epoch: 32768 # Number of update iteration for Actor network actor_iters: 10 # Number of update iteration for Critic network diff --git a/omnisafe/configs/on-policy/PPO.yaml b/omnisafe/configs/on-policy/PPO.yaml index df1a2f18b..4df9beff5 100644 --- a/omnisafe/configs/on-policy/PPO.yaml +++ b/omnisafe/configs/on-policy/PPO.yaml @@ -31,7 +31,7 @@ defaults: # Number of epochs epochs: 500 # Number of steps per epoch - steps_per_epoch: 32000 + steps_per_epoch: 32768 # Number of update iteration for Actor network actor_iters: 40 # Number of update iteration for Critic network diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index cd92b3205..52489f8ac 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -172,6 +172,13 @@ def render(self) -> Any: Any: the render frames, we recommend to use `np.ndarray` which could construct video by moviepy. """ + def save(self) -> Dict[str, torch.nn.Module]: + """Save the important components of the environment. + Returns: + Dict[str, torch.nn.Module]: the saved components. + """ + return {} + @abstractmethod def close(self) -> None: """Close the environment.""" @@ -230,6 +237,9 @@ def sample_action(self) -> torch.Tensor: def render(self) -> Any: return self._env.render() + def save(self) -> Dict[str, torch.nn.Module]: + return self._env.save() + def close(self) -> None: self._env.close() diff --git a/omnisafe/envs/wrapper.py b/omnisafe/envs/wrapper.py index 632819020..0a774a1d8 100644 --- a/omnisafe/envs/wrapper.py +++ b/omnisafe/envs/wrapper.py @@ -132,6 +132,11 @@ def single_reset(self, idx: int, seed: Optional[int] = None) -> Tuple[torch.Tens obs = self._obs_normalizer.normalize(obs.unsqueeze(0)).squeeze(0) return obs, info + def save(self) -> Dict[str, torch.nn.Module]: + saved = super().save() + saved['obs_normalizer'] = self._obs_normalizer + return saved + class RewardNormalize(Wrapper): """Normalize the reward. @@ -166,6 +171,11 @@ def step( reward = self._reward_normalizer.normalize(reward) return obs, reward, cost, terminated, truncated, info + def save(self) -> Dict[str, torch.nn.Module]: + saved = super().save() + saved['reward_normalizer'] = self._reward_normalizer + return saved + class CostNormalize(Wrapper): """Normalize the cost. @@ -198,6 +208,11 @@ def step( cost = self._cost_normalizer.normalize(cost) return obs, reward, cost, terminated, truncated, info + def save(self) -> Dict[str, torch.nn.Module]: + saved = super().save() + saved['cost_normalizer'] = self._cost_normalizer + return saved + class ActionScale(Wrapper): """Scale the action space to a given range. diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py new file mode 100644 index 000000000..0b73669a7 --- /dev/null +++ b/omnisafe/evaluator.py @@ -0,0 +1,49 @@ +# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Evaluator.""" + + +class Evaluator: # pylint: disable=too-many-instance-attributes + """This class includes common evaluation methods for safe RL algorithms.""" + + def __init__(self) -> None: + pass + + def load_saved_model(self, save_dir: str, model_name: str) -> None: + """Load saved model from save_dir. + + Args: + save_dir (str): The directory of saved model. + model_name (str): The name of saved model. + + """ + + def load_running_model(self, env, actor) -> None: + """Load running model from env and actor. + + Args: + env (gym.Env): The environment. + actor (omnisafe.actor.Actor): The actor. + + """ + + def evaluate(self, num_episode: int, render: bool = False) -> None: + """Evaluate the model. + + Args: + num_episode (int): The number of episodes to evaluate. + render (bool): Whether to render the environment. + + """