Skip to content

Commit

Permalink
fix(algo): fix no return in algo_wrapper::learn (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
rockmagma02 authored and zmsn-2077 committed Mar 15, 2023
1 parent 555acbb commit 77807ee
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 9 deletions.
8 changes: 8 additions & 0 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,11 @@ def reset(self) -> Tuple[torch.Tensor, Dict]:
info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
return self._env.reset()

def save(self) -> Dict[str, torch.nn.Module]:
"""Save the environment.
Returns:
Dict[str, torch.nn.Module]: the saved environment.
"""
return self._env.save()
5 changes: 1 addition & 4 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def learn(self):
env_id=self.env_id,
cfgs=cfgs,
)
agent.learn()
ep_ret = agent.logger.get_stats('Metrics/EpRet')
ep_len = agent.logger.get_stats('Metrics/EpLen')
ep_cost = agent.logger.get_stats('Metrics/EpCost')
ep_ret, ep_cost, ep_len = agent.learn()
return ep_ret, ep_len, ep_cost

# def evaluate(self, num_episodes: int = 10, horizon: int = 1000, cost_criteria: float = 1.0):
Expand Down
3 changes: 2 additions & 1 deletion omnisafe/algorithms/base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Implementation of the Policy Gradient algorithm."""

from abc import ABC, abstractmethod
from typing import Tuple, Union

import torch

Expand Down Expand Up @@ -63,5 +64,5 @@ def _init_log(self) -> None:
"""Initialize the logger."""

@abstractmethod
def learn(self) -> None:
def learn(self) -> Tuple[Union[int, float], ...]:
"""Learn the policy."""
12 changes: 10 additions & 2 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Implementation of the Policy Gradient algorithm."""

import time
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -94,9 +94,12 @@ def _init_log(self) -> None:
config=self._cfgs,
)

obs_normalizer = self._env.save()['obs_normalizer']
what_to_save = {
'pi': self._actor_critic.actor,
'obs_normalizer': obs_normalizer,
}

self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

Expand Down Expand Up @@ -134,7 +137,7 @@ def _init_log(self) -> None:
self._logger.register_key('Time/Epoch')
self._logger.register_key('Time/FPS')

def learn(self) -> None:
def learn(self) -> Tuple[Union[int, float], ...]:
"""This is main function for algorithm update, divided into the following steps:
- :meth:`rollout`: collect interactive data from environment.
Expand Down Expand Up @@ -184,8 +187,13 @@ def learn(self) -> None:
if (epoch + 1) % self._cfgs.save_freq == 0:
self._logger.torch_save()

ep_ret = self._logger.get_stats('Metrics/EpRet')[0]
ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
ep_len = self._logger.get_stats('Metrics/EpLen')[0]
self._logger.close()

return ep_ret, ep_cost, ep_len

def _update(self) -> None:
data = self._buf.get()
obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = (
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/on-policy/PDO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ defaults:
# Number of epochs
epochs: 500
# Number of steps per epoch
steps_per_epoch: 32000
steps_per_epoch: 32768
# Number of update iteration for Actor network
actor_iters: 10
# Number of update iteration for Critic network
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/on-policy/PPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ defaults:
# Number of epochs
epochs: 500
# Number of steps per epoch
steps_per_epoch: 32000
steps_per_epoch: 32768
# Number of update iteration for Actor network
actor_iters: 40
# Number of update iteration for Critic network
Expand Down
10 changes: 10 additions & 0 deletions omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def render(self) -> Any:
Any: the render frames, we recommend to use `np.ndarray` which could construct video by moviepy.
"""

def save(self) -> Dict[str, torch.nn.Module]:
"""Save the important components of the environment.
Returns:
Dict[str, torch.nn.Module]: the saved components.
"""
return {}

@abstractmethod
def close(self) -> None:
"""Close the environment."""
Expand Down Expand Up @@ -230,6 +237,9 @@ def sample_action(self) -> torch.Tensor:
def render(self) -> Any:
return self._env.render()

def save(self) -> Dict[str, torch.nn.Module]:
return self._env.save()

def close(self) -> None:
self._env.close()

Expand Down
15 changes: 15 additions & 0 deletions omnisafe/envs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def single_reset(self, idx: int, seed: Optional[int] = None) -> Tuple[torch.Tens
obs = self._obs_normalizer.normalize(obs.unsqueeze(0)).squeeze(0)
return obs, info

def save(self) -> Dict[str, torch.nn.Module]:
saved = super().save()
saved['obs_normalizer'] = self._obs_normalizer
return saved


class RewardNormalize(Wrapper):
"""Normalize the reward.
Expand Down Expand Up @@ -166,6 +171,11 @@ def step(
reward = self._reward_normalizer.normalize(reward)
return obs, reward, cost, terminated, truncated, info

def save(self) -> Dict[str, torch.nn.Module]:
saved = super().save()
saved['reward_normalizer'] = self._reward_normalizer
return saved


class CostNormalize(Wrapper):
"""Normalize the cost.
Expand Down Expand Up @@ -198,6 +208,11 @@ def step(
cost = self._cost_normalizer.normalize(cost)
return obs, reward, cost, terminated, truncated, info

def save(self) -> Dict[str, torch.nn.Module]:
saved = super().save()
saved['cost_normalizer'] = self._cost_normalizer
return saved


class ActionScale(Wrapper):
"""Scale the action space to a given range.
Expand Down
49 changes: 49 additions & 0 deletions omnisafe/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Evaluator."""


class Evaluator: # pylint: disable=too-many-instance-attributes
"""This class includes common evaluation methods for safe RL algorithms."""

def __init__(self) -> None:
pass

def load_saved_model(self, save_dir: str, model_name: str) -> None:
"""Load saved model from save_dir.
Args:
save_dir (str): The directory of saved model.
model_name (str): The name of saved model.
"""

def load_running_model(self, env, actor) -> None:
"""Load running model from env and actor.
Args:
env (gym.Env): The environment.
actor (omnisafe.actor.Actor): The actor.
"""

def evaluate(self, num_episode: int, render: bool = False) -> None:
"""Evaluate the model.
Args:
num_episode (int): The number of episodes to evaluate.
render (bool): Whether to render the environment.
"""

0 comments on commit 77807ee

Please sign in to comment.