From 79b620aeef190a580fe3e6ca615e9a0a478cdd1a Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 11 Dec 2024 17:32:30 +0100 Subject: [PATCH] [RLlib; docs] Docs do-over (new API stack): Env pages vol 01. (#49165) Signed-off-by: ujjawal-khare --- doc/source/rllib/images/envs/env_runners.svg | 1 + .../rllib/images/envs/external_env_logo.svg | 1 + .../images/envs/hierarchical_env_logo.svg | 1 + .../images/envs/multi_agent_env_logo.svg | 1 + .../envs/multi_agent_episode_simultaneous.svg | 1 + .../envs/multi_agent_episode_turn_based.svg | 1 + .../rllib/images/envs/multi_agent_setup.svg | 1 + .../images/envs/single_agent_env_logo.svg | 1 + .../rllib/images/envs/single_agent_setup.svg | 1 + doc/source/rllib/single-agent-episode.rst | 4 +- rllib/BUILD | 22 +- rllib/env/multi_agent_env.py | 19 +- rllib/env/multi_agent_episode.py | 286 ++++--------- rllib/examples/centralized_critic.py | 2 +- rllib/examples/envs/agents_act_in_sequence.py | 87 ++++ .../envs/agents_act_simultaneously.py | 108 +++++ .../classes/coin_game_non_vectorized_env.py | 344 ---------------- .../envs/classes/coin_game_vectorized_env.py | 386 ------------------ .../matrix_sequential_social_dilemma.py | 314 -------------- .../envs/classes/multi_agent/__init__.py | 35 ++ .../{ => multi_agent}/bandit_envs_discrete.py | 0 .../bandit_envs_recommender_system.py | 0 .../guess_the_number_game.py} | 23 +- .../{ => multi_agent}/pettingzoo_chess.py | 0 .../{ => multi_agent}/pettingzoo_connect4.py | 0 .../multi_agent/rock_paper_scissors.py | 125 ++++++ .../envs/classes/multi_agent/tic_tac_toe.py | 144 +++++++ .../{ => multi_agent}/two_step_game.py | 0 .../examples/envs/classes/utils/interfaces.py | 23 -- rllib/examples/envs/classes/utils/mixins.py | 72 ---- ...ock_paper_scissors_heuristic_vs_learned.py | 4 +- .../rock_paper_scissors_learned_vs_learned.py | 4 +- .../two_step_game_with_grouped_agents.py | 4 +- 33 files changed, 636 insertions(+), 1379 deletions(-) create mode 100644 doc/source/rllib/images/envs/env_runners.svg create mode 100644 doc/source/rllib/images/envs/external_env_logo.svg create mode 100644 doc/source/rllib/images/envs/hierarchical_env_logo.svg create mode 100644 doc/source/rllib/images/envs/multi_agent_env_logo.svg create mode 100644 doc/source/rllib/images/envs/multi_agent_episode_simultaneous.svg create mode 100644 doc/source/rllib/images/envs/multi_agent_episode_turn_based.svg create mode 100644 doc/source/rllib/images/envs/multi_agent_setup.svg create mode 100644 doc/source/rllib/images/envs/single_agent_env_logo.svg create mode 100644 doc/source/rllib/images/envs/single_agent_setup.svg create mode 100644 rllib/examples/envs/agents_act_in_sequence.py create mode 100644 rllib/examples/envs/agents_act_simultaneously.py delete mode 100644 rllib/examples/envs/classes/coin_game_non_vectorized_env.py delete mode 100644 rllib/examples/envs/classes/coin_game_vectorized_env.py delete mode 100644 rllib/examples/envs/classes/matrix_sequential_social_dilemma.py create mode 100644 rllib/examples/envs/classes/multi_agent/__init__.py rename rllib/examples/envs/classes/{ => multi_agent}/bandit_envs_discrete.py (100%) rename rllib/examples/envs/classes/{ => multi_agent}/bandit_envs_recommender_system.py (100%) rename rllib/examples/envs/classes/{multi_agent.py => multi_agent/guess_the_number_game.py} (78%) rename rllib/examples/envs/classes/{ => multi_agent}/pettingzoo_chess.py (100%) rename rllib/examples/envs/classes/{ => multi_agent}/pettingzoo_connect4.py (100%) create mode 100644 rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py create mode 100644 rllib/examples/envs/classes/multi_agent/tic_tac_toe.py rename rllib/examples/envs/classes/{ => multi_agent}/two_step_game.py (100%) delete mode 100644 rllib/examples/envs/classes/utils/interfaces.py delete mode 100644 rllib/examples/envs/classes/utils/mixins.py diff --git a/doc/source/rllib/images/envs/env_runners.svg b/doc/source/rllib/images/envs/env_runners.svg new file mode 100644 index 000000000000..fda0dc654f9d --- /dev/null +++ b/doc/source/rllib/images/envs/env_runners.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/external_env_logo.svg b/doc/source/rllib/images/envs/external_env_logo.svg new file mode 100644 index 000000000000..d445af31638d --- /dev/null +++ b/doc/source/rllib/images/envs/external_env_logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/hierarchical_env_logo.svg b/doc/source/rllib/images/envs/hierarchical_env_logo.svg new file mode 100644 index 000000000000..d184b4c81771 --- /dev/null +++ b/doc/source/rllib/images/envs/hierarchical_env_logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/multi_agent_env_logo.svg b/doc/source/rllib/images/envs/multi_agent_env_logo.svg new file mode 100644 index 000000000000..8fb72d6f0185 --- /dev/null +++ b/doc/source/rllib/images/envs/multi_agent_env_logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/multi_agent_episode_simultaneous.svg b/doc/source/rllib/images/envs/multi_agent_episode_simultaneous.svg new file mode 100644 index 000000000000..6da2e8b8a12a --- /dev/null +++ b/doc/source/rllib/images/envs/multi_agent_episode_simultaneous.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/multi_agent_episode_turn_based.svg b/doc/source/rllib/images/envs/multi_agent_episode_turn_based.svg new file mode 100644 index 000000000000..4e555590ad4b --- /dev/null +++ b/doc/source/rllib/images/envs/multi_agent_episode_turn_based.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/multi_agent_setup.svg b/doc/source/rllib/images/envs/multi_agent_setup.svg new file mode 100644 index 000000000000..dc5e9294a2a6 --- /dev/null +++ b/doc/source/rllib/images/envs/multi_agent_setup.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/single_agent_env_logo.svg b/doc/source/rllib/images/envs/single_agent_env_logo.svg new file mode 100644 index 000000000000..f2d8512a2421 --- /dev/null +++ b/doc/source/rllib/images/envs/single_agent_env_logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/envs/single_agent_setup.svg b/doc/source/rllib/images/envs/single_agent_setup.svg new file mode 100644 index 000000000000..6413c696721f --- /dev/null +++ b/doc/source/rllib/images/envs/single_agent_setup.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/single-agent-episode.rst b/doc/source/rllib/single-agent-episode.rst index 87fa50790d45..4c2a1f89d847 100644 --- a/doc/source/rllib/single-agent-episode.rst +++ b/doc/source/rllib/single-agent-episode.rst @@ -139,14 +139,14 @@ episodes (one non-finalized the other finalized): :align: left **Complex observations in a non-finalized episode**: Each individual observation is a (complex) dict matching the - gym environment's observation space. There are three such observation items stored in the episode so far. + gymnasium environment's observation space. There are three such observation items stored in the episode so far. .. figure:: images/episodes/sa_episode_finalized.svg :width: 600 :align: left **Complex observations in a finalized episode**: The entire observation record is a single (complex) dict matching the - gym environment's observation space. At the leafs of the structure are `NDArrays` holding the individual values of the leaf. + gymnasium environment's observation space. At the leafs of the structure are `NDArrays` holding the individual values of the leaf. Note that these `NDArrays` have an extra batch dim (axis=0), whose length matches the length of the episode stored (here 3). diff --git a/rllib/BUILD b/rllib/BUILD index 1bc105d96fc7..1592a8bb4222 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2362,6 +2362,22 @@ py_test( # subdirectory: envs/ # .................................... +py_test( + name = "examples/envs/agents_act_simultaneously", + main = "examples/envs/agents_act_simultaneously.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/envs/agents_act_simultaneously.py"], + args = ["--enable-new-api-stack", "--num-agents=2", "--stop-iters=3"] +) +py_test( + name = "examples/envs/agents_act_in_sequence", + main = "examples/envs/agents_act_in_sequence.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/envs/agents_act_in_sequence.py"], + args = ["--enable-new-api-stack", "--num-agents=2", "--stop-iters=3"] +) py_test( name = "examples/envs/custom_env_render_method", main = "examples/envs/custom_env_render_method.py", @@ -2370,7 +2386,6 @@ py_test( srcs = ["examples/envs/custom_env_render_method.py"], args = ["--enable-new-api-stack", "--num-agents=0"] ) - py_test( name = "examples/envs/custom_env_render_method_multi_agent", main = "examples/envs/custom_env_render_method.py", @@ -2379,7 +2394,6 @@ py_test( srcs = ["examples/envs/custom_env_render_method.py"], args = ["--enable-new-api-stack", "--num-agents=2"] ) - py_test( name = "examples/envs/custom_gym_env", main = "examples/envs/custom_gym_env.py", @@ -2388,7 +2402,6 @@ py_test( srcs = ["examples/envs/custom_gym_env.py"], args = ["--enable-new-api-stack", "--as-test"] ) - py_test( name = "examples/envs/env_connecting_to_rllib_w_tcp_client", main = "examples/envs/env_connecting_to_rllib_w_tcp_client.py", @@ -2397,7 +2410,6 @@ py_test( srcs = ["examples/envs/env_connecting_to_rllib_w_tcp_client.py"], args = ["--enable-new-api-stack", "--as-test", "--port=12346"] ) - py_test( name = "examples/envs/env_rendering_and_recording", srcs = ["examples/envs/env_rendering_and_recording.py"], @@ -2405,7 +2417,6 @@ py_test( size = "medium", args = ["--enable-new-api-stack", "--env=CartPole-v1", "--stop-iters=2"] ) - py_test( name = "examples/envs/env_w_protobuf_observations", main = "examples/envs/env_w_protobuf_observations.py", @@ -2414,7 +2425,6 @@ py_test( srcs = ["examples/envs/env_w_protobuf_observations.py"], args = ["--enable-new-api-stack", "--as-test"] ) - #@OldAPIStack py_test( name = "examples/envs/greyscale_env", diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 6293ecd8f818..c21acec528c2 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -6,11 +6,7 @@ from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext -from ray.rllib.utils.annotations import ( - OldAPIStack, - override, - PublicAPI, -) +from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import ( AgentID, @@ -21,6 +17,7 @@ MultiEnvDict, ) from ray.util import log_once +from ray.util.annotations import DeveloperAPI, PublicAPI # If the obs space is Dict type, look for the global state under this key. ENV_STATE = "state" @@ -28,7 +25,7 @@ logger = logging.getLogger(__name__) -@PublicAPI +@PublicAPI(stability="beta") class MultiAgentEnv(gym.Env): """An environment that hosts multiple independent agents. @@ -166,11 +163,15 @@ def get_observation_space(self, agent_id: AgentID) -> gym.Space: return self.observation_spaces[agent_id] # @OldAPIStack behavior. + # `self.observation_space` is a `gym.spaces.Dict` AND contains `agent_id`. if ( isinstance(self.observation_space, gym.spaces.Dict) and agent_id in self.observation_space.spaces ): return self.observation_space[agent_id] + # `self.observation_space` is not a `gym.spaces.Dict` OR doesn't contain + # `agent_id` -> The defined space is most likely meant to be the space + # for all agents. else: return self.observation_space @@ -179,11 +180,15 @@ def get_action_space(self, agent_id: AgentID) -> gym.Space: return self.action_spaces[agent_id] # @OldAPIStack behavior. + # `self.action_space` is a `gym.spaces.Dict` AND contains `agent_id`. if ( isinstance(self.action_space, gym.spaces.Dict) and agent_id in self.action_space.spaces ): return self.action_space[agent_id] + # `self.action_space` is not a `gym.spaces.Dict` OR doesn't contain + # `agent_id` -> The defined space is most likely meant to be the space + # for all agents. else: return self.action_space @@ -321,7 +326,7 @@ def to_base_env( return env -@PublicAPI +@DeveloperAPI def make_multi_agent( env_name_or_creator: Union[str, EnvCreator], ) -> Type["MultiAgentEnv"]: diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 9f3b7240b586..7f9e41ed639c 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -75,9 +75,10 @@ class MultiAgentEpisode: "is_terminated", "is_truncated", "agent_episodes", - "_temporary_timestep_data", - "_start_time", "_last_step_time", + "_len_lookback_buffers", + "_start_time", + "_temporary_timestep_data", ) SKIP_ENV_TS_TAG = "S" @@ -206,6 +207,7 @@ def __init__( # lookback buffer. if len_lookback_buffer == "auto": len_lookback_buffer = len(rewards or []) + self._len_lookback_buffers = len_lookback_buffer self.observation_space = observation_space or {} self.action_space = action_space or {} @@ -218,7 +220,7 @@ def __init__( self.env_t_started = env_t_started or 0 self.env_t = ( (len(rewards) if rewards is not None else 0) - - len_lookback_buffer + - self._len_lookback_buffers + self.env_t_started ) self.agent_t_started = defaultdict(int, agent_t_started or {}) @@ -280,7 +282,6 @@ def __init__( terminateds=terminateds, truncateds=truncateds, extra_model_outputs=extra_model_outputs, - len_lookback_buffer=len_lookback_buffer, ) # Caches for temporary per-timestep data. May be used to store custom metrics @@ -489,7 +490,7 @@ def add_env_step( # ------------------------------------------------------------------------ # We have an observation, but no action -> # a) Action (and extra model outputs) must be hanging already. Also use - # collected hanging rewards. + # collected hanging rewards and extra_model_outputs. # b) The observation is the first observation for this agent ID. elif _observation is not None and _action is None: _action = self._hanging_actions_end.pop(agent_id, None) @@ -515,9 +516,13 @@ def add_env_step( # This must be the agent's initial observation. else: # Prepend n skip tags to this agent's mapping + the initial [0]. + assert agent_id not in self.env_t_to_agent_t self.env_t_to_agent_t[agent_id].extend( [self.SKIP_ENV_TS_TAG] * self.env_t + [0] ) + self.env_t_to_agent_t[ + agent_id + ].lookback = self._len_lookback_buffers # Make `add_env_reset` call and continue with next agent. sa_episode.add_env_reset(observation=_observation, infos=_infos) # Add possible reward to begin cache. @@ -956,24 +961,29 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": else slice(None, 0) # -> empty slice ) + observations = self.get_observations( + indices=indices_obs_and_infos, return_list=True + ) + infos = self.get_infos(indices=indices_obs_and_infos, return_list=True) + actions = self.get_actions(indices=indices_rest, return_list=True) + rewards = self.get_rewards(indices=indices_rest, return_list=True) + extra_model_outputs = self.get_extra_model_outputs( + key=None, # all keys + indices=indices_rest, + return_list=True, + ) successor = MultiAgentEpisode( # Same ID. id_=self.id_, - observations=self.get_observations( - indices=indices_obs_and_infos, return_list=True - ), + observations=observations, observation_space=self.observation_space, - infos=self.get_infos(indices=indices_obs_and_infos, return_list=True), - actions=self.get_actions(indices=indices_rest, return_list=True), + infos=infos, + actions=actions, action_space=self.action_space, - rewards=self.get_rewards(indices=indices_rest, return_list=True), + rewards=rewards, # List of MADicts, mapping agent IDs to their respective extra model output # dicts. - extra_model_outputs=self.get_extra_model_outputs( - key=None, # all keys - indices=indices_rest, - return_list=True, - ), + extra_model_outputs=extra_model_outputs, terminateds=self.get_terminateds(), truncateds=self.get_truncateds(), # Continue with `self`'s current timesteps. @@ -1955,7 +1965,6 @@ def _init_single_agent_episodes( terminateds: Union[MultiAgentDict, bool] = False, truncateds: Union[MultiAgentDict, bool] = False, extra_model_outputs: Optional[List[MultiAgentDict]] = None, - len_lookback_buffer: int, ): if observations is None: return @@ -1979,7 +1988,7 @@ def _init_single_agent_episodes( rewards_per_agent = defaultdict(list) extra_model_outputs_per_agent = defaultdict(list) done_per_agent = defaultdict(bool) - len_lookback_buffer_per_agent = defaultdict(lambda: len_lookback_buffer) + len_lookback_buffer_per_agent = defaultdict(lambda: self._len_lookback_buffers) all_agent_ids = set( agent_episode_ids.keys() if agent_episode_ids is not None else [] @@ -2058,7 +2067,11 @@ def _init_single_agent_episodes( # If we are still in the global lookback buffer segment, deduct 1 # from this agents' lookback buffer, b/c we don't want the agent # to use this (missing) obs/data in its single-agent lookback. - if len(self.env_t_to_agent_t[agent_id]) - len_lookback_buffer <= 0: + if ( + len(self.env_t_to_agent_t[agent_id]) + - self._len_lookback_buffers + <= 0 + ): len_lookback_buffer_per_agent[agent_id] -= 1 self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0) @@ -2076,7 +2089,7 @@ def _init_single_agent_episodes( == len(extra_model_outputs_per_agent[agent_id]) + 1 == len(rewards_per_agent[agent_id]) + 1 ) - self.env_t_to_agent_t[agent_id].lookback = len_lookback_buffer + self.env_t_to_agent_t[agent_id].lookback = self._len_lookback_buffers # Now create the individual episodes from the collected per-agent data. for agent_id, agent_obs in observations_per_agent.items(): @@ -2181,12 +2194,16 @@ def _get_data_by_agent_steps( one_hot_discrete, extra_model_outputs_key, ): + # Return requested data by agent-steps. ret = {} + # For each agent, we retrieve the data through passing the given indices into + # the SingleAgentEpisode of that agent. for agent_id, sa_episode in self.agent_episodes.items(): if agent_id not in agent_ids: continue inf_lookback_buffer = getattr(sa_episode, what) hanging_val = self._get_hanging_value(what, agent_id) + # User wants a specific `extra_model_outputs` key. if extra_model_outputs_key is not None: inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key] hanging_val = hanging_val[extra_model_outputs_key] @@ -2208,59 +2225,11 @@ def _get_data_by_env_steps_as_list( what: str, indices: Union[int, slice, List[int]], agent_ids: Collection[AgentID], - neg_index_as_lookback: bool = False, - fill: Optional[Any] = None, - one_hot_discrete: bool = False, - extra_model_outputs_key: Optional[str] = None, + neg_index_as_lookback: bool, + fill: Any, + one_hot_discrete, + extra_model_outputs_key: str, ) -> List[MultiAgentDict]: - """Returns data from the episode based on env step indices, as a list. - - The returned list contains n MultiAgentDict objects, one for each env timestep - defined via `indices`. - - Args: - what: A (str) descriptor of what data to collect. Must be one of - "observations", "infos", "actions", "rewards", or "extra_model_outputs". - indices: A single int is interpreted as an index, from which to return the - individual data stored at this (env step) index. - A list of ints is interpreted as a list of indices from which to gather - individual data in a batch of size len(indices). - A slice object is interpreted as a range of data to be returned. - Thereby, negative indices by default are interpreted as "before the end" - unless the `neg_index_as_lookback=True` option is used, in which case - negative indices are interpreted as "before ts=0", meaning going back - into the lookback buffer. - agent_ids: A collection of AgentIDs to filter for. Only data for those - agents will be returned, all other agents will be ignored. - neg_index_as_lookback: If True, negative values in `indices` are - interpreted as "before ts=0", meaning going back into the lookback - buffer. For example, a buffer with data [4, 5, 6, 7, 8, 9], - where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will - respond to `get(-1, neg_index_as_lookback=True)` with `6` and to - `get(slice(-2, 1), neg_index_as_lookback=True)` with `[5, 6, 7]`. - fill: An optional float value to use for filling up the returned results at - the boundaries. This filling only happens if the requested index range's - start/stop boundaries exceed the buffer's boundaries (including the - lookback buffer on the left side). This comes in very handy, if users - don't want to worry about reaching such boundaries and want to zero-pad. - For example, a buffer with data [10, 11, 12, 13, 14] and lookback - buffer size of 2 (meaning `10` and `11` are part of the lookback buffer) - will respond to `indices=slice(-7, -2)` and `fill=0.0` - with `[0.0, 0.0, 10, 11, 12]`. - one_hot_discrete: If True, will return one-hot vectors (instead of - int-values) for those sub-components of a (possibly complex) space - that are Discrete or MultiDiscrete. Note that if `fill=0` and the - requested `indices` are out of the range of our data, the returned - one-hot vectors will actually be zero-hot (all slots zero). - extra_model_outputs_key: Only if what is "extra_model_outputs", this - specifies the sub-key (str) inside the extra_model_outputs dict, e.g. - STATE_OUT or ACTION_DIST_INPUTS. - - Returns: - A list of MultiAgentDict, where each item in the list corresponds to one - env timestep defined via `indices`. - """ - # Collect indices for each agent first, so we can construct the list in # the next step. agent_indices = {} @@ -2293,7 +2262,11 @@ def _get_data_by_env_steps_as_list( hanging_val, filter_for_skip_indices=idxes[i], ) - if what == "extra_model_outputs" and not inf_lookback_buffer: + if ( + what == "extra_model_outputs" + and not inf_lookback_buffer + and not hanging_val + ): continue agent_value = self._get_single_agent_data_by_index( what=what, @@ -2302,8 +2275,8 @@ def _get_data_by_env_steps_as_list( index_incl_lookback=indices_to_use, fill=fill, one_hot_discrete=one_hot_discrete, - hanging_val=hanging_val, extra_model_outputs_key=extra_model_outputs_key, + hanging_val=hanging_val, ) if agent_value is not None: ret2[agent_id] = agent_value @@ -2316,60 +2289,11 @@ def _get_data_by_env_steps( what: str, indices: Union[int, slice, List[int]], agent_ids: Collection[AgentID], - neg_index_as_lookback: bool = False, - fill: Optional[Any] = None, - one_hot_discrete: bool = False, - extra_model_outputs_key: Optional[str] = None, + neg_index_as_lookback: bool, + fill: Any, + one_hot_discrete: bool, + extra_model_outputs_key: str, ) -> MultiAgentDict: - """Returns data from the episode based on env step indices, as a MultiAgentDict. - - The returned dict maps AgentID keys to individual or batched values, where the - batch size matches the env timesteps defined via `indices`. - - Args: - what: A (str) descriptor of what data to collect. Must be one of - "observations", "infos", "actions", "rewards", or "extra_model_outputs". - indices: A single int is interpreted as an index, from which to return the - individual data stored at this (env step) index. - A list of ints is interpreted as a list of indices from which to gather - individual data in a batch of size len(indices). - A slice object is interpreted as a range of data to be returned. - Thereby, negative indices by default are interpreted as "before the end" - unless the `neg_index_as_lookback=True` option is used, in which case - negative indices are interpreted as "before ts=0", meaning going back - into the lookback buffer. - agent_ids: A collection of AgentIDs to filter for. Only data for those - agents will be returned, all other agents will be ignored. - neg_index_as_lookback: If True, negative values in `indices` are - interpreted as "before ts=0", meaning going back into the lookback - buffer. For example, a buffer with data [4, 5, 6, 7, 8, 9], - where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will - respond to `get(-1, neg_index_as_lookback=True)` with `6` and to - `get(slice(-2, 1), neg_index_as_lookback=True)` with `[5, 6, 7]`. - fill: An optional float value to use for filling up the returned results at - the boundaries. This filling only happens if the requested index range's - start/stop boundaries exceed the buffer's boundaries (including the - lookback buffer on the left side). This comes in very handy, if users - don't want to worry about reaching such boundaries and want to zero-pad. - For example, a buffer with data [10, 11, 12, 13, 14] and lookback - buffer size of 2 (meaning `10` and `11` are part of the lookback buffer) - will respond to `indices=slice(-7, -2)` and `fill=0.0` - with `[0.0, 0.0, 10, 11, 12]`. - one_hot_discrete: If True, will return one-hot vectors (instead of - int-values) for those sub-components of a (possibly complex) space - that are Discrete or MultiDiscrete. Note that if `fill=0` and the - requested `indices` are out of the range of our data, the returned - one-hot vectors will actually be zero-hot (all slots zero). - extra_model_outputs_key: Only if what is "extra_model_outputs", this - specifies the sub-key (str) inside the extra_model_outputs dict, e.g. - STATE_OUT or ACTION_DIST_INPUTS. - - Returns: - A single MultiAgentDict with individual leaf-values (in case `indices` is an - int), or batched leaf-data (in case `indices` is a list of ints or a slice - object). In the latter case, the batch size matches the env timesteps - defined via `indices`. - """ ignore_last_ts = what not in ["observations", "infos"] ret = {} for agent_id, sa_episode in self.agent_episodes.items(): @@ -2412,8 +2336,8 @@ def _get_data_by_env_steps( index_incl_lookback=agent_indices, fill=fill, one_hot_discrete=one_hot_discrete, - hanging_val=hanging_val, extra_model_outputs_key=extra_model_outputs_key, + hanging_val=hanging_val, ) if agent_values is not None: ret[agent_id] = agent_values @@ -2426,56 +2350,11 @@ def _get_single_agent_data_by_index( inf_lookback_buffer: InfiniteLookbackBuffer, agent_id: AgentID, index_incl_lookback: Union[int, str], - fill: Optional[Any] = None, - one_hot_discrete: bool = False, - extra_model_outputs_key: Optional[str] = None, - hanging_val: Optional[Any] = None, + fill: Any, + one_hot_discrete: dict, + extra_model_outputs_key: str, + hanging_val: Any, ) -> Any: - """Returns single data item from the episode based on given (env step) index. - - Args: - what: A (str) descriptor of what data to collect. Must be one of - "observations", "infos", "actions", "rewards", or "extra_model_outputs". - inf_lookback_buffer: The InfiniteLookbackBuffer to use for extracting the - data. - index_incl_lookback: An int specifying, which index to pull from the given - `inf_lookback_buffer`, but disregarding the special logic of the - lookback buffer. Meaning if the `index_incl_lookback` is 0, then the - first value in the lookback buffer should be returned, not the first - value after the lookback buffer (which would be normal behavior for - pulling items from an InfiniteLookbackBuffer object). - If the value is `self.SKIP_ENV_TS_TAG`, either None is returned (if - `fill` is None) or the provided `fill` value. - agent_id: The individual agent ID to pull data for. Used to lookup the - `SingleAgentEpisode` object for this agent in `self`. - fill: An optional float value to use for filling up the returned results at - the boundaries. This filling only happens if the requested index range's - start/stop boundaries exceed the buffer's boundaries (including the - lookback buffer on the left side). This comes in very handy, if users - don't want to worry about reaching such boundaries and want to zero-pad. - For example, a buffer with data [10, 11, 12, 13, 14] and lookback - buffer size of 2 (meaning `10` and `11` are part of the lookback buffer) - will respond to `index_incl_lookback=-6` and `fill=0.0` - with `0.0`. - one_hot_discrete: If True, will return one-hot vectors (instead of - int-values) for those sub-components of a (possibly complex) space - that are Discrete or MultiDiscrete. Note that if `fill=0` and the - requested `index_incl_lookback` is out of the range of our data, the - returned one-hot vectors will actually be zero-hot (all slots zero). - extra_model_outputs_key: Only if what is "extra_model_outputs", this - specifies the sub-key (str) inside the extra_model_outputs dict, e.g. - STATE_OUT or ACTION_DIST_INPUTS. - hanging_val: In case we are pulling actions, rewards, or extra_model_outputs - data, there might be information "hanging" (cached). For example, - if an agent receives an observation o0 and then immediately sends an - action a0 back, but then does NOT immediately reveive a next - observation, a0 is now cached (not fully logged yet with this - episode). The currently cached value must be provided here to be able - to return it in case the index is -1 (most recent timestep). - - Returns: - A data item corresponding to the provided args. - """ sa_episode = self.agent_episodes[agent_id] if index_incl_lookback == self.SKIP_ENV_TS_TAG: @@ -2494,29 +2373,40 @@ def _get_single_agent_data_by_index( ), **one_hot_discrete, ) + # No skip timestep -> Provide value at given index for this agent. - else: - if what == "extra_model_outputs": - # Special case: extra_model_outputs and key=None (return all keys as - # a dict). Note that `inf_lookback_buffer` is NOT an infinite lookback - # buffer, but a dict mapping keys to individual infinite lookback - # buffers. - if extra_model_outputs_key is None: - assert hanging_val is None or isinstance(hanging_val, dict) - return { - key: sub_buffer.get( - indices=index_incl_lookback - sub_buffer.lookback, - neg_index_as_lookback=True, - fill=fill, - _add_last_ts_value=( - None if hanging_val is None else hanging_val[key] - ), - **one_hot_discrete, - ) - for key, sub_buffer in inf_lookback_buffer.items() - } - # Extract data directly from the infinite lookback buffer object. + # Special case: extra_model_outputs and key=None (return all keys as + # a dict). Note that `inf_lookback_buffer` is NOT an infinite lookback + # buffer, but a dict mapping keys to individual infinite lookback + # buffers. + elif what == "extra_model_outputs" and extra_model_outputs_key is None: + assert hanging_val is None or isinstance(hanging_val, dict) + ret = {} + if inf_lookback_buffer: + for key, sub_buffer in inf_lookback_buffer.items(): + ret[key] = sub_buffer.get( + indices=index_incl_lookback - sub_buffer.lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=( + None if hanging_val is None else hanging_val[key] + ), + **one_hot_discrete, + ) + else: + for key in hanging_val.keys(): + ret[key] = InfiniteLookbackBuffer().get( + indices=index_incl_lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=hanging_val[key], + **one_hot_discrete, + ) + return ret + + # Extract data directly from the infinite lookback buffer object. + else: return inf_lookback_buffer.get( indices=index_incl_lookback - inf_lookback_buffer.lookback, neg_index_as_lookback=True, diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 01b274e92477..14380b789908 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -35,7 +35,7 @@ ) from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing -from ray.rllib.examples.envs.classes.two_step_game import TwoStepGame +from ray.rllib.examples.envs.classes.multi_agent.two_step_game import TwoStepGame from ray.rllib.examples._old_api_stack.models.centralized_critic_models import ( CentralizedCriticModel, TorchCentralizedCriticModel, diff --git a/rllib/examples/envs/agents_act_in_sequence.py b/rllib/examples/envs/agents_act_in_sequence.py new file mode 100644 index 000000000000..c2872a6e4aca --- /dev/null +++ b/rllib/examples/envs/agents_act_in_sequence.py @@ -0,0 +1,87 @@ +"""Example of running a multi-agent experiment w/ agents taking turns (sequence). + +This example: + - demonstrates how to write your own (multi-agent) environment using RLlib's + MultiAgentEnv API. + - shows how to implement the `reset()` and `step()` methods of the env such that + the agents act in a fixed sequence (taking turns). + - shows how to configure and setup this environment class within an RLlib + Algorithm config. + - runs the experiment with the configured algo, trying to solve the environment. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see results similar to the following in your console output: ++---------------------------+----------+--------+------------------+--------+ +| Trial name | status | iter | total time (s) | ts | +|---------------------------+----------+--------+------------------+--------+ +| PPO_TicTacToe_957aa_00000 | RUNNING | 25 | 96.7452 | 100000 | ++---------------------------+----------+--------+------------------+--------+ ++-------------------+------------------+------------------+ +| combined return | return player2 | return player1 | +|-------------------+------------------+------------------| +| -2 | 1.15 | -0.85 | ++-------------------+------------------+------------------+ + +Note that even though we are playing a zero-sum game, the overall return should start +at some negative values due to the misplacement penalty of our (simplified) TicTacToe +game. +""" +from ray.rllib.examples.envs.classes.multi_agent.tic_tac_toe import TicTacToe +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env # noqa + + +parser = add_rllib_example_script_args( + default_reward=-4.0, default_iters=50, default_timesteps=100000 +) +parser.set_defaults( + enable_new_api_stack=True, + num_agents=2, +) + + +if __name__ == "__main__": + args = parser.parse_args() + + assert args.num_agents == 2, "Must set --num-agents=2 when running this script!" + + # You can also register the env creator function explicitly with: + # register_env("tic_tac_toe", lambda cfg: TicTacToe()) + + # Or allow the RLlib user to set more c'tor options via their algo config: + # config.environment(env_config={[c'tor arg name]: [value]}) + # register_env("tic_tac_toe", lambda cfg: TicTacToe(cfg)) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment(TicTacToe) + .multi_agent( + # Define two policies. + policies={"player1", "player2"}, + # Map agent "player1" to policy "player1" and agent "player2" to policy + # "player2". + policy_mapping_fn=lambda agent_id, episode, **kw: agent_id, + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/envs/agents_act_simultaneously.py b/rllib/examples/envs/agents_act_simultaneously.py new file mode 100644 index 000000000000..1e9ce55ce29b --- /dev/null +++ b/rllib/examples/envs/agents_act_simultaneously.py @@ -0,0 +1,108 @@ +"""Example of running a multi-agent experiment w/ agents always acting simultaneously. + +This example: + - demonstrates how to write your own (multi-agent) environment using RLlib's + MultiAgentEnv API. + - shows how to implement the `reset()` and `step()` methods of the env such that + the agents act simultaneously. + - shows how to configure and setup this environment class within an RLlib + Algorithm config. + - runs the experiment with the configured algo, trying to solve the environment. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --sheldon-cooper-mode` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see results similar to the following in your console output: + ++-----------------------------------+----------+--------+------------------+-------+ +| Trial name | status | iter | total time (s) | ts | +|-----------------------------------+----------+--------+------------------+-------+ +| PPO_RockPaperScissors_8cef7_00000 | RUNNING | 3 | 16.5348 | 12000 | ++-----------------------------------+----------+--------+------------------+-------+ ++-------------------+------------------+------------------+ +| combined return | return player2 | return player1 | +|-------------------+------------------+------------------| +| 0 | -0.15 | 0.15 | ++-------------------+------------------+------------------+ + +Note that b/c we are playing a zero-sum game, the overall return remains 0.0 at +all times. +""" +from ray.rllib.examples.envs.classes.multi_agent.rock_paper_scissors import ( + RockPaperScissors, +) +from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env # noqa + + +parser = add_rllib_example_script_args( + default_reward=0.9, default_iters=50, default_timesteps=100000 +) +parser.set_defaults( + enable_new_api_stack=True, + num_agents=2, +) +parser.add_argument( + "--sheldon-cooper-mode", + action="store_true", + help="Whether to add two more actions to the game: Lizard and Spock. " + "Watch here for more details :) https://www.youtube.com/watch?v=x5Q6-wMx-K8", +) + + +if __name__ == "__main__": + args = parser.parse_args() + + assert args.num_agents == 2, "Must set --num-agents=2 when running this script!" + + # You can also register the env creator function explicitly with: + # register_env("env", lambda cfg: RockPaperScissors({"sheldon_cooper_mode": False})) + + # Or you can hard code certain settings into the Env's constructor (`config`). + # register_env( + # "rock-paper-scissors-w-sheldon-mode-activated", + # lambda config: RockPaperScissors({**config, **{"sheldon_cooper_mode": True}}), + # ) + + # Or allow the RLlib user to set more c'tor options via their algo config: + # config.environment(env_config={[c'tor arg name]: [value]}) + # register_env("rock-paper-scissors", lambda cfg: RockPaperScissors(cfg)) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment( + RockPaperScissors, + env_config={"sheldon_cooper_mode": args.sheldon_cooper_mode}, + ) + .env_runners( + env_to_module_connector=lambda env: FlattenObservations(multi_agent=True), + ) + .multi_agent( + # Define two policies. + policies={"player1", "player2"}, + # Map agent "player1" to policy "player1" and agent "player2" to policy + # "player2". + policy_mapping_fn=lambda agent_id, episode, **kw: agent_id, + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/envs/classes/coin_game_non_vectorized_env.py b/rllib/examples/envs/classes/coin_game_non_vectorized_env.py deleted file mode 100644 index b94b984c85a4..000000000000 --- a/rllib/examples/envs/classes/coin_game_non_vectorized_env.py +++ /dev/null @@ -1,344 +0,0 @@ -########## -# Contribution by the Center on Long-Term Risk: -# https://github.com/longtermrisk/marltoolbox -########## - -import copy - -try: - # This works in Python<3.9 - from collections import Iterable -except ImportError: - # This works in Python>=3.9 - from collections.abc import Iterable - -import gymnasium as gym -import logging -import numpy as np -from gymnasium.spaces import Discrete -from gymnasium.utils import seeding -from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.utils import override -from typing import Dict, Optional - -from ray.rllib.examples.envs.classes.utils.interfaces import InfoAccumulationInterface - -logger = logging.getLogger(__name__) - - -class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): - """ - Coin Game environment. - """ - - NAME = "CoinGame" - NUM_AGENTS = 2 - NUM_ACTIONS = 4 - action_space = Discrete(NUM_ACTIONS) - observation_space = None - MOVES = [ - np.array([0, 1]), - np.array([0, -1]), - np.array([1, 0]), - np.array([-1, 0]), - ] - - def __init__(self, config: Optional[Dict] = None): - if config is None: - config = {} - - self._validate_config(config) - - self._load_config(config) - self.player_red_id, self.player_blue_id = self.players_ids - self.n_features = self.grid_size**2 * (2 * self.NUM_AGENTS) - self.observation_space = gym.spaces.Box( - low=0, high=1, shape=(self.grid_size, self.grid_size, 4), dtype="uint8" - ) - - self.step_count_in_current_episode = None - if self.output_additional_info: - self._init_info() - - def _validate_config(self, config): - if "players_ids" in config: - assert isinstance(config["players_ids"], Iterable) - assert len(config["players_ids"]) == self.NUM_AGENTS - - def _load_config(self, config): - self.players_ids = config.get("players_ids", ["player_red", "player_blue"]) - self.max_steps = config.get("max_steps", 20) - self.grid_size = config.get("grid_size", 3) - self.output_additional_info = config.get("output_additional_info", True) - self.asymmetric = config.get("asymmetric", False) - self.both_players_can_pick_the_same_coin = config.get( - "both_players_can_pick_the_same_coin", True - ) - - @override(gym.Env) - def reset(self, *, seed=None, options=None): - self.np_random, seed = seeding.np_random(seed) - - self.step_count_in_current_episode = 0 - - if self.output_additional_info: - self._reset_info() - - self._randomize_color_and_player_positions() - self._generate_coin() - obs = self._generate_observation() - - return {self.player_red_id: obs[0], self.player_blue_id: obs[1]}, {} - - def _randomize_color_and_player_positions(self): - # Reset coin color and the players and coin positions - self.red_coin = self.np_random.integers(low=0, high=2) - self.red_pos = self.np_random.integers(low=0, high=self.grid_size, size=(2,)) - self.blue_pos = self.np_random.integers(low=0, high=self.grid_size, size=(2,)) - self.coin_pos = np.zeros(shape=(2,), dtype=np.int8) - - self._players_do_not_overlap_at_start() - - def _players_do_not_overlap_at_start(self): - while self._same_pos(self.red_pos, self.blue_pos): - self.blue_pos = self.np_random.integers(self.grid_size, size=2) - - def _generate_coin(self): - self._switch_between_coin_color_at_each_generation() - self._coin_position_different_from_players_positions() - - def _switch_between_coin_color_at_each_generation(self): - self.red_coin = 1 - self.red_coin - - def _coin_position_different_from_players_positions(self): - success = 0 - while success < self.NUM_AGENTS: - self.coin_pos = self.np_random.integers(self.grid_size, size=2) - success = 1 - self._same_pos(self.red_pos, self.coin_pos) - success += 1 - self._same_pos(self.blue_pos, self.coin_pos) - - def _generate_observation(self): - obs = np.zeros((self.grid_size, self.grid_size, 4)) - obs[self.red_pos[0], self.red_pos[1], 0] = 1 - obs[self.blue_pos[0], self.blue_pos[1], 1] = 1 - if self.red_coin: - obs[self.coin_pos[0], self.coin_pos[1], 2] = 1 - else: - obs[self.coin_pos[0], self.coin_pos[1], 3] = 1 - - obs = self._get_obs_invariant_to_the_player_trained(obs) - - return obs - - @override(gym.Env) - def step(self, actions: Dict): - """ - :param actions: Dict containing both actions for player_1 and player_2 - :return: observations, rewards, done, info - """ - actions = self._from_RLlib_API_to_list(actions) - - self.step_count_in_current_episode += 1 - self._move_players(actions) - reward_list, generate_new_coin = self._compute_reward() - if generate_new_coin: - self._generate_coin() - observations = self._generate_observation() - - return self._to_RLlib_API(observations, reward_list) - - def _same_pos(self, x, y): - return (x == y).all() - - def _move_players(self, actions): - self.red_pos = (self.red_pos + self.MOVES[actions[0]]) % self.grid_size - self.blue_pos = (self.blue_pos + self.MOVES[actions[1]]) % self.grid_size - - def _compute_reward(self): - - reward_red = 0.0 - reward_blue = 0.0 - generate_new_coin = False - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = ( - False, - False, - False, - False, - ) - - red_first_if_both = None - if not self.both_players_can_pick_the_same_coin: - if self._same_pos(self.red_pos, self.coin_pos) and self._same_pos( - self.blue_pos, self.coin_pos - ): - red_first_if_both = bool(self.np_random.integers(low=0, high=2)) - - if self.red_coin: - if self._same_pos(self.red_pos, self.coin_pos) and ( - red_first_if_both is None or red_first_if_both - ): - generate_new_coin = True - reward_red += 1 - if self.asymmetric: - reward_red += 3 - red_pick_any = True - red_pick_red = True - if self._same_pos(self.blue_pos, self.coin_pos) and ( - red_first_if_both is None or not red_first_if_both - ): - generate_new_coin = True - reward_red += -2 - reward_blue += 1 - blue_pick_any = True - else: - if self._same_pos(self.red_pos, self.coin_pos) and ( - red_first_if_both is None or red_first_if_both - ): - generate_new_coin = True - reward_red += 1 - reward_blue += -2 - if self.asymmetric: - reward_red += 3 - red_pick_any = True - if self._same_pos(self.blue_pos, self.coin_pos) and ( - red_first_if_both is None or not red_first_if_both - ): - generate_new_coin = True - reward_blue += 1 - blue_pick_blue = True - blue_pick_any = True - - reward_list = [reward_red, reward_blue] - - if self.output_additional_info: - self._accumulate_info( - red_pick_any=red_pick_any, - red_pick_red=red_pick_red, - blue_pick_any=blue_pick_any, - blue_pick_blue=blue_pick_blue, - ) - - return reward_list, generate_new_coin - - def _from_RLlib_API_to_list(self, actions): - """ - Format actions from dict of players to list of lists - """ - actions = [actions[player_id] for player_id in self.players_ids] - return actions - - def _get_obs_invariant_to_the_player_trained(self, observation): - """ - We want to be able to use a policy trained as player 1, - for evaluation as player 2 and vice versa. - """ - - # player_red_observation contains - # [Red pos, Blue pos, Red coin pos, Blue coin pos] - player_red_observation = observation - # After modification, player_blue_observation will contain - # [Blue pos, Red pos, Blue coin pos, Red coin pos] - player_blue_observation = copy.deepcopy(observation) - player_blue_observation[..., 0] = observation[..., 1] - player_blue_observation[..., 1] = observation[..., 0] - player_blue_observation[..., 2] = observation[..., 3] - player_blue_observation[..., 3] = observation[..., 2] - - return [player_red_observation, player_blue_observation] - - def _to_RLlib_API(self, observations, rewards): - state = { - self.player_red_id: observations[0], - self.player_blue_id: observations[1], - } - rewards = { - self.player_red_id: rewards[0], - self.player_blue_id: rewards[1], - } - - epi_is_done = self.step_count_in_current_episode >= self.max_steps - if self.step_count_in_current_episode > self.max_steps: - logger.warning( - "step_count_in_current_episode > self.max_steps: " - f"{self.step_count_in_current_episode} > {self.max_steps}" - ) - - done = { - self.player_red_id: epi_is_done, - self.player_blue_id: epi_is_done, - "__all__": epi_is_done, - } - - if epi_is_done and self.output_additional_info: - player_red_info, player_blue_info = self._get_episode_info() - info = { - self.player_red_id: player_red_info, - self.player_blue_id: player_blue_info, - } - else: - info = {} - - return state, rewards, done, done, info - - @override(InfoAccumulationInterface) - def _get_episode_info(self): - """ - Output the following information: - pick_speed is the fraction of steps during which the player picked a - coin. - pick_own_color is the fraction of coins picked by the player which have - the same color as the player. - """ - player_red_info, player_blue_info = {}, {} - - if len(self.red_pick) > 0: - red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / len(self.red_pick) - if red_pick > 0: - player_red_info["pick_own_color"] = sum(self.red_pick_own) / red_pick - - if len(self.blue_pick) > 0: - blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / len(self.blue_pick) - if blue_pick > 0: - player_blue_info["pick_own_color"] = sum(self.blue_pick_own) / blue_pick - - return player_red_info, player_blue_info - - @override(InfoAccumulationInterface) - def _reset_info(self): - self.red_pick.clear() - self.red_pick_own.clear() - self.blue_pick.clear() - self.blue_pick_own.clear() - - @override(InfoAccumulationInterface) - def _accumulate_info( - self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue - ): - - self.red_pick.append(red_pick_any) - self.red_pick_own.append(red_pick_red) - self.blue_pick.append(blue_pick_any) - self.blue_pick_own.append(blue_pick_blue) - - @override(InfoAccumulationInterface) - def _init_info(self): - self.red_pick = [] - self.red_pick_own = [] - self.blue_pick = [] - self.blue_pick_own = [] - - -class AsymCoinGame(CoinGame): - NAME = "AsymCoinGame" - - def __init__(self, config: Optional[dict] = None): - if config is None: - config = {} - - if "asymmetric" in config: - assert config["asymmetric"] - else: - config["asymmetric"] = True - super().__init__(config) diff --git a/rllib/examples/envs/classes/coin_game_vectorized_env.py b/rllib/examples/envs/classes/coin_game_vectorized_env.py deleted file mode 100644 index d5b296c31268..000000000000 --- a/rllib/examples/envs/classes/coin_game_vectorized_env.py +++ /dev/null @@ -1,386 +0,0 @@ -########## -# Contribution by the Center on Long-Term Risk: -# https://github.com/longtermrisk/marltoolbox -# Some parts are originally from: -# https://github.com/julianstastny/openspiel-social-dilemmas/ -# blob/master/games/coin_game_gym.py -########## - -import copy -from collections import Iterable - -import numpy as np -from numba import jit, prange -from numba.typed import List -from ray.rllib.examples.envs.classes.coin_game_non_vectorized_env import CoinGame -from ray.rllib.utils import override - - -class VectorizedCoinGame(CoinGame): - """ - Vectorized Coin Game environment. - """ - - def __init__(self, config=None): - if config is None: - config = {} - - super().__init__(config) - - self.batch_size = config.get("batch_size", 1) - self.force_vectorized = config.get("force_vectorize", False) - assert self.grid_size == 3, "hardcoded in the generate_state function" - - @override(CoinGame) - def _randomize_color_and_player_positions(self): - # Reset coin color and the players and coin positions - self.red_coin = np.random.randint(2, size=self.batch_size) - self.red_pos = np.random.randint(self.grid_size, size=(self.batch_size, 2)) - self.blue_pos = np.random.randint(self.grid_size, size=(self.batch_size, 2)) - self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) - - self._players_do_not_overlap_at_start() - - @override(CoinGame) - def _players_do_not_overlap_at_start(self): - for i in range(self.batch_size): - while _same_pos(self.red_pos[i], self.blue_pos[i]): - self.blue_pos[i] = np.random.randint(self.grid_size, size=2) - - @override(CoinGame) - def _generate_coin(self): - generate = np.ones(self.batch_size, dtype=bool) - self.coin_pos = generate_coin( - self.batch_size, - generate, - self.red_coin, - self.red_pos, - self.blue_pos, - self.coin_pos, - self.grid_size, - ) - - @override(CoinGame) - def _generate_observation(self): - obs = generate_observations_wt_numba_optimization( - self.batch_size, - self.red_pos, - self.blue_pos, - self.coin_pos, - self.red_coin, - self.grid_size, - ) - - obs = self._get_obs_invariant_to_the_player_trained(obs) - obs, _ = self._optional_unvectorize(obs) - return obs - - def _optional_unvectorize(self, obs, rewards=None): - if self.batch_size == 1 and not self.force_vectorized: - obs = [one_obs[0, ...] for one_obs in obs] - if rewards is not None: - rewards[0], rewards[1] = rewards[0][0], rewards[1][0] - return obs, rewards - - @override(CoinGame) - def step(self, actions: Iterable): - - actions = self._from_RLlib_API_to_list(actions) - self.step_count_in_current_episode += 1 - - ( - self.red_pos, - self.blue_pos, - rewards, - self.coin_pos, - observation, - self.red_coin, - red_pick_any, - red_pick_red, - blue_pick_any, - blue_pick_blue, - ) = vectorized_step_wt_numba_optimization( - actions, - self.batch_size, - self.red_pos, - self.blue_pos, - self.coin_pos, - self.red_coin, - self.grid_size, - self.asymmetric, - self.max_steps, - self.both_players_can_pick_the_same_coin, - ) - - if self.output_additional_info: - self._accumulate_info( - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue - ) - - obs = self._get_obs_invariant_to_the_player_trained(observation) - - obs, rewards = self._optional_unvectorize(obs, rewards) - - return self._to_RLlib_API(obs, rewards) - - @override(CoinGame) - def _get_episode_info(self): - - player_red_info, player_blue_info = {}, {} - - if len(self.red_pick) > 0: - red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / ( - len(self.red_pick) * self.batch_size - ) - if red_pick > 0: - player_red_info["pick_own_color"] = sum(self.red_pick_own) / red_pick - - if len(self.blue_pick) > 0: - blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / ( - len(self.blue_pick) * self.batch_size - ) - if blue_pick > 0: - player_blue_info["pick_own_color"] = sum(self.blue_pick_own) / blue_pick - - return player_red_info, player_blue_info - - @override(CoinGame) - def _from_RLlib_API_to_list(self, actions): - - ac_red = actions[self.player_red_id] - ac_blue = actions[self.player_blue_id] - if not isinstance(ac_red, Iterable): - assert not isinstance(ac_blue, Iterable) - ac_red, ac_blue = [ac_red], [ac_blue] - actions = [ac_red, ac_blue] - actions = np.array(actions).T - return actions - - def _save_env(self): - env_save_state = { - "red_pos": self.red_pos, - "blue_pos": self.blue_pos, - "coin_pos": self.coin_pos, - "red_coin": self.red_coin, - "grid_size": self.grid_size, - "asymmetric": self.asymmetric, - "batch_size": self.batch_size, - "step_count_in_current_episode": self.step_count_in_current_episode, - "max_steps": self.max_steps, - "red_pick": self.red_pick, - "red_pick_own": self.red_pick_own, - "blue_pick": self.blue_pick, - "blue_pick_own": self.blue_pick_own, - "both_players_can_pick_the_same_coin": self.both_players_can_pick_the_same_coin, # noqa: E501 - } - return copy.deepcopy(env_save_state) - - def _load_env(self, env_state): - for k, v in env_state.items(): - self.__setattr__(k, v) - - -class AsymVectorizedCoinGame(VectorizedCoinGame): - NAME = "AsymCoinGame" - - def __init__(self, config=None): - if config is None: - config = {} - - if "asymmetric" in config: - assert config["asymmetric"] - else: - config["asymmetric"] = True - super().__init__(config) - - -@jit(nopython=True) -def move_players(batch_size, actions, red_pos, blue_pos, grid_size): - moves = List( - [ - np.array([0, 1]), - np.array([0, -1]), - np.array([1, 0]), - np.array([-1, 0]), - ] - ) - - for j in prange(batch_size): - red_pos[j] = (red_pos[j] + moves[actions[j, 0]]) % grid_size - blue_pos[j] = (blue_pos[j] + moves[actions[j, 1]]) % grid_size - return red_pos, blue_pos - - -@jit(nopython=True) -def compute_reward( - batch_size, - red_pos, - blue_pos, - coin_pos, - red_coin, - asymmetric, - both_players_can_pick_the_same_coin, -): - reward_red = np.zeros(batch_size) - reward_blue = np.zeros(batch_size) - generate = np.zeros(batch_size, dtype=np.bool_) - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 - - for i in prange(batch_size): - red_first_if_both = None - if not both_players_can_pick_the_same_coin: - if _same_pos(red_pos[i], coin_pos[i]) and _same_pos( - blue_pos[i], coin_pos[i] - ): - red_first_if_both = bool(np.random.randint(0, 1)) - - if red_coin[i]: - if _same_pos(red_pos[i], coin_pos[i]) and ( - red_first_if_both is None or red_first_if_both - ): - generate[i] = True - reward_red[i] += 1 - if asymmetric: - reward_red[i] += 3 - red_pick_any += 1 - red_pick_red += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and ( - red_first_if_both is None or not red_first_if_both - ): - generate[i] = True - reward_red[i] += -2 - reward_blue[i] += 1 - blue_pick_any += 1 - else: - if _same_pos(red_pos[i], coin_pos[i]) and ( - red_first_if_both is None or red_first_if_both - ): - generate[i] = True - reward_red[i] += 1 - reward_blue[i] += -2 - if asymmetric: - reward_red[i] += 3 - red_pick_any += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and ( - red_first_if_both is None or not red_first_if_both - ): - generate[i] = True - reward_blue[i] += 1 - blue_pick_any += 1 - blue_pick_blue += 1 - reward = [reward_red, reward_blue] - - return reward, generate, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue - - -@jit(nopython=True) -def _same_pos(x, y): - return (x == y).all() - - -@jit(nopython=True) -def _flatten_index(pos, grid_size): - y_pos, x_pos = pos - idx = grid_size * y_pos - idx += x_pos - return idx - - -@jit(nopython=True) -def _unflatten_index(pos, grid_size): - x_idx = pos % grid_size - y_idx = pos // grid_size - return np.array([y_idx, x_idx]) - - -@jit(nopython=True) -def generate_coin( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size -): - red_coin[generate] = 1 - red_coin[generate] - for i in prange(batch_size): - if generate[i]: - coin_pos[i] = place_coin(red_pos[i], blue_pos[i], grid_size) - return coin_pos - - -@jit(nopython=True) -def place_coin(red_pos_i, blue_pos_i, grid_size): - red_pos_flat = _flatten_index(red_pos_i, grid_size) - blue_pos_flat = _flatten_index(blue_pos_i, grid_size) - possible_coin_pos = np.array( - [x for x in range(9) if ((x != blue_pos_flat) and (x != red_pos_flat))] - ) - flat_coin_pos = np.random.choice(possible_coin_pos) - return _unflatten_index(flat_coin_pos, grid_size) - - -@jit(nopython=True) -def generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size -): - obs = np.zeros((batch_size, grid_size, grid_size, 4)) - for i in prange(batch_size): - obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 - obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1 - if red_coin[i]: - obs[i, coin_pos[i][0], coin_pos[i][1], 2] = 1 - else: - obs[i, coin_pos[i][0], coin_pos[i][1], 3] = 1 - return obs - - -@jit(nopython=True) -def vectorized_step_wt_numba_optimization( - actions, - batch_size, - red_pos, - blue_pos, - coin_pos, - red_coin, - grid_size: int, - asymmetric: bool, - max_steps: int, - both_players_can_pick_the_same_coin: bool, -): - red_pos, blue_pos = move_players(batch_size, actions, red_pos, blue_pos, grid_size) - - ( - reward, - generate, - red_pick_any, - red_pick_red, - blue_pick_any, - blue_pick_blue, - ) = compute_reward( - batch_size, - red_pos, - blue_pos, - coin_pos, - red_coin, - asymmetric, - both_players_can_pick_the_same_coin, - ) - - coin_pos = generate_coin( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size - ) - - obs = generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size - ) - - return ( - red_pos, - blue_pos, - reward, - coin_pos, - obs, - red_coin, - red_pick_any, - red_pick_red, - blue_pick_any, - blue_pick_blue, - ) diff --git a/rllib/examples/envs/classes/matrix_sequential_social_dilemma.py b/rllib/examples/envs/classes/matrix_sequential_social_dilemma.py deleted file mode 100644 index bcf7500b9005..000000000000 --- a/rllib/examples/envs/classes/matrix_sequential_social_dilemma.py +++ /dev/null @@ -1,314 +0,0 @@ -########## -# Contribution by the Center on Long-Term Risk: -# https://github.com/longtermrisk/marltoolbox -# Some parts are originally from: -# https://github.com/alshedivat/lola/tree/master/lola -########## - -import logging -from abc import ABC -from collections import Iterable -from typing import Dict, Optional - -import numpy as np -from gymnasium.spaces import Discrete -from gymnasium.utils import seeding -from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.examples.envs.classes.utils.interfaces import InfoAccumulationInterface -from ray.rllib.examples.envs.classes.utils.mixins import ( - TwoPlayersTwoActionsInfoMixin, - NPlayersNDiscreteActionsInfoMixin, -) - -logger = logging.getLogger(__name__) - - -class MatrixSequentialSocialDilemma(InfoAccumulationInterface, MultiAgentEnv, ABC): - """ - A multi-agent abstract class for two player matrix games. - - PAYOUT_MATRIX: Numpy array. Along the dimension N, the action of the - Nth player change. The last dimension is used to select the player - whose reward you want to know. - - max_steps: number of step in one episode - - players_ids: list of the RLlib agent id of each player - - output_additional_info: ask the environment to aggregate information - about the last episode and output them as info at the end of the - episode. - """ - - def __init__(self, config: Optional[Dict] = None): - if config is None: - config = {} - - assert "reward_randomness" not in config.keys() - assert self.PAYOUT_MATRIX is not None - if "players_ids" in config: - assert ( - isinstance(config["players_ids"], Iterable) - and len(config["players_ids"]) == self.NUM_AGENTS - ) - - self.players_ids = config.get("players_ids", ["player_row", "player_col"]) - self.player_row_id, self.player_col_id = self.players_ids - self.max_steps = config.get("max_steps", 20) - self.output_additional_info = config.get("output_additional_info", True) - - self.step_count_in_current_episode = None - - # To store info about the fraction of each states - if self.output_additional_info: - self._init_info() - - def reset(self, *, seed=None, options=None): - self.np_random, seed = seeding.np_random(seed) - - self.step_count_in_current_episode = 0 - if self.output_additional_info: - self._reset_info() - return { - self.player_row_id: self.NUM_STATES - 1, - self.player_col_id: self.NUM_STATES - 1, - }, {} - - def step(self, actions: dict): - """ - :param actions: Dict containing both actions for player_1 and player_2 - :return: observations, rewards, done, info - """ - self.step_count_in_current_episode += 1 - action_player_row = actions[self.player_row_id] - action_player_col = actions[self.player_col_id] - - if self.output_additional_info: - self._accumulate_info(action_player_row, action_player_col) - - observations = self._produce_observations_invariant_to_the_player_trained( - action_player_row, action_player_col - ) - rewards = self._get_players_rewards(action_player_row, action_player_col) - epi_is_done = self.step_count_in_current_episode >= self.max_steps - if self.step_count_in_current_episode > self.max_steps: - logger.warning("self.step_count_in_current_episode >= self.max_steps") - info = self._get_info_for_current_epi(epi_is_done) - - return self._to_RLlib_API(observations, rewards, epi_is_done, info) - - def _produce_observations_invariant_to_the_player_trained( - self, action_player_0: int, action_player_1: int - ): - """ - We want to be able to use a policy trained as player 1 - for evaluation as player 2 and vice versa. - """ - return [ - action_player_0 * self.NUM_ACTIONS + action_player_1, - action_player_1 * self.NUM_ACTIONS + action_player_0, - ] - - def _get_players_rewards(self, action_player_0: int, action_player_1: int): - return [ - self.PAYOUT_MATRIX[action_player_0][action_player_1][0], - self.PAYOUT_MATRIX[action_player_0][action_player_1][1], - ] - - def _to_RLlib_API( - self, observations: list, rewards: list, epi_is_done: bool, info: dict - ): - - observations = { - self.player_row_id: observations[0], - self.player_col_id: observations[1], - } - - rewards = {self.player_row_id: rewards[0], self.player_col_id: rewards[1]} - - if info is None: - info = {} - else: - info = {self.player_row_id: info, self.player_col_id: info} - - done = { - self.player_row_id: epi_is_done, - self.player_col_id: epi_is_done, - "__all__": epi_is_done, - } - - return observations, rewards, done, done, info - - def _get_info_for_current_epi(self, epi_is_done): - if epi_is_done and self.output_additional_info: - info_for_current_epi = self._get_episode_info() - else: - info_for_current_epi = None - return info_for_current_epi - - def __str__(self): - return self.NAME - - -class IteratedMatchingPennies( - TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma -): - """ - A two-agent environment for the Matching Pennies game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]]) - NAME = "IMP" - - -class IteratedPrisonersDilemma( - TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma -): - """ - A two-agent environment for the Prisoner's Dilemma game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) - NAME = "IPD" - - -class IteratedAsymPrisonersDilemma( - TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma -): - """ - A two-agent environment for the Asymmetric Prisoner's Dilemma game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) - NAME = "IPD" - - -class IteratedStagHunt(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - """ - A two-agent environment for the Stag Hunt game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]]) - NAME = "IteratedStagHunt" - - -class IteratedChicken(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - """ - A two-agent environment for the Chicken game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+0, +0], [-1.0, +1.0]], [[+1, -1], [-10, -10]]]) - NAME = "IteratedChicken" - - -class IteratedAsymChicken(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - """ - A two-agent environment for the Asymmetric Chicken game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+2.0, +0], [-1.0, +1.0]], [[+2.5, -1], [-10, -10]]]) - NAME = "AsymmetricIteratedChicken" - - -class IteratedBoS(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - """ - A two-agent environment for the BoS game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( - [[[+3.0, +2.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +3.0]]] - ) - NAME = "IteratedBoS" - - -class IteratedAsymBoS(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - """ - A two-agent environment for the BoS game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( - [[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]] - ) - NAME = "AsymmetricIteratedBoS" - - -def define_greed_fear_matrix_game(greed, fear): - class GreedFearGame(TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma): - NUM_AGENTS = 2 - NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - R = 3 - P = 1 - T = R + greed - S = P - fear - PAYOUT_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]]) - NAME = "IteratedGreedFear" - - def __str__(self): - return f"{self.NAME} with greed={greed} and fear={fear}" - - return GreedFearGame - - -class IteratedBoSAndPD( - NPlayersNDiscreteActionsInfoMixin, MatrixSequentialSocialDilemma -): - """ - A two-agent environment for the BOTS + PD game. - """ - - NUM_AGENTS = 2 - NUM_ACTIONS = 3 - NUM_STATES = NUM_ACTIONS**NUM_AGENTS + 1 - ACTION_SPACE = Discrete(NUM_ACTIONS) - OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( - [ - [[3.5, +1], [+0, +0], [-3, +2]], - [[+0.0, +0], [+1, +3], [-3, +2]], - [[+2.0, -3], [+2, -3], [-1, -1]], - ] - ) - NAME = "IteratedBoSAndPD" diff --git a/rllib/examples/envs/classes/multi_agent/__init__.py b/rllib/examples/envs/classes/multi_agent/__init__.py new file mode 100644 index 000000000000..b7fb660ccd46 --- /dev/null +++ b/rllib/examples/envs/classes/multi_agent/__init__.py @@ -0,0 +1,35 @@ +from ray.rllib.env.multi_agent_env import make_multi_agent +from ray.rllib.examples.envs.classes.cartpole_with_dict_observation_space import ( + CartPoleWithDictObservationSpace, +) +from ray.rllib.examples.envs.classes.multi_agent.guess_the_number_game import ( + GuessTheNumberGame, +) +from ray.rllib.examples.envs.classes.multi_agent.two_step_game import ( + TwoStepGame, + TwoStepGameWithGroupedAgents, +) +from ray.rllib.examples.envs.classes.nested_space_repeat_after_me_env import ( + NestedSpaceRepeatAfterMeEnv, +) +from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole + + +# Backward compatibility. +__all__ = [ + "GuessTheNumberGame", + "TwoStepGame", + "TwoStepGameWithGroupedAgents", +] + + +MultiAgentCartPole = make_multi_agent("CartPole-v1") +MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0") +MultiAgentPendulum = make_multi_agent("Pendulum-v1") +MultiAgentStatelessCartPole = make_multi_agent(lambda config: StatelessCartPole(config)) +MultiAgentCartPoleWithDictObservationSpace = make_multi_agent( + lambda config: CartPoleWithDictObservationSpace(config) +) +MultiAgentNestedSpaceRepeatAfterMeEnv = make_multi_agent( + lambda config: NestedSpaceRepeatAfterMeEnv(config) +) diff --git a/rllib/examples/envs/classes/bandit_envs_discrete.py b/rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py similarity index 100% rename from rllib/examples/envs/classes/bandit_envs_discrete.py rename to rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py diff --git a/rllib/examples/envs/classes/bandit_envs_recommender_system.py b/rllib/examples/envs/classes/multi_agent/bandit_envs_recommender_system.py similarity index 100% rename from rllib/examples/envs/classes/bandit_envs_recommender_system.py rename to rllib/examples/envs/classes/multi_agent/bandit_envs_recommender_system.py diff --git a/rllib/examples/envs/classes/multi_agent.py b/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py similarity index 78% rename from rllib/examples/envs/classes/multi_agent.py rename to rllib/examples/envs/classes/multi_agent/guess_the_number_game.py index b0b8b588c30c..eaac3e4becc5 100644 --- a/rllib/examples/envs/classes/multi_agent.py +++ b/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py @@ -1,25 +1,6 @@ import gymnasium as gym -from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent -from ray.rllib.examples.envs.classes.cartpole_with_dict_observation_space import ( - CartPoleWithDictObservationSpace, -) -from ray.rllib.examples.envs.classes.nested_space_repeat_after_me_env import ( - NestedSpaceRepeatAfterMeEnv, -) -from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole - - -MultiAgentCartPole = make_multi_agent("CartPole-v1") -MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0") -MultiAgentPendulum = make_multi_agent("Pendulum-v1") -MultiAgentStatelessCartPole = make_multi_agent(lambda config: StatelessCartPole(config)) -MultiAgentCartPoleWithDictObservationSpace = make_multi_agent( - lambda config: CartPoleWithDictObservationSpace(config) -) -MultiAgentNestedSpaceRepeatAfterMeEnv = make_multi_agent( - lambda config: NestedSpaceRepeatAfterMeEnv(config) -) +from ray.rllib.env.multi_agent_env import MultiAgentEnv class GuessTheNumberGame(MultiAgentEnv): @@ -40,7 +21,7 @@ class GuessTheNumberGame(MultiAgentEnv): MAX_NUMBER = 3 MAX_STEPS = 20 - def __init__(self, config): + def __init__(self, config=None): super().__init__() self._agent_ids = {0, 1} diff --git a/rllib/examples/envs/classes/pettingzoo_chess.py b/rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py similarity index 100% rename from rllib/examples/envs/classes/pettingzoo_chess.py rename to rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py diff --git a/rllib/examples/envs/classes/pettingzoo_connect4.py b/rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py similarity index 100% rename from rllib/examples/envs/classes/pettingzoo_connect4.py rename to rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py diff --git a/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py b/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py new file mode 100644 index 000000000000..aa363ae75f2b --- /dev/null +++ b/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py @@ -0,0 +1,125 @@ +# __sphinx_doc_1_begin__ +import gymnasium as gym + +from ray.rllib.env.multi_agent_env import MultiAgentEnv + + +class RockPaperScissors(MultiAgentEnv): + """Two-player environment for the famous rock paper scissors game. + + # __sphinx_doc_1_end__ + Optionally, the "Sheldon Cooper extension" can be activated by passing + `sheldon_cooper_mode=True` into the constructor, in which case two more moves + are allowed: Spock and Lizard. Spock is poisoned by Lizard, disproven by Paper, but + crushes Rock and smashes Scissors. Lizard poisons Spock and eats Paper, but is + decapitated by Scissors and crushed by Rock. + + # __sphinx_doc_2_begin__ + Both players always move simultaneously over a course of 10 timesteps in total. + The winner of each timestep receives reward of +1, the losing player -1.0. + + The observation of each player is the last opponent action. + """ + + ROCK = 0 + PAPER = 1 + SCISSORS = 2 + LIZARD = 3 + SPOCK = 4 + + WIN_MATRIX = { + (ROCK, ROCK): (0, 0), + (ROCK, PAPER): (-1, 1), + (ROCK, SCISSORS): (1, -1), + (PAPER, ROCK): (1, -1), + (PAPER, PAPER): (0, 0), + (PAPER, SCISSORS): (-1, 1), + (SCISSORS, ROCK): (-1, 1), + (SCISSORS, PAPER): (1, -1), + (SCISSORS, SCISSORS): (0, 0), + } + # __sphinx_doc_2_end__ + + WIN_MATRIX.update( + { + # Sheldon Cooper mode: + (LIZARD, LIZARD): (0, 0), + (LIZARD, SPOCK): (1, -1), # Lizard poisons Spock + (LIZARD, ROCK): (-1, 1), # Rock crushes lizard + (LIZARD, PAPER): (1, -1), # Lizard eats paper + (LIZARD, SCISSORS): (-1, 1), # Scissors decapitate lizard + (ROCK, LIZARD): (1, -1), # Rock crushes lizard + (PAPER, LIZARD): (-1, 1), # Lizard eats paper + (SCISSORS, LIZARD): (1, -1), # Scissors decapitate lizard + (SPOCK, SPOCK): (0, 0), + (SPOCK, LIZARD): (-1, 1), # Lizard poisons Spock + (SPOCK, ROCK): (1, -1), # Spock vaporizes rock + (SPOCK, PAPER): (-1, 1), # Paper disproves Spock + (SPOCK, SCISSORS): (1, -1), # Spock smashes scissors + (ROCK, SPOCK): (-1, 1), # Spock vaporizes rock + (PAPER, SPOCK): (1, -1), # Paper disproves Spock + (SCISSORS, SPOCK): (-1, 1), # Spock smashes scissors + } + ) + + # __sphinx_doc_3_begin__ + def __init__(self, config=None): + super().__init__() + + self.agents = self.possible_agents = ["player1", "player2"] + + # The observations are always the last taken actions. Hence observation- and + # action spaces are identical. + self.observation_spaces = self.action_spaces = { + "player1": gym.spaces.Discrete(3), + "player2": gym.spaces.Discrete(3), + } + self.last_move = None + self.num_moves = 0 + # __sphinx_doc_3_end__ + + self.sheldon_cooper_mode = False + if config.get("sheldon_cooper_mode"): + self.sheldon_cooper_mode = True + self.action_spaces = self.observation_spaces = { + "player1": gym.spaces.Discrete(5), + "player2": gym.spaces.Discrete(5), + } + + # __sphinx_doc_4_begin__ + def reset(self, *, seed=None, options=None): + self.num_moves = 0 + + # The first observation should not matter (none of the agents has moved yet). + # Set them to 0. + return { + "player1": 0, + "player2": 0, + }, {} # <- empty infos dict + + # __sphinx_doc_4_end__ + + # __sphinx_doc_5_begin__ + def step(self, action_dict): + self.num_moves += 1 + + move1 = action_dict["player1"] + move2 = action_dict["player2"] + + # Set the next observations (simply use the other player's action). + # Note that because we are publishing both players in the observations dict, + # we expect both players to act in the next `step()` (simultaneous stepping). + observations = {"player1": move2, "player2": move1} + + # Compute rewards for each player based on the win-matrix. + r1, r2 = self.WIN_MATRIX[move1, move2] + rewards = {"player1": r1, "player2": r2} + + # Terminate the entire episode (for all agents) once 10 moves have been made. + terminateds = {"__all__": self.num_moves >= 10} + + # Leave truncateds and infos empty. + return observations, rewards, terminateds, {}, {} + + +# __sphinx_doc_5_end__ diff --git a/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py b/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py new file mode 100644 index 000000000000..ceb08422092f --- /dev/null +++ b/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py @@ -0,0 +1,144 @@ +# __sphinx_doc_1_begin__ +import gymnasium as gym +import numpy as np + +from ray.rllib.env.multi_agent_env import MultiAgentEnv + + +class TicTacToe(MultiAgentEnv): + """A two-player game in which any player tries to complete one row in a 3x3 field. + + The observation space is Box(0.0, 1.0, (9,)), where each index represents a distinct + field on a 3x3 board and values of 0.0 mean the field is empty, -1.0 means + the opponend owns the field, and 1.0 means we occupy the field: + ---------- + | 0| 1| 2| + ---------- + | 3| 4| 5| + ---------- + | 6| 7| 8| + ---------- + + The action space is Discrete(9) and actions landing on an already occupied field + are simply ignored (and thus useless to the player taking these actions). + + Once a player completes a row, they receive +1.0 reward, the losing player receives + -1.0 reward. In all other cases, both players receive 0.0 reward. + """ + + # __sphinx_doc_1_end__ + + # __sphinx_doc_2_begin__ + def __init__(self, config=None): + super().__init__() + + # Define the agents in the game. + self.agents = self.possible_agents = ["player1", "player2"] + + # Each agent observes a 9D tensor, representing the 3x3 fields of the board. + # A 0 means an empty field, a 1 represents a piece of player 1, a -1 a piece of + # player 2. + self.observation_spaces = { + "player1": gym.spaces.Box(-1.0, 1.0, (9,), np.float32), + "player2": gym.spaces.Box(-1.0, 1.0, (9,), np.float32), + } + # Each player has 9 actions, encoding the 9 fields each player can place a piece + # on during their turn. + self.action_spaces = { + "player1": gym.spaces.Discrete(9), + "player2": gym.spaces.Discrete(9), + } + + self.board = None + self.current_player = None + + # __sphinx_doc_2_end__ + + # __sphinx_doc_3_begin__ + def reset(self, *, seed=None, options=None): + self.board = [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + # Pick a random player to start the game. + self.current_player = np.random.choice(["player1", "player2"]) + # Return observations dict (only with the starting player, which is the one + # we expect to act next). + return { + self.current_player: np.array(self.board, np.float32), + }, {} + + # __sphinx_doc_3_end__ + + # __sphinx_doc_4_begin__ + def step(self, action_dict): + action = action_dict[self.current_player] + + # Create a rewards-dict (containing the rewards of the agent that just acted). + rewards = {self.current_player: 0.0} + # Create a terminateds-dict with the special `__all__` agent ID, indicating that + # if True, the episode ends for all agents. + terminateds = {"__all__": False} + + opponent = "player1" if self.current_player == "player2" else "player2" + + # Penalize trying to place a piece on an already occupied field. + if self.board[action] != 0: + rewards[self.current_player] -= 5.0 + # Change the board according to the (valid) action taken. + else: + self.board[action] = 1 if self.current_player == "player1" else -1 + + # After having placed a new piece, figure out whether the current player + # won or not. + if self.current_player == "player1": + win_val = [1, 1, 1] + else: + win_val = [-1, -1, -1] + if ( + # Horizontal win. + self.board[:3] == win_val + or self.board[3:6] == win_val + or self.board[6:] == win_val + # Vertical win. + or self.board[0:7:3] == win_val + or self.board[1:8:3] == win_val + or self.board[2:9:3] == win_val + # Diagonal win. + or self.board[::3] == win_val + or self.board[2:7:2] == win_val + ): + # Final reward is +5 for victory and -5 for a loss. + rewards[self.current_player] += 5.0 + rewards[opponent] = -5.0 + + # Episode is done and needs to be reset for a new game. + terminateds["__all__"] = True + + # The board might also be full w/o any player having won/lost. + # In this case, we simply end the episode and none of the players receives + # +1 or -1 reward. + elif 0 not in self.board: + terminateds["__all__"] = True + + # Flip players and return an observations dict with only the next player to + # make a move in it. + self.current_player = opponent + + return ( + {self.current_player: np.array(self.board, np.float32)}, + rewards, + terminateds, + {}, + {}, + ) + + +# __sphinx_doc_4_end__ diff --git a/rllib/examples/envs/classes/two_step_game.py b/rllib/examples/envs/classes/multi_agent/two_step_game.py similarity index 100% rename from rllib/examples/envs/classes/two_step_game.py rename to rllib/examples/envs/classes/multi_agent/two_step_game.py diff --git a/rllib/examples/envs/classes/utils/interfaces.py b/rllib/examples/envs/classes/utils/interfaces.py deleted file mode 100644 index 6cb398cfa3fe..000000000000 --- a/rllib/examples/envs/classes/utils/interfaces.py +++ /dev/null @@ -1,23 +0,0 @@ -########## -# Contribution by the Center on Long-Term Risk: -# https://github.com/longtermrisk/marltoolbox -########## -from abc import ABC, abstractmethod - - -class InfoAccumulationInterface(ABC): - @abstractmethod - def _init_info(self): - raise NotImplementedError() - - @abstractmethod - def _reset_info(self): - raise NotImplementedError() - - @abstractmethod - def _get_episode_info(self): - raise NotImplementedError() - - @abstractmethod - def _accumulate_info(self, *args, **kwargs): - raise NotImplementedError() diff --git a/rllib/examples/envs/classes/utils/mixins.py b/rllib/examples/envs/classes/utils/mixins.py deleted file mode 100644 index 236381f980c9..000000000000 --- a/rllib/examples/envs/classes/utils/mixins.py +++ /dev/null @@ -1,72 +0,0 @@ -########## -# Contribution by the Center on Long-Term Risk: -# https://github.com/longtermrisk/marltoolbox -########## -from abc import ABC -import numpy as np - -from ray.rllib.examples.envs.classes.utils.interfaces import InfoAccumulationInterface - - -class TwoPlayersTwoActionsInfoMixin(InfoAccumulationInterface, ABC): - """ - Mixin class to add logging capability in a two player discrete game. - Logs the frequency of each state. - """ - - def _init_info(self): - self.cc_count = [] - self.dd_count = [] - self.cd_count = [] - self.dc_count = [] - - def _reset_info(self): - self.cc_count.clear() - self.dd_count.clear() - self.cd_count.clear() - self.dc_count.clear() - - def _get_episode_info(self): - return { - "CC": np.mean(self.cc_count).item(), - "DD": np.mean(self.dd_count).item(), - "CD": np.mean(self.cd_count).item(), - "DC": np.mean(self.dc_count).item(), - } - - def _accumulate_info(self, ac0, ac1): - self.cc_count.append(ac0 == 0 and ac1 == 0) - self.cd_count.append(ac0 == 0 and ac1 == 1) - self.dc_count.append(ac0 == 1 and ac1 == 0) - self.dd_count.append(ac0 == 1 and ac1 == 1) - - -class NPlayersNDiscreteActionsInfoMixin(InfoAccumulationInterface, ABC): - """ - Mixin class to add logging capability in N player games with - discrete actions. - Logs the frequency of action profiles used - (action profile: the set of actions used during one step by all players). - """ - - def _init_info(self): - self.info_counters = {"n_steps_accumulated": 0} - - def _reset_info(self): - self.info_counters = {"n_steps_accumulated": 0} - - def _get_episode_info(self): - info = {} - if self.info_counters["n_steps_accumulated"] > 0: - for k, v in self.info_counters.items(): - if k != "n_steps_accumulated": - info[k] = v / self.info_counters["n_steps_accumulated"] - - return info - - def _accumulate_info(self, *actions): - id = "_".join([str(a) for a in actions]) - if id not in self.info_counters: - self.info_counters[id] = 0 - self.info_counters[id] += 1 - self.info_counters["n_steps_accumulated"] += 1 diff --git a/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py b/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py index 0e9247123c7d..6f474e8e3c69 100644 --- a/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py +++ b/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py @@ -71,7 +71,7 @@ register_env( - "RockPaperScissors", + "pettingzoo_rps", lambda _: ParallelPettingZooEnv(rps_v2.parallel_env()), ) @@ -84,7 +84,7 @@ base_config = ( get_trainable_cls(args.algo) .get_default_config() - .environment("RockPaperScissors") + .environment("pettingzoo_rps") .env_runners( env_to_module_connector=lambda env: ( # `agent_ids=...`: Only flatten obs for the learning RLModule. diff --git a/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py b/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py index ae8f299cd06f..adf88dba985b 100644 --- a/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py +++ b/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py @@ -45,7 +45,7 @@ register_env( - "RockPaperScissors", + "pettingzoo_rps", lambda _: ParallelPettingZooEnv(rps_v2.parallel_env()), ) @@ -58,7 +58,7 @@ base_config = ( get_trainable_cls(args.algo) .get_default_config() - .environment("RockPaperScissors") + .environment("pettingzoo_rps") .env_runners( env_to_module_connector=lambda env: FlattenObservations(multi_agent=True), ) diff --git a/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py b/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py index 11daa3f218d5..0981eb2575f1 100644 --- a/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py +++ b/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py @@ -43,7 +43,9 @@ from ray.rllib.connectors.env_to_module import FlattenObservations from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.examples.envs.classes.two_step_game import TwoStepGameWithGroupedAgents +from ray.rllib.examples.envs.classes.multi_agent.two_step_game import ( + TwoStepGameWithGroupedAgents, +) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, run_rllib_example_script_experiment,