Skip to content

Commit

Permalink
feat(wrapper): separated wrapper for different algorithmic environmen…
Browse files Browse the repository at this point in the history
…ts (#44)
  • Loading branch information
zmsn-2077 authored and XuehaiPan committed Dec 23, 2022
1 parent d4cd28b commit d1e171e
Show file tree
Hide file tree
Showing 39 changed files with 1,350 additions and 68 deletions.
11 changes: 11 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,14 @@ FOCOPS
Kakade
QCritic
yaml
polyak
MSE
Daan
Wierstra
Pritzel
Heess
mul
logprob
Tanh
Eq
chol
9 changes: 7 additions & 2 deletions examples/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
keys = [k[2:] for k in unparsed_args[0::2]]
values = list(unparsed_args[1::2])
unparsed_dict = dict(zip(keys, values))
env = omnisafe.Env(args.env_id)
agent = omnisafe.Agent(args.algo, env, parallel=args.parallel, custom_cfgs=unparsed_dict)
# env = omnisafe.Env(args.env_id)
agent = omnisafe.Agent(
args.algo,
args.env_id,
parallel=args.parallel,
custom_cfgs=unparsed_dict,
)
agent.learn()
5 changes: 3 additions & 2 deletions omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning."""

from omnisafe.algo_wrapper import AlgoWrapper as Agent
from omnisafe.algorithms.env_wrapper import EnvWrapper as Env
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent

# from omnisafe.algorithms.env_wrapper import EnvWrapper as Env
from omnisafe.version import __version__
5 changes: 5 additions & 0 deletions omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
# ==============================================================================
"""Safe Reinforcement Learning algorithms."""

# Off Policy Safe
from omnisafe.algorithms.off_policy.ddpg import DDPG

# On Policy Safe
from omnisafe.algorithms.on_policy.cpo import CPO
from omnisafe.algorithms.on_policy.cppo_pid import CPPOPid
from omnisafe.algorithms.on_policy.cup import CUP
from omnisafe.algorithms.on_policy.focops import FOCOPS
from omnisafe.algorithms.on_policy.natural_pg import NaturalPG
from omnisafe.algorithms.on_policy.npg_lag import NPGLag
Expand Down Expand Up @@ -45,6 +49,7 @@
'PPOLag',
'TRPO',
'TRPOLag',
'CUP',
],
'model-based': ['MBPPOLag', 'SafeLoop'],
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
class AlgoWrapper:
"""Algo Wrapper for algo"""

def __init__(self, algo, env, parallel=1, custom_cfgs=None):
def __init__(self, algo, env_id, parallel=1, custom_cfgs=None):
self.algo = algo
self.env = env
self.parallel = parallel
self.env_id = env.env_id
self.env_id = env_id
# algo_type will set in _init_checks()
self.algo_type = None
self.custom_cfgs = custom_cfgs
Expand Down Expand Up @@ -69,12 +68,12 @@ def learn(self):
sys.exit()

default_cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type)
exp_name = os.path.join(self.env.env_id, self.algo)
exp_name = os.path.join(self.env_id, self.algo)
default_cfgs.update(exp_name=exp_name, env_id=self.env_id)
cfgs = recursive_update(default_cfgs, self.custom_cfgs)
check_all_configs(cfgs, self.algo_type)
agent = registry.get(self.algo)(
env=self.env,
env_id=self.env_id,
cfgs=cfgs,
)
agent.learn()
Expand Down
Loading

0 comments on commit d1e171e

Please sign in to comment.