evaluation
Package
Episodes
ray.rllib.evaluation.episode.Episode
Tracks the current state of a (possibly multi-agent) episode.
Attributes:
Name | Type | Description |
---|---|---|
new_batch_builder |
func |
Create a new MultiAgentSampleBatchBuilder. |
add_extra_batch |
func |
Return a built MultiAgentBatch to the sampler. |
batch_builder |
obj |
Batch builder for the current episode. |
total_reward |
float |
Summed reward across all agents in this episode. |
length |
int |
Length of this episode. |
episode_id |
int |
Unique id identifying this trajectory. |
agent_rewards |
dict |
Summed rewards broken down by agent. |
custom_metrics |
dict |
Dict where the you can add custom metrics. |
user_data |
dict |
Dict that you can use for temporary storage. E.g. in between two custom callbacks referring to the same episode. |
hist_data |
dict |
Dict mapping str keys to List[float] for storage of per-timestep float data throughout the episode. |
Use case 1: Model-based rollouts in multi-agent: A custom compute_actions() function in a policy can inspect the current episode state and perform a number of rollouts based on the policies and state of other agents in the environment.
Use case 2: Returning extra rollouts data. The model rollouts can be returned back to the sampler by calling:
>>> batch = episode.new_batch_builder()
>>> for each transition:
batch.add_values(...) # see sampler for usage
>>> episode.extra_batches.add(batch.build_and_reset())
__init__(self, policies, policy_mapping_fn, batch_builder_factory, extra_batch_callback, env_id, *, worker=None)
special
Initializes an Episode instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policies |
PolicyMap |
The PolicyMap object (mapping PolicyIDs to Policy objects) to use for determining, which policy is used for which agent. |
required |
policy_mapping_fn |
Callable[[Any, Episode, RolloutWorker], str] |
The mapping function mapping AgentIDs to PolicyIDs. |
required |
batch_builder_factory |
Callable[[], MultiAgentSampleBatchBuilder] |
required | |
extra_batch_callback |
Callable[[Union[SampleBatch, MultiAgentBatch]]] |
required | |
env_id |
Union[int, str] |
The environment's ID in which this episode runs. |
required |
worker |
Optional[RolloutWorker] |
The RolloutWorker instance, in which this episode runs. |
None |
Source code in ray/rllib/evaluation/episode.py
def __init__(
self,
policies: PolicyMap,
policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
PolicyID],
batch_builder_factory: Callable[[],
"MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None],
env_id: EnvID,
*,
worker: Optional["RolloutWorker"] = None,
):
"""Initializes an Episode instance.
Args:
policies: The PolicyMap object (mapping PolicyIDs to Policy
objects) to use for determining, which policy is used for
which agent.
policy_mapping_fn: The mapping function mapping AgentIDs to
PolicyIDs.
batch_builder_factory:
extra_batch_callback:
env_id: The environment's ID in which this episode runs.
worker: The RolloutWorker instance, in which this episode runs.
"""
self.new_batch_builder: Callable[
[], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
self.add_extra_batch: Callable[[SampleBatchType],
None] = extra_batch_callback
self.batch_builder: "MultiAgentSampleBatchBuilder" = \
batch_builder_factory()
self.total_reward: float = 0.0
self.length: int = 0
self.episode_id: int = random.randrange(2e9)
self.env_id = env_id
self.worker = worker
self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
self.custom_metrics: Dict[str, float] = {}
self.user_data: Dict[str, Any] = {}
self.hist_data: Dict[str, List[float]] = {}
self.media: Dict[str, Any] = {}
self.policy_map: PolicyMap = policies
self._policies = self.policy_map # backward compatibility
self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
PolicyID] = policy_mapping_fn
self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {}
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_done: Dict[AgentID, bool] = {}
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
list)
get_agents(self)
Returns list of agent IDs that have appeared in this episode.
Returns:
Type | Description |
---|---|
List[Any] |
The list of all agent IDs that have appeared so far in this episode. |
last_action_for(self, agent_id='agent0')
Returns the last action for the specified AgentID, or zeros.
The "last" action is the most recent one taken by the agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last action for. |
'agent0' |
Returns:
Type | Description |
---|---|
Any |
Last action the specified AgentID has executed. Zeros in case the agent has never performed any actions in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the last action for the specified AgentID, or zeros.
The "last" action is the most recent one taken by the agent.
Args:
agent_id: The agent's ID to get the last action for.
Returns:
Last action the specified AgentID has executed.
Zeros in case the agent has never performed any actions in the
episode.
"""
# Agent has already taken at least one action in the episode.
if agent_id in self._agent_to_last_action:
return flatten_to_single_ndarray(
self._agent_to_last_action[agent_id])
# Agent has not acted yet, return all zeros.
else:
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
flat = flatten_to_single_ndarray(policy.action_space.sample())
if hasattr(policy.action_space, "dtype"):
return np.zeros_like(flat, dtype=policy.action_space.dtype)
return np.zeros_like(flat)
last_done_for(self, agent_id='agent0')
Returns the last done flag for the specified AgentID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last done flag for. |
'agent0' |
Returns:
Type | Description |
---|---|
bool |
Last done flag for the specified AgentID. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
"""Returns the last done flag for the specified AgentID.
Args:
agent_id: The agent's ID to get the last done flag for.
Returns:
Last done flag for the specified AgentID.
"""
if agent_id not in self._agent_to_last_done:
self._agent_to_last_done[agent_id] = False
return self._agent_to_last_done[agent_id]
last_extra_action_outs_for(self, agent_id='agent0')
Returns the last extra-action outputs for the specified agent.
This data is returned by a call to
Policy.compute_actions_from_input_dict
as the 3rd return value
(1st return value = action; 2nd return value = RNN state outs).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last extra-action outs for. |
'agent0' |
Returns:
Type | Description |
---|---|
dict |
The last extra-action outs for the specified AgentID. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_extra_action_outs_for(
self,
agent_id: AgentID = _DUMMY_AGENT_ID,
) -> dict:
"""Returns the last extra-action outputs for the specified agent.
This data is returned by a call to
`Policy.compute_actions_from_input_dict` as the 3rd return value
(1st return value = action; 2nd return value = RNN state outs).
Args:
agent_id: The agent's ID to get the last extra-action outs for.
Returns:
The last extra-action outs for the specified AgentID.
"""
return self._agent_to_last_extra_action_outs[agent_id]
last_info_for(self, agent_id='agent0')
Returns the last info for the specified AgentID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last info for. |
'agent0' |
Returns:
Type | Description |
---|---|
Optional[dict] |
Last info dict the specified AgentID has seen. None in case the agent has never made any observations in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvInfoDict]:
"""Returns the last info for the specified AgentID.
Args:
agent_id: The agent's ID to get the last info for.
Returns:
Last info dict the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_info.get(agent_id)
last_observation_for(self, agent_id='agent0')
Returns the last observation for the specified AgentID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last observation for. |
'agent0' |
Returns:
Type | Description |
---|---|
Optional[Any] |
Last observation the specified AgentID has seen. None in case the agent has never made any observations in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_observation_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last observation for the specified AgentID.
Args:
agent_id: The agent's ID to get the last observation for.
Returns:
Last observation the specified AgentID has seen. None in case
the agent has never made any observations in the episode.
"""
return self._agent_to_last_obs.get(agent_id)
last_raw_obs_for(self, agent_id='agent0')
Returns the last un-preprocessed obs for the specified AgentID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last un-preprocessed observation for. |
'agent0' |
Returns:
Type | Description |
---|---|
Optional[Any] |
Last un-preprocessed observation the specified AgentID has seen. None in case the agent has never made any observations in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_raw_obs_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last un-preprocessed obs for the specified AgentID.
Args:
agent_id: The agent's ID to get the last un-preprocessed
observation for.
Returns:
Last un-preprocessed observation the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_raw_obs.get(agent_id)
last_reward_for(self, agent_id='agent0')
Returns the last reward for the specified agent, or zero.
The "last" reward is the one received most recently by the agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the last reward for. |
'agent0' |
Returns:
Type | Description |
---|---|
float |
Last reward for the the specified AgentID. Zero in case the agent has never performed any actions (and thus received rewards) in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the last reward for the specified agent, or zero.
The "last" reward is the one received most recently by the agent.
Args:
agent_id: The agent's ID to get the last reward for.
Returns:
Last reward for the the specified AgentID.
Zero in case the agent has never performed any actions
(and thus received rewards) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 0 -> Return previously received reward.
if len(history) >= 1:
return history[-1]
# We're at t=0, so there is no previous reward, just return zero.
else:
return 0.0
policy_for(self, agent_id='agent0')
Returns and stores the policy ID for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the agent to a policy for the duration of the entire episode (even if the policy_mapping_fn is changed in the meantime!).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent ID to lookup the policy ID for. |
'agent0' |
Returns:
Type | Description |
---|---|
str |
The policy ID for the specified agent. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
"""Returns and stores the policy ID for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the entire episode (even if the
policy_mapping_fn is changed in the meantime!).
Args:
agent_id: The agent ID to lookup the policy ID for.
Returns:
The policy ID for the specified agent.
"""
# Perform a new policy_mapping_fn lookup and bind AgentID for the
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy:
# Try new API: pass in agent_id and episode as named args.
# New signature should be: (agent_id, episode, worker, **kwargs)
try:
policy_id = self._agent_to_policy[agent_id] = \
self.policy_mapping_fn(agent_id, self, worker=self.worker)
except TypeError as e:
if "positional argument" in e.args[0] or \
"unexpected keyword argument" in e.args[0]:
if log_once("policy_mapping_new_signature"):
deprecation_warning(
old="policy_mapping_fn(agent_id)",
new="policy_mapping_fn(agent_id, episode, "
"worker, **kwargs)")
policy_id = self._agent_to_policy[agent_id] = \
self.policy_mapping_fn(agent_id)
else:
raise e
# Use already determined PolicyID.
else:
policy_id = self._agent_to_policy[agent_id]
# PolicyID not found in policy map -> Error.
if policy_id not in self.policy_map:
raise KeyError("policy_mapping_fn returned invalid policy id "
f"'{policy_id}'!")
return policy_id
prev_action_for(self, agent_id='agent0')
Returns the previous action for the specified agent, or zeros.
The "previous" action is the one taken one timestep before the most recent action taken by the agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the previous action for. |
'agent0' |
Returns:
Type | Description |
---|---|
Any |
Previous action the specified AgentID has executed. Zero in case the agent has never performed any actions (or only one) in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def prev_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the previous action for the specified agent, or zeros.
The "previous" action is the one taken one timestep before the
most recent action taken by the agent.
Args:
agent_id: The agent's ID to get the previous action for.
Returns:
Previous action the specified AgentID has executed.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
# We are at t > 1 -> There has been a previous action by this agent.
if agent_id in self._agent_to_prev_action:
return flatten_to_single_ndarray(
self._agent_to_prev_action[agent_id])
# We're at t <= 1, so return all zeros.
else:
return np.zeros_like(self.last_action_for(agent_id))
prev_reward_for(self, agent_id='agent0')
Returns the previous reward for the specified agent, or zero.
The "previous" reward is the one received one timestep before the most recently received reward of the agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the previous reward for. |
'agent0' |
Returns:
Type | Description |
---|---|
float |
Previous reward for the the specified AgentID. Zero in case the agent has never performed any actions (or only one) in the episode. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent, or zero.
The "previous" reward is the one received one timestep before the
most recently received reward of the agent.
Args:
agent_id: The agent's ID to get the previous reward for.
Returns:
Previous reward for the the specified AgentID.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 1 -> Return reward prior to most recent (last) one.
if len(history) >= 2:
return history[-2]
# We're at t <= 1, so there is no previous reward, just return zero.
else:
return 0.0
rnn_state_for(self, agent_id='agent0')
Returns the last RNN state for the specified agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
Any |
The agent's ID to get the most recent RNN state for. |
'agent0' |
Returns:
Type | Description |
---|---|
List[Any] |
Most recent RNN state of the the specified AgentID. |
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
"""Returns the last RNN state for the specified agent.
Args:
agent_id: The agent's ID to get the most recent RNN state for.
Returns:
Most recent RNN state of the the specified AgentID.
"""
if agent_id not in self._agent_to_rnn_state:
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
soft_reset(self)
Clears rewards and metrics, but retains RNN and other state.
This is used to carry state across multiple logical episodes in the
same env (i.e., if soft_horizon
is set).
Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def soft_reset(self) -> None:
"""Clears rewards and metrics, but retains RNN and other state.
This is used to carry state across multiple logical episodes in the
same env (i.e., if `soft_horizon` is set).
"""
self.length = 0
self.episode_id = random.randrange(2e9)
self.total_reward = 0.0
self.agent_rewards = defaultdict(float)
self._agent_reward_history = defaultdict(list)
ray.rllib.evaluation.episode.MultiAgentEpisode (Episode)
Rollouts
ray.rllib.evaluation.rollout_worker.RolloutWorker (ParallelIteratorWorker)
Common experience collection class.
This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training.
This class supports vectorized and multi-agent policy evaluation (e.g., VectorEnv, MultiAgentEnv, etc.)
Examples:
>>> # Create a rollout worker and using it to collect experiences.
>>> worker = RolloutWorker(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy_spec=PGTFPolicy)
>>> print(worker.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]})
>>> # Creating a multi-agent rollout worker
>>> worker = RolloutWorker(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policy_spec={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
... "car_policy2":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
... # Use a single shared policy for all traffic lights
... "traffic_light_policy":
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id, episode, **kwargs:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(worker.sample())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
"traffic_light_policy": SampleBatch(...)})
__init__(self, *, env_creator, validate_env=None, policy_spec=None, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, rollout_fragment_length=100, count_steps_by='env_steps', batch_mode='truncate_episodes', episode_horizon=None, preprocessor_pref='deepmind', sample_async=False, compress_observations=False, num_envs=1, observation_fn=None, observation_filter='NoFilter', clip_rewards=None, normalize_actions=True, clip_actions=False, env_config=None, model_config=None, policy_config=None, worker_index=0, num_workers=0, record_env=False, log_dir=None, log_level=None, callbacks=None, input_creator=<function RolloutWorker.<lambda> at 0x11d91bd40>, input_evaluation=frozenset(), output_creator=<function RolloutWorker.<lambda> at 0x11d91bdd0>, remote_worker_envs=False, remote_env_batch_wait_ms=0, soft_horizon=False, no_done_at_end=False, seed=None, extra_python_environs=None, fake_sampler=False, spaces=None, policy=None, monitor_path=None)
special
Initializes a RolloutWorker instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env_creator |
Callable[[ray.rllib.env.env_context.EnvContext], Any] |
Function that returns a gym.Env given an EnvContext wrapped configuration. |
required |
validate_env |
Optional[Callable[[Any, ray.rllib.env.env_context.EnvContext], NoneType]] |
Optional callable to validate the generated environment (only on worker=0). |
None |
policy_spec |
Union[type, Dict[str, ray.rllib.policy.policy.PolicySpec]] |
The MultiAgentPolicyConfigDict mapping policy IDs (str) to PolicySpec's or a single policy class to use. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn can also be set (if not, will map all agents to DEFAULT_POLICY_ID). |
None |
policy_mapping_fn |
Optional[Callable[[Any, Episode], str]] |
A callable that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. If not provided, will map all agents to DEFAULT_POLICY_ID. |
None |
policies_to_train |
Optional[List[str]] |
Optional list of policies to train, or None for all policies. |
None |
tf_session_creator |
Optional[Callable[[], tf1.Session]] |
A function that returns a TF session. This is optional and only useful with TFPolicy. |
None |
rollout_fragment_length |
int |
The target number of steps
(maesured in |
100 |
count_steps_by |
str |
The unit in which to count fragment lengths. One of env_steps or agent_steps. |
'env_steps' |
batch_mode |
str |
One of the following batch modes:
- "truncate_episodes": Each call to sample() will return a
batch of at most |
'truncate_episodes' |
episode_horizon |
Optional[int] |
Horizon at which to stop episodes (even if the environment itself has not retured a "done" signal). |
None |
preprocessor_pref |
str |
Whether to use RLlib preprocessors ("rllib") or deepmind ("deepmind"), when applicable. |
'deepmind' |
sample_async |
bool |
Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. |
False |
compress_observations |
bool |
If true, compress the observations. They can be decompressed with rllib/utils/compression. |
False |
num_envs |
int |
If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. |
1 |
observation_fn |
Optional[ObservationFunction] |
Optional multi-agent observation function. |
None |
observation_filter |
str |
Name of observation filter to use. |
'NoFilter' |
clip_rewards |
Union[bool, float] |
True for clipping rewards to [-1.0, 1.0] prior to experience postprocessing. None: Clip for Atari only. float: Clip to [-clip_rewards; +clip_rewards]. |
None |
normalize_actions |
bool |
Whether to normalize actions to the action space's bounds. |
True |
clip_actions |
bool |
Whether to clip action values to the range specified by the policy action space. |
False |
env_config |
Optional[dict] |
Config to pass to the env creator. |
None |
model_config |
Optional[dict] |
Config to use when creating the policy model. |
None |
policy_config |
Optional[dict] |
Config to pass to the
policy. In the multi-agent case, this config will be merged
with the per-policy configs specified by |
None |
worker_index |
int |
For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. |
0 |
num_workers |
int |
For remote workers, how many workers altogether have been created? |
0 |
record_env |
Union[bool, str] |
Write out episode stats and videos using gym.wrappers.Monitor to this directory if specified. If True, use the default output dir in ~/ray_results/.... If False, do not record anything. |
False |
log_dir |
Optional[str] |
Directory where logs can be placed. |
None |
log_level |
Optional[str] |
Set the root log level on creation. |
None |
callbacks |
Type[DefaultCallbacks] |
Custom sub-class of DefaultCallbacks for training/policy/rollout-worker callbacks. |
None |
input_creator |
Callable[[ray.rllib.offline.io_context.IOContext], ray.rllib.offline.input_reader.InputReader] |
Function that returns an InputReader object for loading previous generated experiences. |
<function RolloutWorker.<lambda> at 0x11d91bd40> |
input_evaluation |
List[str] |
How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning. |
frozenset() |
output_creator |
Callable[[ray.rllib.offline.io_context.IOContext], ray.rllib.offline.output_writer.OutputWriter] |
Function that returns an OutputWriter object for saving generated experiences. |
<function RolloutWorker.<lambda> at 0x11d91bdd0> |
remote_worker_envs |
bool |
If using num_envs_per_worker > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs are expensive to step/reset (e.g., for StarCraft). Use this cautiously, overheads are significant! |
False |
remote_env_batch_wait_ms |
int |
Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf. |
0 |
soft_horizon |
bool |
Calculate rewards but don't reset the environment when the horizon is hit. |
False |
no_done_at_end |
bool |
Ignore the done=True at the end of the episode and instead record done=False. |
False |
seed |
int |
Set the seed of both np and tf to this value to to ensure each remote worker has unique exploration behavior. |
None |
extra_python_environs |
Optional[dict] |
Extra python environments need to be set. |
None |
fake_sampler |
bool |
Use a fake (inf speed) sampler for testing. |
False |
spaces |
Optional[Dict[str, Tuple[gym.spaces.space.Space, gym.spaces.space.Space]]] |
An optional space dict mapping policy IDs to (obs_space, action_space)-tuples. This is used in case no Env is created on this RolloutWorker. |
None |
policy |
Obsoleted arg. Use |
None |
|
monitor_path |
Obsoleted arg. Use |
None |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def __init__(
self,
*,
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None,
policy_spec: Optional[Union[type, Dict[PolicyID,
PolicySpec]]] = None,
policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
rollout_fragment_length: int = 100,
count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes",
episode_horizon: Optional[int] = None,
preprocessor_pref: str = "deepmind",
sample_async: bool = False,
compress_observations: bool = False,
num_envs: int = 1,
observation_fn: Optional["ObservationFunction"] = None,
observation_filter: str = "NoFilter",
clip_rewards: Optional[Union[bool, float]] = None,
normalize_actions: bool = True,
clip_actions: bool = False,
env_config: Optional[EnvConfigDict] = None,
model_config: Optional[ModelConfigDict] = None,
policy_config: Optional[PartialTrainerConfigDict] = None,
worker_index: int = 0,
num_workers: int = 0,
record_env: Union[bool, str] = False,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
callbacks: Type["DefaultCallbacks"] = None,
input_creator: Callable[[
IOContext
], InputReader] = lambda ioctx: ioctx.default_sampler_input(),
input_evaluation: List[str] = frozenset([]),
output_creator: Callable[
[IOContext], OutputWriter] = lambda ioctx: NoopOutput(),
remote_worker_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
soft_horizon: bool = False,
no_done_at_end: bool = False,
seed: int = None,
extra_python_environs: Optional[dict] = None,
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
policy=None,
monitor_path=None,
):
"""Initializes a RolloutWorker instance.
Args:
env_creator: Function that returns a gym.Env given an EnvContext
wrapped configuration.
validate_env: Optional callable to validate the generated
environment (only on worker=0).
policy_spec: The MultiAgentPolicyConfigDict mapping policy IDs
(str) to PolicySpec's or a single policy class to use.
If a dict is specified, then we are in multi-agent mode and a
policy_mapping_fn can also be set (if not, will map all agents
to DEFAULT_POLICY_ID).
policy_mapping_fn: A callable that maps agent ids to policy ids in
multi-agent mode. This function will be called each time a new
agent appears in an episode, to bind that agent to a policy
for the duration of the episode. If not provided, will map all
agents to DEFAULT_POLICY_ID.
policies_to_train: Optional list of policies to train, or None
for all policies.
tf_session_creator: A function that returns a TF session.
This is optional and only useful with TFPolicy.
rollout_fragment_length: The target number of steps
(maesured in `count_steps_by`) to include in each sample
batch returned from this worker.
count_steps_by: The unit in which to count fragment
lengths. One of env_steps or agent_steps.
batch_mode: One of the following batch modes:
- "truncate_episodes": Each call to sample() will return a
batch of at most `rollout_fragment_length * num_envs` in size.
The batch will be exactly `rollout_fragment_length * num_envs`
in size if postprocessing does not change batch sizes. Episodes
may be truncated in order to meet this size requirement.
- "complete_episodes": Each call to sample() will return a
batch of at least `rollout_fragment_length * num_envs` in
size. Episodes will not be truncated, but multiple episodes
may be packed within one batch to meet the batch size. Note
that when `num_envs > 1`, episode steps will be buffered
until the episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon: Horizon at which to stop episodes (even if the
environment itself has not retured a "done" signal).
preprocessor_pref: Whether to use RLlib preprocessors
("rllib") or deepmind ("deepmind"), when applicable.
sample_async: Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
compress_observations: If true, compress the observations.
They can be decompressed with rllib/utils/compression.
num_envs: If more than one, will create multiple envs
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_fn: Optional multi-agent observation function.
observation_filter: Name of observation filter to use.
clip_rewards: True for clipping rewards to [-1.0, 1.0] prior
to experience postprocessing. None: Clip for Atari only.
float: Clip to [-clip_rewards; +clip_rewards].
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions: Whether to clip action values to the range
specified by the policy action space.
env_config: Config to pass to the env creator.
model_config: Config to use when creating the policy model.
policy_config: Config to pass to the
policy. In the multi-agent case, this config will be merged
with the per-policy configs specified by `policy_spec`.
worker_index: For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
num_workers: For remote workers, how many workers altogether
have been created?
record_env: Write out episode stats and videos
using gym.wrappers.Monitor to this directory if specified. If
True, use the default output dir in ~/ray_results/.... If
False, do not record anything.
log_dir: Directory where logs can be placed.
log_level: Set the root log level on creation.
callbacks: Custom sub-class of
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator: Function that returns an InputReader object for
loading previous generated experiences.
input_evaluation: How to evaluate the policy
performance. This only makes sense to set when the input is
reading offline data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
output_creator: Function that returns an OutputWriter object for
saving generated experiences.
remote_worker_envs: If using num_envs_per_worker > 1,
whether to create those new envs in remote processes instead of
in the current process. This adds overheads, but can make sense
if your envs are expensive to step/reset (e.g., for StarCraft).
Use this cautiously, overheads are significant!
remote_env_batch_wait_ms: Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False.
seed: Set the seed of both np and tf to this value to
to ensure each remote worker has unique exploration behavior.
extra_python_environs: Extra python environments need to be set.
fake_sampler: Use a fake (inf speed) sampler for testing.
spaces: An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
"""
# Deprecated args.
if policy is not None:
deprecation_warning("policy", "policy_spec", error=False)
policy_spec = policy
assert policy_spec is not None, \
"Must provide `policy_spec` when creating RolloutWorker!"
# Do quick translation into MultiAgentPolicyConfigDict.
if not isinstance(policy_spec, dict):
policy_spec = {
DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)
}
policy_spec = {
pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec)
for pid, spec in policy_spec.copy().items()
}
if monitor_path is not None:
deprecation_warning("monitor_path", "record_env", error=False)
record_env = monitor_path
self._original_kwargs: dict = locals().copy()
del self._original_kwargs["self"]
global _global_worker
_global_worker = self
# set extra environs first
if extra_python_environs:
for key, value in extra_python_environs.items():
os.environ[key] = str(value)
def gen_rollouts():
while True:
yield self.sample()
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
policy_config = policy_config or {}
if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
# This eager check is necessary for certain all-framework tests
# that use tf's eager_mode() context generator.
and not tf1.executing_eagerly()):
tf1.enable_eager_execution()
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)
if worker_index > 1:
disable_log_once_globally() # only need 1 worker to log
elif log_level == "DEBUG":
enable_periodic_logging()
env_context = EnvContext(
env_config or {},
worker_index=worker_index,
vector_index=0,
num_workers=num_workers,
)
self.env_context = env_context
self.policy_config: PartialTrainerConfigDict = policy_config
if callbacks:
self.callbacks: "DefaultCallbacks" = callbacks()
else:
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = \
model_config or self.policy_config.get("model") or {}
# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
self.policy_mapping_fn = \
lambda agent_id, episode, worker, **kwargs: DEFAULT_POLICY_ID
# If provided, set it here.
self.set_policy_mapping_fn(policy_mapping_fn)
self.env_creator: Callable[[EnvContext], EnvType] = env_creator
self.rollout_fragment_length: int = rollout_fragment_length * num_envs
self.count_steps_by: str = count_steps_by
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = False \
if policy_config.get("_disable_preprocessor_api") else True
self.observation_filter = observation_filter
self.last_batch: Optional[SampleBatchType] = None
self.global_vars: Optional[dict] = None
self.fake_sampler: bool = fake_sampler
# Update the global seed for numpy/random/tf-eager/torch if we are not
# the local worker, otherwise, this was already done in the Trainer
# object itself.
if self.worker_index > 0:
update_global_seed_if_necessary(
policy_config.get("framework"), seed)
# A single environment provided by the user (via config.env). This may
# also remain None.
# 1) Create the env using the user provided env_creator. This may
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
# 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
# 3) Seed the env, if necessary.
# 4) Vectorize the existing single env by creating more clones of
# this env and wrapping it with the RLlib BaseEnv class.
self.env = None
# Create a (single) env for this worker.
if not (worker_index == 0 and num_workers > 0
and not policy_config.get("create_env_on_driver")):
# Run the `env_creator` function passing the EnvContext.
self.env = env_creator(copy.deepcopy(self.env_context))
if self.env is not None:
# Validate environment (general validation function).
_validate_env(self.env, env_context=self.env_context)
# Custom validation function given.
if validate_env is not None:
validate_env(self.env, self.env_context)
# We can't auto-wrap a BaseEnv.
if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
def wrap(env):
return env
# Atari type env and "deepmind" preprocessor pref.
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
# Deepmind wrappers already handle all preprocessing.
self.preprocessing_enabled = False
# If clip_rewards not explicitly set to False, switch it
# on here (clip between -1.0 and 1.0).
if clip_rewards is None:
clip_rewards = True
# Framestacking is used.
use_framestack = model_config.get("framestack") is True
def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=use_framestack)
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
return env
# gym.Env -> Wrap with gym Monitor.
else:
def wrap(env):
return record_env_wrapper(env, record_env, log_dir,
policy_config)
# Wrap env through the correct wrapper.
self.env: EnvType = wrap(self.env)
# Ideally, we would use the same make_sub_env() function below
# to create self.env, but wrap(env) and self.env has a cyclic
# dependency on each other right now, so we would settle on
# duplicating the random seed setting logic for now.
_update_env_seed_if_necessary(self.env, seed, worker_index, 0)
def make_sub_env(vector_index):
# Used to created additional environments during environment
# vectorization.
# Create the env context (config dict + meta-data) for
# this particular sub-env within the vectorized one.
env_ctx = env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs)
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
_validate_env(env, env_context=env_ctx)
# Custom validation function given by user.
if validate_env is not None:
validate_env(env, env_ctx)
# Use our wrapper, defined above.
env = wrap(env)
# Make sure a deterministic random seed is set on
# all the sub-environments if specified.
_update_env_seed_if_necessary(env, seed, worker_index,
vector_index)
return env
self.make_sub_env_fn = make_sub_env
self.spaces = spaces
policy_dict = _determine_spaces_for_multi_agent_dict(
policy_spec,
self.env,
spaces=self.spaces,
policy_config=policy_config)
# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
self.set_policies_to_train(self.policies_to_train)
self.policy_map: PolicyMap = None
self.preprocessors: Dict[PolicyID, Preprocessor] = None
# Check available number of GPUs.
num_gpus = policy_config.get("num_gpus", 0) if \
self.worker_index == 0 else \
policy_config.get("num_gpus_per_worker", 0)
# Error if we don't find enough GPUs.
if ray.is_initialized() and \
ray.worker._mode() != ray.worker.LOCAL_MODE and \
not policy_config.get("_fake_gpus"):
devices = []
if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
devices = get_tf_gpu_devices()
elif policy_config.get("framework") == "torch":
devices = list(range(torch.cuda.device_count()))
if len(devices) < num_gpus:
raise RuntimeError(
ERR_MSG_NO_GPUS.format(len(devices), devices) +
HOWTO_CHANGE_CONFIG)
# Warn, if running in local-mode and actual GPUs (not faked) are
# requested.
elif ray.is_initialized() and \
ray.worker._mode() == ray.worker.LOCAL_MODE and \
num_gpus > 0 and not policy_config.get("_fake_gpus"):
logger.warning(
"You are running ray with `local_mode=True`, but have "
f"configured {num_gpus} GPUs to be used! In local mode, "
f"Policies are placed on the CPU and the `num_gpus` setting "
f"is ignored.")
self._build_policy_map(
policy_dict,
policy_config,
session_creator=tf_session_creator,
seed=seed)
# Update Policy's view requirements from Model, only if Policy directly
# inherited from base `Policy` class. At this point here, the Policy
# must have it's Model (if any) defined and ready to output an initial
# state.
for pol in self.policy_map.values():
if not pol._model_init_state_automatically_added:
pol._update_model_view_requirements_from_init_state()
self.multiagent: bool = set(
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
if not isinstance(self.env,
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
ray.actor.ActorHandle)):
raise ValueError(
f"Have multiple policies {self.policy_map}, but the "
f"env {self.env} is not a subclass of BaseEnv, "
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")
self.filters: Dict[PolicyID, Filter] = {
policy_id: get_filter(self.observation_filter,
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
# Vectorize environment, if any.
self.num_envs: int = num_envs
# This RolloutWorker has no env.
if self.env is None:
self.async_env = None
# Use a custom env-vectorizer and call it providing self.env.
elif "custom_vector_env" in policy_config:
self.async_env = policy_config["custom_vector_env"](self.env)
# Default: Vectorize self.env via the make_sub_env function. This adds
# further clones of self.env and creates a RLlib BaseEnv (which is
# vectorized under the hood).
else:
# Always use vector env for consistency even if num_envs = 1.
self.async_env: BaseEnv = BaseEnv.to_base_env(
self.env,
make_env=self.make_sub_env_fn,
num_envs=num_envs,
remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
policy_config=policy_config,
)
# `truncate_episodes`: Allow a batch to contain more than one episode
# (fragments) and always make the batch `rollout_fragment_length`
# long.
if self.batch_mode == "truncate_episodes":
pack = True
# `complete_episodes`: Never cut episodes and sampler will return
# exactly one (complete) episode per poll.
elif self.batch_mode == "complete_episodes":
rollout_fragment_length = float("inf")
pack = False
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
# Create the IOContext for this worker.
self.io_context: IOContext = IOContext(log_dir, policy_config,
worker_index, self)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.\
create_from_io_context(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.\
create_from_io_context(self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(
"Unknown evaluation method: {}".format(method))
render = False
if policy_config.get("render_env") is True and \
(num_workers == 0 or worker_index == 1):
render = True
if self.env is None:
self.sampler = None
elif sample_async:
self.sampler = AsyncSampler(
worker=self,
env=self.async_env,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
sample_collector_class=policy_config.get("sample_collector"),
render=render,
)
# Start the Sampler thread.
self.sampler.start()
else:
self.sampler = SyncSampler(
worker=self,
env=self.async_env,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
sample_collector_class=policy_config.get("sample_collector"),
render=render,
)
self.input_reader: InputReader = input_creator(self.io_context)
self.output_writer: OutputWriter = output_creator(self.io_context)
logger.debug(
"Created rollout worker with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))
add_policy(self, *, policy_id, policy_cls, observation_space=None, action_space=None, config=None, policy_mapping_fn=None, policies_to_train=None)
Adds a new policy to this RolloutWorker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
str |
ID of the policy to add. |
required |
policy_cls |
Type[ray.rllib.policy.policy.Policy] |
The Policy class to use for constructing the new Policy. |
required |
observation_space |
Optional[gym.spaces.space.Space] |
The observation space of the policy to add. |
None |
action_space |
Optional[gym.spaces.space.Space] |
The action space of the policy to add. |
None |
config |
Optional[dict] |
The config overrides for the policy to add. |
None |
policy_config |
The base config of the Trainer object owning this RolloutWorker. |
required | |
policy_mapping_fn |
Optional[Callable[[Any, Episode], str]] |
An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. |
None |
policies_to_train |
Optional[List[str]] |
An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated. |
None |
Returns:
Type | Description |
---|---|
Policy |
The newly added policy. |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def add_policy(
self,
*,
policy_id: PolicyID,
policy_cls: Type[Policy],
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
) -> Policy:
"""Adds a new policy to this RolloutWorker.
Args:
policy_id: ID of the policy to add.
policy_cls: The Policy class to use for constructing the new
Policy.
observation_space: The observation space of the policy to add.
action_space: The action space of the policy to add.
config: The config overrides for the policy to add.
policy_config: The base config of the Trainer object owning this
RolloutWorker.
policy_mapping_fn: An optional (updated) policy mapping function
to use from here on. Note that already ongoing episodes will
not change their mapping but will use the old mapping till
the end of the episode.
policies_to_train: An optional list of policy IDs to be trained.
If None, will keep the existing list in place. Policies,
whose IDs are not in the list will not be updated.
Returns:
The newly added policy.
"""
if policy_id in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
policy_dict = _determine_spaces_for_multi_agent_dict(
{
policy_id: PolicySpec(policy_cls, observation_space,
action_space, config or {})
},
self.env,
spaces=self.spaces,
policy_config=self.policy_config,
)
self._build_policy_map(
policy_dict,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]
self.filters[policy_id] = get_filter(
self.observation_filter, new_policy.observation_space.shape)
self.set_policy_mapping_fn(policy_mapping_fn)
self.set_policies_to_train(policies_to_train)
return new_policy
apply(self, func, *args)
Calls the given function with this rollout worker instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[RolloutWorker, Optional[Any]], ~T] |
The function to call with this RolloutWorker as first argument. |
required |
args |
Optional additional args to pass to the function call. |
() |
Returns:
Type | Description |
---|---|
~T |
The return value of the function call. |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T],
*args) -> T:
"""Calls the given function with this rollout worker instance.
Args:
func: The function to call with this RolloutWorker as first
argument.
args: Optional additional args to pass to the function call.
Returns:
The return value of the function call.
"""
return func(self, *args)
apply_gradients(self, grads)
Applies the given gradients to this worker's models.
Uses the Policy's/ies' apply_gradients method(s) to perform the operations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grads |
Union[List[Tuple[Any, Any]], List[Any], Dict[str, Union[List[Tuple[Any, Any]], List[Any]]]] |
Single ModelGradients (single-agent case) or a dict mapping PolicyIDs to the respective model gradients structs. |
required |
Examples:
>>> samples = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
>>> worker.apply_gradients(grads)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def apply_gradients(
self,
grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
) -> None:
"""Applies the given gradients to this worker's models.
Uses the Policy's/ies' apply_gradients method(s) to perform the
operations.
Args:
grads: Single ModelGradients (single-agent case) or a dict
mapping PolicyIDs to the respective model gradients
structs.
Examples:
>>> samples = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
>>> worker.apply_gradients(grads)
"""
if log_once("apply_gradients"):
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
# Grads is a dict (mapping PolicyIDs to ModelGradients).
# Multi-agent case.
if isinstance(grads, dict):
for pid, g in grads.items():
if pid in self.policies_to_train:
self.policy_map[pid].apply_gradients(g)
# Grads is a ModelGradients type. Single-agent case.
elif DEFAULT_POLICY_ID in self.policies_to_train:
self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
as_remote(num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None)
classmethod
Returns RolloutWorker class as a @ray.remote using given options
.
The returned class can then be used to instantiate ray actors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_cpus |
Optional[int] |
The number of CPUs to allocate for the remote actor. |
None |
num_gpus |
Union[int, float] |
The number of GPUs to allocate for the remote actor. This could be a fraction as well. |
None |
memory |
Optional[int] |
The heap memory request for the remote actor. |
None |
object_store_memory |
Optional[int] |
The object store memory for the remote actor. |
None |
resources |
Optional[dict] |
The default custom resources to allocate for the remote actor. |
None |
Returns:
Type | Description |
---|---|
type |
The |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
@classmethod
def as_remote(cls,
num_cpus: Optional[int] = None,
num_gpus: Optional[Union[int, float]] = None,
memory: Optional[int] = None,
object_store_memory: Optional[int] = None,
resources: Optional[dict] = None) -> type:
"""Returns RolloutWorker class as a `@ray.remote using given options`.
The returned class can then be used to instantiate ray actors.
Args:
num_cpus: The number of CPUs to allocate for the remote actor.
num_gpus: The number of GPUs to allocate for the remote actor.
This could be a fraction as well.
memory: The heap memory request for the remote actor.
object_store_memory: The object store memory for the remote actor.
resources: The default custom resources to allocate for the remote
actor.
Returns:
The `@ray.remote` decorated RolloutWorker class.
"""
return ray.remote(
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources)(cls)
compute_gradients(self, samples)
Returns a gradient computed w.r.t the specified samples.
Uses the Policy's/ies' compute_gradients method(s) to perform the calculations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
Union[SampleBatch, MultiAgentBatch] |
The SampleBatch or MultiAgentBatch to compute gradients for using this worker's policies. |
required |
Returns:
Type | Description |
---|---|
Tuple[Union[List[Tuple[Any, Any]], List[Any]], dict] |
In the single-agent case, a tuple consisting of ModelGradients and
info dict of the worker's policy.
In the multi-agent case, a tuple consisting of a dict mapping
PolicyID to ModelGradients and a dict mapping PolicyID to extra
metadata info.
Note that the first return value (grads) can be applied as is to a
compatible worker using the worker's |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def compute_gradients(
self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]:
"""Returns a gradient computed w.r.t the specified samples.
Uses the Policy's/ies' compute_gradients method(s) to perform the
calculations.
Args:
samples: The SampleBatch or MultiAgentBatch to compute gradients
for using this worker's policies.
Returns:
In the single-agent case, a tuple consisting of ModelGradients and
info dict of the worker's policy.
In the multi-agent case, a tuple consisting of a dict mapping
PolicyID to ModelGradients and a dict mapping PolicyID to extra
metadata info.
Note that the first return value (grads) can be applied as is to a
compatible worker using the worker's `apply_gradients()` method.
Examples:
>>> batch = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
"""
if log_once("compute_gradients"):
logger.info("Compute gradients on:\n\n{}\n".format(
summarize(samples)))
# MultiAgentBatch -> Calculate gradients for all policies.
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.policy_config.get("framework") == "tf":
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
policy = self.policy_map[pid]
builder = TFRunBuilder(policy.get_session(),
"compute_gradients")
grad_out[pid], info_out[pid] = (
policy._build_compute_gradients(builder, batch))
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid].compute_gradients(batch))
# SampleBatch -> Calculate gradients for the default policy.
else:
grad_out, info_out = (
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
info_out["batch_count"] = samples.count
if log_once("grad_out"):
logger.info("Compute grad info:\n\n{}\n".format(
summarize(info_out)))
return grad_out, info_out
creation_args(self)
find_free_port(self)
for_policy(self, func, policy_id='default_policy', **kwargs)
Calls the given function with the specified policy as first arg.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[ray.rllib.policy.policy.Policy, Optional[Any]], ~T] |
The function to call with the policy as first arg. |
required |
policy_id |
Optional[str] |
The PolicyID of the policy to call the function with. |
'default_policy' |
Returns:
Type | Description |
---|---|
~T |
The return value of the function call. |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def for_policy(self,
func: Callable[[Policy, Optional[Any]], T],
policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
**kwargs) -> T:
"""Calls the given function with the specified policy as first arg.
Args:
func: The function to call with the policy as first arg.
policy_id: The PolicyID of the policy to call the function with.
Keyword Args:
kwargs: Additional kwargs to be passed to the call.
Returns:
The return value of the function call.
"""
return func(self.policy_map[policy_id], **kwargs)
foreach_env(self, func)
Calls the given function with each sub-environment as arg.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[Any], ~T] |
The function to call for each underlying sub-environment (as only arg). |
required |
Returns:
Type | Description |
---|---|
List[~T] |
The list of return values of all calls to |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
"""Calls the given function with each sub-environment as arg.
Args:
func: The function to call for each underlying
sub-environment (as only arg).
Returns:
The list of return values of all calls to `func([env])`.
"""
if self.async_env is None:
return []
envs = self.async_env.get_sub_environments()
# Empty list (not implemented): Call function directly on the
# BaseEnv.
if not envs:
return [func(self.async_env)]
# Call function on all underlying (vectorized) sub environments.
else:
return [func(e) for e in envs]
foreach_env_with_context(self, func)
Calls given function with each sub-env plus env_ctx as args.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[Any, ray.rllib.env.env_context.EnvContext], ~T] |
The function to call for each underlying sub-environment and its EnvContext (as the args). |
required |
Returns:
Type | Description |
---|---|
List[~T] |
The list of return values of all calls to |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_env_with_context(
self, func: Callable[[EnvType, EnvContext], T]) -> List[T]:
"""Calls given function with each sub-env plus env_ctx as args.
Args:
func: The function to call for each underlying
sub-environment and its EnvContext (as the args).
Returns:
The list of return values of all calls to `func([env, ctx])`.
"""
if self.async_env is None:
return []
envs = self.async_env.get_sub_environments()
# Empty list (not implemented): Call function directly on the
# BaseEnv.
if not envs:
return [func(self.async_env, self.env_context)]
# Call function on all underlying (vectorized) sub environments.
else:
ret = []
for i, e in enumerate(envs):
ctx = self.env_context.copy_with_overrides(vector_index=i)
ret.append(func(e, ctx))
return ret
foreach_policy(self, func, **kwargs)
Calls the given function with each (policy, policy_id) tuple.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[ray.rllib.policy.policy.Policy, str, Optional[Any]], ~T] |
The function to call with each (policy, policy ID) tuple. |
required |
Returns:
Type | Description |
---|---|
List[~T] |
The list of return values of all calls to
|
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_policy(self,
func: Callable[[Policy, PolicyID, Optional[Any]], T],
**kwargs) -> List[T]:
"""Calls the given function with each (policy, policy_id) tuple.
Args:
func: The function to call with each (policy, policy ID) tuple.
Keyword Args:
kwargs: Additional kwargs to be passed to the call.
Returns:
The list of return values of all calls to
`func([policy, pid, **kwargs])`.
"""
return [
func(policy, pid, **kwargs)
for pid, policy in self.policy_map.items()
]
foreach_trainable_policy(self, func, **kwargs)
Calls the given function with each (policy, policy_id) tuple.
Only those policies/IDs will be called on, which can be found in
self.policies_to_train
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[[ray.rllib.policy.policy.Policy, str, Optional[Any]], ~T] |
The function to call with each (policy, policy ID) tuple,
for only those policies that are in |
required |
Returns:
Type | Description |
---|---|
List[~T] |
The list of return values of all calls to
|
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_trainable_policy(
self, func: Callable[[Policy, PolicyID, Optional[Any]], T],
**kwargs) -> List[T]:
"""
Calls the given function with each (policy, policy_id) tuple.
Only those policies/IDs will be called on, which can be found in
`self.policies_to_train`.
Args:
func: The function to call with each (policy, policy ID) tuple,
for only those policies that are in `self.policies_to_train`.
Keyword Args:
kwargs: Additional kwargs to be passed to the call.
Returns:
The list of return values of all calls to
`func([policy, pid, **kwargs])`.
"""
return [
func(policy, pid, **kwargs)
for pid, policy in self.policy_map.items()
if pid in self.policies_to_train
]
get_filters(self, flush_after=False)
Returns a snapshot of filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flush_after |
bool |
Clears the filter buffer state. |
False |
Returns:
Type | Description |
---|---|
Dict |
Dict for serializable filters |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_filters(self, flush_after: bool = False) -> Dict:
"""Returns a snapshot of filters.
Args:
flush_after: Clears the filter buffer state.
Returns:
Dict for serializable filters
"""
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
get_global_vars(self)
Returns the current global_vars dict of this worker.
Returns:
Type | Description |
---|---|
dict |
The current global_vars dict of this worker. |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
get_host(self)
get_metrics(self)
Returns the thus-far collected metrics from this worker's rollouts.
Returns:
Type | Description |
---|---|
List[Union[ray.rllib.evaluation.metrics.RolloutMetrics, ray.rllib.offline.off_policy_estimator.OffPolicyEstimate]] |
List of RolloutMetrics and/or OffPolicyEstimate objects collected thus-far. |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
"""Returns the thus-far collected metrics from this worker's rollouts.
Returns:
List of RolloutMetrics and/or OffPolicyEstimate objects
collected thus-far.
"""
# Get metrics from sampler (if any).
if self.sampler is not None:
out = self.sampler.get_metrics()
else:
out = []
# Get metrics from our reward-estimators (if any).
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
get_node_ip(self)
get_policy(self, policy_id='default_policy')
Return policy for the specified id, or None.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
str |
ID of the policy to return. None for DEFAULT_POLICY_ID (in the single agent case). |
'default_policy' |
Returns:
Type | Description |
---|---|
Optional[ray.rllib.policy.policy.Policy] |
The policy under the given ID (or None if not found). |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
Optional[Policy]:
"""Return policy for the specified id, or None.
Args:
policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
(in the single agent case).
Returns:
The policy under the given ID (or None if not found).
"""
return self.policy_map.get(policy_id)
get_weights(self, policies=None)
Returns each policies' model weights of this worker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policies |
Optional[List[str]] |
List of PolicyIDs to get the weights from. Use None for all policies. |
None |
Returns:
Type | Description |
---|---|
Dict[str, dict] |
Dict mapping PolicyIDs to ModelWeights. |
Examples:
>>> weights = worker.get_weights()
>>> print(weights)
{"default_policy": {"layer1": array(...), "layer2": ...}}
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_weights(
self,
policies: Optional[List[PolicyID]] = None,
) -> Dict[PolicyID, ModelWeights]:
"""Returns each policies' model weights of this worker.
Args:
policies: List of PolicyIDs to get the weights from.
Use None for all policies.
Returns:
Dict mapping PolicyIDs to ModelWeights.
Examples:
>>> weights = worker.get_weights()
>>> print(weights)
{"default_policy": {"layer1": array(...), "layer2": ...}}
"""
if policies is None:
policies = list(self.policy_map.keys())
policies = force_list(policies)
return {
pid: policy.get_weights()
for pid, policy in self.policy_map.items() if pid in policies
}
learn_on_batch(self, samples)
Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)), but can be optimized to avoid pulling gradients into CPU memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
Union[SampleBatch, MultiAgentBatch] |
The SampleBatch or MultiAgentBatch to learn on. |
required |
Returns:
Type | Description |
---|---|
Dict |
Dictionary of extra metadata from compute_gradients(). |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def learn_on_batch(self, samples: SampleBatchType) -> Dict:
"""Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)),
but can be optimized to avoid pulling gradients into CPU memory.
Args:
samples: The SampleBatch or MultiAgentBatch to learn on.
Returns:
Dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = worker.sample()
>>> info = worker.learn_on_batch(samples)
"""
if log_once("learn_on_batch"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
builders = {}
to_fetch = {}
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
policy = self.policy_map[pid]
tf_session = policy.get_session()
if tf_session and hasattr(policy, "_build_learn_on_batch"):
builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
to_fetch[pid] = policy._build_learn_on_batch(
builders[pid], batch)
else:
info_out[pid] = policy.learn_on_batch(batch)
info_out.update(
{pid: builders[pid].get(v)
for pid, v in to_fetch.items()})
else:
info_out = {
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
.learn_on_batch(samples)
}
if log_once("learn_out"):
logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
return info_out
remove_policy(self, *, policy_id='default_policy', policy_mapping_fn=None, policies_to_train=None)
Removes a policy from this RolloutWorker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
str |
ID of the policy to be removed. None for DEFAULT_POLICY_ID. |
'default_policy' |
policy_mapping_fn |
Optional[Callable[[Any], str]] |
An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. |
None |
policies_to_train |
Optional[List[str]] |
An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated. |
None |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def remove_policy(
self,
*,
policy_id: PolicyID = DEFAULT_POLICY_ID,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
) -> None:
"""Removes a policy from this RolloutWorker.
Args:
policy_id: ID of the policy to be removed. None for
DEFAULT_POLICY_ID.
policy_mapping_fn: An optional (updated) policy mapping function
to use from here on. Note that already ongoing episodes will
not change their mapping but will use the old mapping till
the end of the episode.
policies_to_train: An optional list of policy IDs to be trained.
If None, will keep the existing list in place. Policies,
whose IDs are not in the list will not be updated.
"""
if policy_id not in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
del self.policy_map[policy_id]
del self.preprocessors[policy_id]
self.set_policy_mapping_fn(policy_mapping_fn)
self.set_policies_to_train(policies_to_train)
restore(self, objs)
Restores this RolloutWorker's state from a sequence of bytes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
objs |
bytes |
The byte sequence to restore this worker's state from. |
required |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def restore(self, objs: bytes) -> None:
"""Restores this RolloutWorker's state from a sequence of bytes.
Args:
objs: The byte sequence to restore this worker's state from.
Examples:
>>> state = worker.save()
>>> new_worker = RolloutWorker(...)
>>> new_worker.restore(state)
"""
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
if pid not in self.policy_map:
pol_spec = objs.get("policy_specs", {}).get(pid)
if not pol_spec:
logger.warning(
f"PolicyID '{pid}' was probably added on-the-fly (not"
" part of the static `multagent.policies` config) and"
" no PolicySpec objects found in the pickled policy "
"state. Will not add `{pid}`, but ignore it for now.")
else:
self.add_policy(
policy_id=pid,
policy_cls=pol_spec.policy_class,
observation_space=pol_spec.observation_space,
action_space=pol_spec.action_space,
config=pol_spec.config,
)
else:
self.policy_map[pid].set_state(state)
sample(self)
Returns a batch of experience sampled from this worker.
This method must be implemented by subclasses.
Returns:
Type | Description |
---|---|
Union[SampleBatch, MultiAgentBatch] |
A columnar batch of experiences (e.g., tensors). |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def sample(self) -> SampleBatchType:
"""Returns a batch of experience sampled from this worker.
This method must be implemented by subclasses.
Returns:
A columnar batch of experiences (e.g., tensors).
Examples:
>>> print(worker.sample())
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
"""
if self.fake_sampler and self.last_batch is not None:
return self.last_batch
elif self.input_reader is None:
raise ValueError("RolloutWorker has no `input_reader` object! "
"Cannot call `sample()`. You can try setting "
"`create_env_on_driver` to True.")
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
self.rollout_fragment_length))
batches = [self.input_reader.next()]
steps_so_far = batches[0].count if \
self.count_steps_by == "env_steps" else \
batches[0].agent_steps()
# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")
while (steps_so_far < self.rollout_fragment_length
and len(batches) < max_batches):
batch = self.input_reader.next()
steps_so_far += batch.count if \
self.count_steps_by == "env_steps" else \
batch.agent_steps()
batches.append(batch)
batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
batches[0]
self.callbacks.on_sample_end(worker=self, samples=batch)
# Always do writes prior to compression for consistency and to allow
# for better compression inside the writer.
self.output_writer.write(batch)
# Do off-policy estimation, if needed.
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(
summarize(batch)))
if self.compress_observations:
batch.compress(bulk=self.compress_observations == "bulk")
if self.fake_sampler:
self.last_batch = batch
return batch
sample_and_learn(self, expected_batch_size, num_sgd_iter, sgd_minibatch_size, standardize_fields)
Sample and batch and learn on it.
This is typically used in combination with distributed allreduce.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expected_batch_size |
int |
Expected number of samples to learn on. |
required |
num_sgd_iter |
int |
Number of SGD iterations. |
required |
sgd_minibatch_size |
str |
SGD minibatch size. |
required |
standardize_fields |
List[str] |
List of sample fields to normalize. |
required |
Returns:
Type | Description |
---|---|
Tuple[dict, int] |
A tuple consisting of a dictionary of extra metadata returned from
the policies' |
Source code in ray/rllib/evaluation/rollout_worker.py
def sample_and_learn(self, expected_batch_size: int, num_sgd_iter: int,
sgd_minibatch_size: str,
standardize_fields: List[str]) -> Tuple[dict, int]:
"""Sample and batch and learn on it.
This is typically used in combination with distributed allreduce.
Args:
expected_batch_size: Expected number of samples to learn on.
num_sgd_iter: Number of SGD iterations.
sgd_minibatch_size: SGD minibatch size.
standardize_fields: List of sample fields to normalize.
Returns:
A tuple consisting of a dictionary of extra metadata returned from
the policies' `learn_on_batch()` and the number of samples
learned on.
"""
batch = self.sample()
assert batch.count == expected_batch_size, \
("Batch size possibly out of sync between workers, expected:",
expected_batch_size, "got:", batch.count)
logger.info("Executing distributed minibatch SGD "
"with epoch size {}, minibatch size {}".format(
batch.count, sgd_minibatch_size))
info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
sgd_minibatch_size, standardize_fields)
return info, batch.count
sample_with_count(self)
Same as sample() but returns the count as a separate value.
Returns:
Type | Description |
---|---|
Tuple[Union[SampleBatch, MultiAgentBatch], int] |
A columnar batch of experiences (e.g., tensors) and the size of the collected batch. |
Examples:
>>> print(worker.sample_with_count())
(SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...}), 3)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
@ray.method(num_returns=2)
def sample_with_count(self) -> Tuple[SampleBatchType, int]:
"""Same as sample() but returns the count as a separate value.
Returns:
A columnar batch of experiences (e.g., tensors) and the
size of the collected batch.
Examples:
>>> print(worker.sample_with_count())
(SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...}), 3)
"""
batch = self.sample()
return batch, batch.count
save(self)
Serializes this RolloutWorker's current state and returns it.
Returns:
Type | Description |
---|---|
bytes |
The current state of this RolloutWorker as a serialized, pickled byte sequence. |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def save(self) -> bytes:
"""Serializes this RolloutWorker's current state and returns it.
Returns:
The current state of this RolloutWorker as a serialized, pickled
byte sequence.
"""
filters = self.get_filters(flush_after=True)
state = {}
policy_specs = {}
for pid in self.policy_map:
state[pid] = self.policy_map[pid].get_state()
policy_specs[pid] = self.policy_map.policy_specs[pid]
return pickle.dumps({
"filters": filters,
"state": state,
"policy_specs": policy_specs,
})
set_global_vars(self, global_vars)
Updates this worker's and all its policies' global vars.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_vars |
dict |
The new global_vars dict. |
required |
Examples:
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_global_vars(self, global_vars: dict) -> None:
"""Updates this worker's and all its policies' global vars.
Args:
global_vars: The new global_vars dict.
Examples:
>>> global_vars = worker.set_global_vars({"timestep": 4242})
"""
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
self.global_vars = global_vars
set_policies_to_train(self, policies_to_train=None)
Sets self.policies_to_train
to a new list of PolicyIDs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policies_to_train |
Optional[List[str]] |
The new list of policy IDs to train with. If None, will keep the existing list in place. |
None |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_policies_to_train(
self, policies_to_train: Optional[List[PolicyID]] = None) -> None:
"""Sets `self.policies_to_train` to a new list of PolicyIDs.
Args:
policies_to_train: The new list of policy IDs to train with.
If None, will keep the existing list in place.
"""
if policies_to_train is not None:
self.policies_to_train = policies_to_train
set_policy_mapping_fn(self, policy_mapping_fn=None)
Sets self.policy_mapping_fn
to a new callable (if provided).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_mapping_fn |
Optional[Callable[[Any, Episode], str]] |
The new mapping function to use. If None, will keep the existing mapping function in place. |
None |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_policy_mapping_fn(
self,
policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
PolicyID]] = None,
) -> None:
"""Sets `self.policy_mapping_fn` to a new callable (if provided).
Args:
policy_mapping_fn: The new mapping function to use. If None,
will keep the existing mapping function in place.
"""
if policy_mapping_fn is not None:
self.policy_mapping_fn = policy_mapping_fn
if not callable(self.policy_mapping_fn):
raise ValueError("`policy_mapping_fn` must be a callable!")
set_weights(self, weights, global_vars=None)
Sets each policies' model weights of this worker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights |
Dict[str, dict] |
Dict mapping PolicyIDs to the new weights to be used. |
required |
global_vars |
Optional[Dict] |
An optional global vars dict to set this worker to. If None, do not update the global_vars. |
None |
Examples:
>>> weights = worker.get_weights()
>>> # Set `global_vars` (timestep) as well.
>>> worker.set_weights(weights, {"timestep": 42})
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_weights(self,
weights: Dict[PolicyID, ModelWeights],
global_vars: Optional[Dict] = None) -> None:
"""Sets each policies' model weights of this worker.
Args:
weights: Dict mapping PolicyIDs to the new weights to be used.
global_vars: An optional global vars dict to set this
worker to. If None, do not update the global_vars.
Examples:
>>> weights = worker.get_weights()
>>> # Set `global_vars` (timestep) as well.
>>> worker.set_weights(weights, {"timestep": 42})
"""
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
if global_vars:
self.set_global_vars(global_vars)
setup_torch_data_parallel(self, url, world_rank, world_size, backend)
Join a torch process group for distributed SGD.
Source code in ray/rllib/evaluation/rollout_worker.py
def setup_torch_data_parallel(self, url: str, world_rank: int,
world_size: int, backend: str) -> None:
"""Join a torch process group for distributed SGD."""
logger.info("Joining process group, url={}, world_rank={}, "
"world_size={}, backend={}".format(url, world_rank,
world_size, backend))
torch.distributed.init_process_group(
backend=backend,
init_method=url,
rank=world_rank,
world_size=world_size)
for pid, policy in self.policy_map.items():
if not isinstance(policy, TorchPolicy):
raise ValueError(
"This policy does not support torch distributed", policy)
policy.distributed_world_size = world_size
stop(self)
Releases all resources used by this RolloutWorker.
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def stop(self) -> None:
"""Releases all resources used by this RolloutWorker."""
# If we have an env -> Release its resources.
if self.env is not None:
self.async_env.stop()
# Close all policies' sessions (if tf static graph).
for policy in self.policy_map.values():
sess = policy.get_session()
# Closes the tf session, if any.
if sess is not None:
sess.close()
sync_filters(self, new_filters)
Changes self's filter to given and rebases any accumulated delta.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_filters |
dict |
Filters with new state to update local copy. |
required |
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def sync_filters(self, new_filters: dict) -> None:
"""Changes self's filter to given and rebases any accumulated delta.
Args:
new_filters: Filters with new state to update local copy.
"""
assert all(k in new_filters for k in self.filters)
for k in self.filters:
self.filters[k].sync(new_filters[k])
Sample Batches
ray.rllib.policy.sample_batch.SampleBatch (dict)
Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three samples, each with an "obs" and "reward" attribute.
__init__(self, *args, **kwargs)
special
Constructs a sample batch (same params as dict constructor).
Note: All args and those *kwargs not listed below will be passed as-is to the parent dict constructor.
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor).
Note: All *args and those **kwargs not listed below will be passed
as-is to the parent dict constructor.
Keyword Args:
_time_major (Optinal[bool]): Whether data in this sample batch
is time-major. This is False by default and only relevant
if the data contains sequences.
_max_seq_len (Optional[bool]): The max sequence chunk length
if the data contains sequences.
_zero_padded (Optional[bool]): Whether the data in this batch
contains sequences AND these sequences are right-zero-padded
according to the `_max_seq_len` setting.
_is_training (Optional[bool]): Whether this batch is used for
training. If False, batch may be used for e.g. action
computations (inference).
"""
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
# Maximum seq len value.
self.max_seq_len = kwargs.pop("_max_seq_len", None)
# Is alredy right-zero-padded?
self.zero_padded = kwargs.pop("_zero_padded", False)
# Whether this batch is used for training (vs inference).
self._is_training = kwargs.pop("_is_training", None)
# Call super constructor. This will make the actual data accessible
# by column name (str) via e.g. self["some-col"].
dict.__init__(self, *args, **kwargs)
self.accessed_keys = set()
self.added_keys = set()
self.deleted_keys = set()
self.intercepted_values = {}
self.get_interceptor = None
# Clear out None seq-lens.
seq_lens_ = self.get(SampleBatch.SEQ_LENS)
if seq_lens_ is None or \
(isinstance(seq_lens_, list) and len(seq_lens_) == 0):
self.pop(SampleBatch.SEQ_LENS, None)
# Numpyfy seq_lens if list.
elif isinstance(seq_lens_, list):
self[SampleBatch.SEQ_LENS] = seq_lens_ = \
np.array(seq_lens_, dtype=np.int32)
if self.max_seq_len is None and seq_lens_ is not None and \
not (tf and tf.is_tensor(seq_lens_)) and \
len(seq_lens_) > 0:
self.max_seq_len = max(seq_lens_)
if self._is_training is None:
self._is_training = self.pop("is_training", False)
lengths = []
copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
for k, v in copy_.items():
assert isinstance(k, str), self
# TODO: Drop support for lists as values.
# Convert lists of int|float into numpy arrays make sure all data
# has same length.
if isinstance(v, list):
self[k] = np.array(v)
# Try to infer the "length" of the SampleBatch by finding the first
# value that is actually a ndarray/tensor. This would fail if
# all values are nested dicts/tuples of more complex underlying
# structures.
len_ = len(v) if isinstance(
v,
(list, np.ndarray)) or (torch and torch.is_tensor(v)) else None
if len_:
lengths.append(len_)
if self.get(SampleBatch.SEQ_LENS) is not None and \
not (tf and tf.is_tensor(self[SampleBatch.SEQ_LENS])) and \
len(self[SampleBatch.SEQ_LENS]) > 0:
self.count = sum(self[SampleBatch.SEQ_LENS])
else:
self.count = lengths[0] if lengths else 0
# A convenience map for slicing this batch into sub-batches along
# the time axis. This helps reduce repeated iterations through the
# batch's seq_lens array to find good slicing points. Built lazily
# when needed.
self._slice_map = []
columns(self, keys)
Returns a list of the batch-data in the specified columns.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
keys |
List[str] |
List of column names fo which to return the data. |
required |
Returns:
Type | Description |
---|---|
List[any] |
The list of data items ordered by the order of column
names in |
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def columns(self, keys: List[str]) -> List[any]:
"""Returns a list of the batch-data in the specified columns.
Args:
keys (List[str]): List of column names fo which to return the data.
Returns:
List[any]: The list of data items ordered by the order of column
names in `keys`.
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
"""
# TODO: (sven) Make this work for nested data as well.
out = []
for k in keys:
out.append(self[k])
return out
compress(self, bulk=False, columns=frozenset({'new_obs', 'obs'}))
Compresses the data buffers (by column) in place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bulk |
bool |
Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size. |
False |
columns |
Set[str] |
The columns to compress. Default: Only compress the obs and new_obs columns. |
frozenset({'new_obs', 'obs'}) |
Returns:
Type | Description |
---|---|
SampleBatch |
This very (now compressed) SampleBatch. |
Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def compress(self,
bulk: bool = False,
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
"""Compresses the data buffers (by column) in place.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): The columns to compress. Default: Only
compress the obs and new_obs columns.
Returns:
SampleBatch: This very (now compressed) SampleBatch.
"""
def _compress_in_place(path, value):
if path[0] not in columns:
return
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
if bulk:
curr[p] = pack(value)
else:
curr[p] = np.array([pack(o) for o in value])
curr = curr[p]
tree.map_structure_with_path(_compress_in_place, self)
return self
concat(self, other)
Concatenates other
to this one and returns a new SampleBatch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
SampleBatch |
The other SampleBatch object to concat to this one. |
required |
Returns:
Type | Description |
---|---|
SampleBatch |
The new SampleBatch, resulting from concating |
Examples:
>>> b1 = SampleBatch({"a": np.array([1, 2])})
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])})
>>> print(b1.concat(b2))
{"a": np.array([1, 2, 3, 4, 5])}
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
"""Concatenates `other` to this one and returns a new SampleBatch.
Args:
other (SampleBatch): The other SampleBatch object to concat to this
one.
Returns:
SampleBatch: The new SampleBatch, resulting from concating `other`
to `self`.
Examples:
>>> b1 = SampleBatch({"a": np.array([1, 2])})
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])})
>>> print(b1.concat(b2))
{"a": np.array([1, 2, 3, 4, 5])}
"""
return self.concat_samples([self, other])
concat_samples(samples)
staticmethod
Concatenates n SampleBatches or MultiAgentBatches.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
Union[List[SampleBatch], List[MultiAgentBatch]] |
List of SampleBatches or MultiAgentBatches to be concatenated. |
required |
Returns:
Type | Description |
---|---|
Union[SampleBatch, MultiAgentBatch] |
A new (concatenated) SampleBatch or MultiAgentBatch. |
Examples:
>>> b1 = SampleBatch({"a": np.array([1, 2]),
... "b": np.array([10, 11])})
>>> b2 = SampleBatch({"a": np.array([3]),
... "b": np.array([12])})
>>> print(SampleBatch.concat_samples([b1, b2]))
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def concat_samples(
samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
) -> Union["SampleBatch", "MultiAgentBatch"]:
"""Concatenates n SampleBatches or MultiAgentBatches.
Args:
samples (Union[List[SampleBatch], List[MultiAgentBatch]]): List of
SampleBatches or MultiAgentBatches to be concatenated.
Returns:
Union[SampleBatch, MultiAgentBatch]: A new (concatenated)
SampleBatch or MultiAgentBatch.
Examples:
>>> b1 = SampleBatch({"a": np.array([1, 2]),
... "b": np.array([10, 11])})
>>> b2 = SampleBatch({"a": np.array([3]),
... "b": np.array([12])})
>>> print(SampleBatch.concat_samples([b1, b2]))
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
"""
if any(isinstance(s, MultiAgentBatch) for s in samples):
return MultiAgentBatch.concat_samples(samples)
concatd_seq_lens = []
concat_samples = []
zero_padded = samples[0].zero_padded
max_seq_len = samples[0].max_seq_len
time_major = samples[0].time_major
for s in samples:
if s.count > 0:
assert s.zero_padded == zero_padded
assert s.time_major == time_major
if zero_padded:
assert s.max_seq_len == max_seq_len
concat_samples.append(s)
if s.get(SampleBatch.SEQ_LENS) is not None:
concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
# If we don't have any samples (0 or only empty SampleBatches),
# return an empty SampleBatch here.
if len(concat_samples) == 0:
return SampleBatch()
# Collect the concat'd data.
concatd_data = {}
def concat_key(*values):
return concat_aligned(values, time_major)
try:
for k in concat_samples[0].keys():
if k == "infos":
concatd_data[k] = concat_aligned(
[s[k] for s in concat_samples], time_major=time_major)
else:
concatd_data[k] = tree.map_structure(
concat_key, *[c[k] for c in concat_samples])
except Exception:
raise ValueError(f"Cannot concat data under key '{k}', b/c "
"sub-structures under that key don't match. "
f"`samples`={samples}")
# Return a new (concat'd) SampleBatch.
return SampleBatch(
concatd_data,
seq_lens=concatd_seq_lens,
_time_major=time_major,
_zero_padded=zero_padded,
_max_seq_len=max_seq_len,
)
copy(self, shallow=False)
Creates a deep or shallow copy of this SampleBatch and returns it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shallow |
bool |
Whether the copying should be done shallowly. |
False |
Returns:
Type | Description |
---|---|
SampleBatch |
A deep or shallow copy of this SampleBatch object. |
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
"""Creates a deep or shallow copy of this SampleBatch and returns it.
Args:
shallow (bool): Whether the copying should be done shallowly.
Returns:
SampleBatch: A deep or shallow copy of this SampleBatch object.
"""
copy_ = {k: v for k, v in self.items()}
data = tree.map_structure(
lambda v: (np.array(v, copy=not shallow) if
isinstance(v, np.ndarray) else v),
copy_,
)
copy_ = SampleBatch(data)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_
decompress_if_needed(self, columns=frozenset({'new_obs', 'obs'}))
Decompresses data buffers (per column if not compressed) in place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
columns |
Set[str] |
The columns to decompress. Default: Only decompress the obs and new_obs columns. |
frozenset({'new_obs', 'obs'}) |
Returns:
Type | Description |
---|---|
SampleBatch |
This very (now uncompressed) SampleBatch. |
Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def decompress_if_needed(self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "SampleBatch":
"""Decompresses data buffers (per column if not compressed) in place.
Args:
columns (Set[str]): The columns to decompress. Default: Only
decompress the obs and new_obs columns.
Returns:
SampleBatch: This very (now uncompressed) SampleBatch.
"""
def _decompress_in_place(path, value):
if path[0] not in columns:
return
curr = self
for p in path[:-1]:
curr = curr[p]
# Bulk compressed.
if is_compressed(value):
curr[path[-1]] = unpack(value)
# Non bulk compressed.
elif len(value) > 0 and is_compressed(value[0]):
curr[path[-1]] = np.array([unpack(o) for o in value])
tree.map_structure_with_path(_decompress_in_place, self)
return self
get(self, key, default=None)
get_single_step_input_dict(self, view_requirements, index='last')
Creates single ts SampleBatch at given index from self
.
For usage as input-dict for model (action or value function) calls.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
view_requirements |
Dict[str, ViewRequirement] |
A view requirements dict from the model for which to produce the input_dict. |
required |
index |
Union[str, int] |
An integer index value indicating the position in the trajectory for which to generate the compute_actions input dict. Set to "last" to generate the dict at the very end of the trajectory (e.g. for value estimation). Note that "last" is different from -1, as "last" will use the final NEXT_OBS as observation input. |
'last' |
Returns:
Type | Description |
---|---|
SampleBatch |
The (single-timestep) input dict for ModelV2 calls. |
Source code in ray/rllib/policy/sample_batch.py
@ExperimentalAPI
def get_single_step_input_dict(
self,
view_requirements: ViewRequirementsDict,
index: Union[str, int] = "last",
) -> "SampleBatch":
"""Creates single ts SampleBatch at given index from `self`.
For usage as input-dict for model (action or value function) calls.
Args:
view_requirements: A view requirements dict from the model for
which to produce the input_dict.
index: An integer index value indicating the
position in the trajectory for which to generate the
compute_actions input dict. Set to "last" to generate the dict
at the very end of the trajectory (e.g. for value estimation).
Note that "last" is different from -1, as "last" will use the
final NEXT_OBS as observation input.
Returns:
The (single-timestep) input dict for ModelV2 calls.
"""
last_mappings = {
SampleBatch.OBS: SampleBatch.NEXT_OBS,
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
}
input_dict = {}
for view_col, view_req in view_requirements.items():
if view_req.used_for_compute_actions is False:
continue
# Create batches of size 1 (single-agent input-dict).
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
# Range needed.
if view_req.shift_from is not None:
# Batch repeat value > 1: We have single frames in the
# batch at each timestep (for the `data_col`).
data = self[view_col][-1]
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
# Index into the observations column must be shifted by
# -1 b/c index=0 for observations means the current (last
# seen) observation (after having taken an action).
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate(
[data,
self[data_col][-missing_at_end:]])[from_:to_]
])
# Single index.
else:
input_dict[view_col] = tree.map_structure(
lambda v: v[-1:], # keep as array (w/ 1 element)
self[data_col],
)
# Single index somewhere inside the trajectory (non-last).
else:
input_dict[view_col] = self[data_col][index:index + 1
if index != -1 else None]
return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
right_zero_pad(self, max_seq_len, exclude_states=True)
Right (adding zeros at end) zero-pads this SampleBatch in-place.
This will set the self.zero_padded
flag to True and
self.max_seq_len
to the given max_seq_len
value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_seq_len |
int |
The max (total) length to zero pad to. |
required |
exclude_states |
bool |
If False, also right-zero-pad all
|
True |
Returns:
Type | Description |
---|---|
SampleBatch |
This very (now right-zero-padded) SampleBatch. |
Exceptions:
Type | Description |
---|---|
ValueError |
If self[SampleBatch.SEQ_LENS] is None (not defined). |
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=4))
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3],
... "state_in_0": [1.0, 3.0],
... "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=5))
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
"seq_lens": [1, 2]}
Source code in ray/rllib/policy/sample_batch.py
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
"""Right (adding zeros at end) zero-pads this SampleBatch in-place.
This will set the `self.zero_padded` flag to True and
`self.max_seq_len` to the given `max_seq_len` value.
Args:
max_seq_len: The max (total) length to zero pad to.
exclude_states: If False, also right-zero-pad all
`state_in_x` data. If True, leave `state_in_x` keys
as-is.
Returns:
SampleBatch: This very (now right-zero-padded) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=4))
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3],
... "state_in_0": [1.0, 3.0],
... "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=5))
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
"seq_lens": [1, 2]}
"""
seq_lens = self.get(SampleBatch.SEQ_LENS)
if seq_lens is None:
raise ValueError(
"Cannot right-zero-pad SampleBatch if no `seq_lens` field "
"present! SampleBatch={self}")
length = len(seq_lens) * max_seq_len
def _zero_pad_in_place(path, value):
# Skip "state_in_..." columns and "seq_lens".
if (exclude_states is True and path[0].startswith("state_in_")) \
or path[0] == SampleBatch.SEQ_LENS:
return
# Generate zero-filled primer of len=max_seq_len.
if value.dtype == np.object or value.dtype.type is np.str_:
f_pad = [None] * length
else:
# Make sure type doesn't change.
f_pad = np.zeros(
(length, ) + np.shape(value)[1:], dtype=value.dtype)
# Fill primer with data.
f_pad_base = f_base = 0
for len_ in self[SampleBatch.SEQ_LENS]:
f_pad[f_pad_base:f_pad_base + len_] = value[f_base:f_base +
len_]
f_pad_base += max_seq_len
f_base += len_
assert f_base == len(value), value
# Update our data in-place.
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
curr[p] = f_pad
curr = curr[p]
self_as_dict = {k: v for k, v in self.items()}
tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
# Set flags to indicate, we are now zero-padded (and to what extend).
self.zero_padded = True
self.max_seq_len = max_seq_len
return self
rows(self)
Returns an iterator over data rows, i.e. dicts with column values.
Note that if seq_lens
is set in self, we set it to [1] in the rows.
!!! yields Dict[str, TensorType]: The column values of the row in this iteration.
Examples:
>>> batch = SampleBatch({
... "a": [1, 2, 3],
... "b": [4, 5, 6],
... "seq_lens": [1, 2]
... })
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4, "seq_lens": [1]}
{"a": 2, "b": 5, "seq_lens": [1]}
{"a": 3, "b": 6, "seq_lens": [1]}
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def rows(self) -> Iterator[Dict[str, TensorType]]:
"""Returns an iterator over data rows, i.e. dicts with column values.
Note that if `seq_lens` is set in self, we set it to [1] in the rows.
Yields:
Dict[str, TensorType]: The column values of the row in this
iteration.
Examples:
>>> batch = SampleBatch({
... "a": [1, 2, 3],
... "b": [4, 5, 6],
... "seq_lens": [1, 2]
... })
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4, "seq_lens": [1]}
{"a": 2, "b": 5, "seq_lens": [1]}
{"a": 3, "b": 6, "seq_lens": [1]}
"""
# Do we add seq_lens=[1] to each row?
seq_lens = None if self.get(
SampleBatch.SEQ_LENS) is None else np.array([1])
self_as_dict = {k: v for k, v in self.items()}
for i in range(self.count):
yield tree.map_structure_with_path(
lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
self_as_dict,
)
shuffle(self)
Shuffles the rows of this batch in-place.
Returns:
Type | Description |
---|---|
SampleBatch |
This very (now shuffled) SampleBatch. |
Exceptions:
Type | Description |
---|---|
ValueError |
If self[SampleBatch.SEQ_LENS] is defined. |
Examples:
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def shuffle(self) -> None:
"""Shuffles the rows of this batch in-place.
Returns:
SampleBatch: This very (now shuffled) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is defined.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3, 4]})
>>> print(batch.shuffle())
{"a": [4, 1, 3, 2]}
"""
# Shuffling the data when we have `seq_lens` defined is probably
# a bad idea!
if self.get(SampleBatch.SEQ_LENS) is not None:
raise ValueError(
"SampleBatch.shuffle not possible when your data has "
"`seq_lens` defined!")
# Get a permutation over the single items once and use the same
# permutation for all the data (otherwise, data would become
# meaningless).
permutation = np.random.permutation(self.count)
def _permutate_in_place(path, value):
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
curr[p] = value[permutation]
# Translate into list (tuples are immutable).
if isinstance(curr[p], tuple):
curr[p] = list(curr[p])
curr = curr[p]
tree.map_structure_with_path(_permutate_in_place, self)
return self
size_bytes(self)
Returns sum over number of bytes of all data buffers.
For numpy arrays, we use .nbytes
. For all other value types, we use
sys.getsizeof(...).
Returns:
Type | Description |
---|---|
int |
The overall size in bytes of the data buffer (all columns). |
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def size_bytes(self) -> int:
"""Returns sum over number of bytes of all data buffers.
For numpy arrays, we use `.nbytes`. For all other value types, we use
sys.getsizeof(...).
Returns:
int: The overall size in bytes of the data buffer (all columns).
"""
return sum(
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
for v in tree.flatten(self))
split_by_episode(self)
Splits by eps_id
column and returns list of new batches.
Returns:
Type | Description |
---|---|
List[SampleBatch] |
List of batches, one per distinct episode. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the |
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]})
>>> print(batch.split_by_episode())
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def split_by_episode(self) -> List["SampleBatch"]:
"""Splits by `eps_id` column and returns list of new batches.
Returns:
List[SampleBatch]: List of batches, one per distinct episode.
Raises:
KeyError: If the `eps_id` AND `dones` columns are not present.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]})
>>> print(batch.split_by_episode())
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
"""
# No eps_id in data -> Make sure there are no "dones" in the middle
# and add eps_id automatically.
if SampleBatch.EPS_ID not in self:
# TODO: (sven) Shouldn't we rather split by DONEs then and not
# add fake eps-ids (0s) at all?
if SampleBatch.DONES in self:
assert not any(self[SampleBatch.DONES][:-1])
self[SampleBatch.EPS_ID] = np.repeat(0, self.count)
return [self]
# Produce a new slice whenever we find a new episode ID.
slices = []
cur_eps_id = self[SampleBatch.EPS_ID][0]
offset = 0
for i in range(self.count):
next_eps_id = self[SampleBatch.EPS_ID][i]
if next_eps_id != cur_eps_id:
slices.append(self[offset:i])
offset = i
cur_eps_id = next_eps_id
# Add final slice.
slices.append(self[offset:self.count])
# TODO: (sven) Are these checks necessary? Should be all ok according
# to above logic.
for s in slices:
slen = len(set(s[SampleBatch.EPS_ID]))
assert slen == 1, (s, slen)
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
timeslices(self, size=None, num_slices=None, k=None)
Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
size |
Optional[int] |
The size (in timesteps) of each returned SampleBatch. |
None |
num_slices |
Optional[int] |
The number of slices to produce. |
None |
k |
int |
The size (in timesteps) of each returned SampleBatch. |
None |
Returns:
Type | Description |
---|---|
List[SampleBatch] |
The list of |
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def timeslices(self,
size: Optional[int] = None,
num_slices: Optional[int] = None,
k: Optional[int] = None) -> List["SampleBatch"]:
"""Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
Args:
size (Optional[int]): The size (in timesteps) of each returned
SampleBatch.
num_slices (Optional[int]): The number of slices to produce.
k (int): Obsoleted: Use size or num_slices instead!
The size (in timesteps) of each returned SampleBatch.
Returns:
List[SampleBatch]: The list of `num_slices` (new) SampleBatches
or n (new) SampleBatches each one of size `size`.
"""
if size is None and num_slices is None:
deprecation_warning("k", "size or num_slices")
assert k is not None
size = k
if size is None:
assert isinstance(num_slices, int)
slices = []
left = len(self)
start = 0
while left:
len_ = left // (num_slices - len(slices))
stop = start + len_
slices.append(self[start:stop])
left -= len_
start = stop
return slices
else:
assert isinstance(size, int)
slices = []
left = len(self)
start = 0
while left:
stop = start + size
slices.append(self[start:stop])
left -= size
start = stop
return slices
to_device(self, device, framework='torch')
Source code in ray/rllib/policy/sample_batch.py
def to_device(self, device, framework="torch"):
"""TODO: transfer batch to given device as framework tensor."""
if framework == "torch":
assert torch is not None
for k, v in self.items():
if isinstance(v, np.ndarray) and v.dtype != np.object:
self[k] = torch.from_numpy(v).to(device)
else:
raise NotImplementedError
return self
ray.rllib.evaluation.sample_batch_builder.SampleBatchBuilder
Util to build a SampleBatch incrementally.
For efficiency, SampleBatches hold values in column form (as arrays). However, it is useful to add data one row (dict) at a time.
add_batch(self, batch)
add_values(self, **values)
build_and_reset(self)
Returns a sample batch including all previously added values.
Source code in ray/rllib/evaluation/sample_batch_builder.py
def build_and_reset(self) -> SampleBatch:
"""Returns a sample batch including all previously added values."""
batch = SampleBatch(
{k: to_float_array(v)
for k, v in self.buffers.items()})
if SampleBatch.UNROLL_ID not in batch:
batch[SampleBatch.UNROLL_ID] = np.repeat(
SampleBatchBuilder._next_unroll_id, batch.count)
SampleBatchBuilder._next_unroll_id += 1
self.buffers.clear()
self.count = 0
return batch
ray.rllib.policy.sample_batch.MultiAgentBatch
A batch of experiences from multiple agents in the environment.
Attributes:
Name | Type | Description |
---|---|---|
policy_batches |
Dict[PolicyID, SampleBatch] |
Mapping from policy ids to SampleBatches of experiences. |
count |
int |
The number of env steps in this batch. |
__init__(self, policy_batches, env_steps)
special
Initialize a MultiAgentBatch object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_batches |
Dict[PolicyID, SampleBatch] |
Mapping from policy ids to SampleBatches of experiences. |
required |
env_steps |
int |
The number of environment steps in the environment this batch contains. This will be less than the number of transitions this batch contains across all policies in total. |
required |
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int):
"""Initialize a MultiAgentBatch object.
Args:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
env_steps (int): The number of environment steps in the environment
this batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
"""
for v in policy_batches.values():
assert isinstance(v, SampleBatch)
self.policy_batches = policy_batches
# Called "count" for uniformity with SampleBatch.
# Prefer to access this via the `env_steps()` method when possible
# for clarity.
self.count = env_steps
agent_steps(self)
The number of agent steps (there are >= 1 agent steps per env step).
Returns:
Type | Description |
---|---|
int |
The number of agent steps total in this batch. |
Source code in ray/rllib/policy/sample_batch.py
compress(self, bulk=False, columns=frozenset({'new_obs', 'obs'}))
Compresses each policy batch (per column) in place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bulk |
bool |
Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size. |
False |
columns |
Set[str] |
Set of column names to compress. |
frozenset({'new_obs', 'obs'}) |
Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def compress(self,
bulk: bool = False,
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
"""Compresses each policy batch (per column) in place.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): Set of column names to compress.
"""
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
concat_samples(samples)
staticmethod
Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
List[MultiAgentBatch] |
List of MultiagentBatch objects to concatenate. |
required |
Returns:
Type | Description |
---|---|
MultiAgentBatch |
A new MultiAgentBatch consisting of the concatenated inputs. |
Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
Args:
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
to concatenate.
Returns:
MultiAgentBatch: A new MultiAgentBatch consisting of the
concatenated inputs.
"""
policy_batches = collections.defaultdict(list)
env_steps = 0
for s in samples:
# Some batches in `samples` are not MultiAgentBatch.
if not isinstance(s, MultiAgentBatch):
# If empty SampleBatch: ok (just ignore).
if isinstance(s, SampleBatch) and len(s) <= 0:
continue
# Otherwise: Error.
raise ValueError(
"`MultiAgentBatch.concat_samples()` can only concat "
"MultiAgentBatch types, not {}!".format(type(s).__name__))
for key, batch in s.policy_batches.items():
policy_batches[key].append(batch)
env_steps += s.env_steps()
out = {}
for key, batches in policy_batches.items():
out[key] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, env_steps)
copy(self)
Deep-copies self into a new MultiAgentBatch.
Returns:
Type | Description |
---|---|
MultiAgentBatch |
The copy of self with deep-copied data. |
Source code in ray/rllib/policy/sample_batch.py
decompress_if_needed(self, columns=frozenset({'new_obs', 'obs'}))
Decompresses each policy batch (per column), if already compressed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
columns |
Set[str] |
Set of column names to decompress. |
frozenset({'new_obs', 'obs'}) |
Returns:
Type | Description |
---|---|
MultiAgentBatch |
This very MultiAgentBatch. |
Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def decompress_if_needed(self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "MultiAgentBatch":
"""Decompresses each policy batch (per column), if already compressed.
Args:
columns (Set[str]): Set of column names to decompress.
Returns:
MultiAgentBatch: This very MultiAgentBatch.
"""
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
return self
env_steps(self)
The number of env steps (there are >= 1 agent steps per env step).
Returns:
Type | Description |
---|---|
int |
The number of environment steps contained in this batch. |
size_bytes(self)
Returns:
Type | Description |
---|---|
int |
The overall size in bytes of all policy batches (all columns). |
timeslices(self, k)
Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3], for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement "lockstep" replay mode. Note that this method does not guarantee each batch contains only data from a single unroll. Batches might contain data from multiple different envs.
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
"""Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement "lockstep" replay mode. Note that this
method does not guarantee each batch contains only data from a single
unroll. Batches might contain data from multiple different envs.
"""
from ray.rllib.evaluation.sample_batch_builder import \
SampleBatchBuilder
# Build a sorted set of (eps_id, t, policy_id, data...)
steps = []
for policy_id, batch in self.policy_batches.items():
for row in batch.rows():
steps.append((row[SampleBatch.EPS_ID], row[SampleBatch.T],
row[SampleBatch.AGENT_INDEX], policy_id, row))
steps.sort()
finished_slices = []
cur_slice = collections.defaultdict(SampleBatchBuilder)
cur_slice_size = 0
def finish_slice():
nonlocal cur_slice_size
assert cur_slice_size > 0
batch = MultiAgentBatch(
{k: v.build_and_reset()
for k, v in cur_slice.items()}, cur_slice_size)
cur_slice_size = 0
finished_slices.append(batch)
# For each unique env timestep.
for _, group in itertools.groupby(steps, lambda x: x[:2]):
# Accumulate into the current slice.
for _, _, _, policy_id, row in group:
cur_slice[policy_id].add_values(**row)
cur_slice_size += 1
# Slice has reached target number of env steps.
if cur_slice_size >= k:
finish_slice()
assert cur_slice_size == 0
if cur_slice_size > 0:
finish_slice()
assert len(finished_slices) > 0, finished_slices
return finished_slices
wrap_as_needed(policy_batches, env_steps)
staticmethod
Returns SampleBatch or MultiAgentBatch, depending on given policies.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_batches |
Dict[PolicyID, SampleBatch] |
Mapping from policy ids to SampleBatch. |
required |
env_steps |
int |
Number of env steps in the batch. |
required |
Returns:
Type | Description |
---|---|
Union[SampleBatch, MultiAgentBatch] |
The single default policy's SampleBatch or a MultiAgentBatch (more than one policy). |
Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def wrap_as_needed(
policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
Args:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatch.
env_steps (int): Number of env steps in the batch.
Returns:
Union[SampleBatch, MultiAgentBatch]: The single default policy's
SampleBatch or a MultiAgentBatch (more than one policy).
"""
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
return policy_batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(
policy_batches=policy_batches, env_steps=env_steps)
ray.rllib.evaluation.sample_batch_builder.MultiAgentSampleBatchBuilder
Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N mapping between agents and policies. We retain one local batch builder per agent. When an agent is done, then its local batch is appended into the corresponding policy batch for the agent's policy.
__init__(self, policy_map, clip_rewards, callbacks)
special
Initialize a MultiAgentSampleBatchBuilder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_map |
Dict[str,Policy] |
Maps policy ids to policy instances. |
required |
clip_rewards |
Union[bool,float] |
Whether to clip rewards before postprocessing (at +/-1.0) or the actual value to +/- clip. |
required |
callbacks |
DefaultCallbacks |
RLlib callbacks. |
required |
Source code in ray/rllib/evaluation/sample_batch_builder.py
def __init__(self, policy_map: Dict[PolicyID, Policy], clip_rewards: bool,
callbacks: "DefaultCallbacks"):
"""Initialize a MultiAgentSampleBatchBuilder.
Args:
policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
clip_rewards (Union[bool,float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.
"""
if log_once("MultiAgentSampleBatchBuilder"):
deprecation_warning(
old="MultiAgentSampleBatchBuilder", error=False)
self.policy_map = policy_map
self.clip_rewards = clip_rewards
# Build the Policies' SampleBatchBuilders.
self.policy_builders = {
k: SampleBatchBuilder()
for k in policy_map.keys()
}
# Whenever we observe a new agent, add a new SampleBatchBuilder for
# this agent.
self.agent_builders = {}
# Internal agent-to-policy map.
self.agent_to_policy = {}
self.callbacks = callbacks
# Number of "inference" steps taken in the environment.
# Regardless of the number of agents involved in each of these steps.
self.count = 0
add_values(self, agent_id, policy_id, **values)
Add the given dictionary (row) of values to this batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_id |
obj |
Unique id for the agent we are adding values for. |
required |
policy_id |
obj |
Unique id for policy controlling the agent. |
required |
values |
dict |
Row of values to add for this agent. |
{} |
Source code in ray/rllib/evaluation/sample_batch_builder.py
@DeveloperAPI
def add_values(self, agent_id: AgentID, policy_id: AgentID,
**values: Any) -> None:
"""Add the given dictionary (row) of values to this batch.
Args:
agent_id (obj): Unique id for the agent we are adding values for.
policy_id (obj): Unique id for policy controlling the agent.
values (dict): Row of values to add for this agent.
"""
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
# Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id
self.agent_builders[agent_id].add_values(**values)
build_and_reset(self, episode=None)
Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
episode |
Optional[Episode] |
The Episode object that holds this MultiAgentBatchBuilder object or None. |
None |
Returns:
Type | Description |
---|---|
MultiAgentBatch |
Returns the accumulated sample batches for each policy. |
Source code in ray/rllib/evaluation/sample_batch_builder.py
@DeveloperAPI
def build_and_reset(self,
episode: Optional[Episode] = None) -> MultiAgentBatch:
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Args:
episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object or None.
Returns:
MultiAgentBatch: Returns the accumulated sample batches for each
policy.
"""
self.postprocess_batch_so_far(episode)
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
policy_batches[policy_id] = builder.build_and_reset()
old_count = self.count
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
has_pending_agent_data(self)
Returns whether there is pending unprocessed data.
Returns:
Type | Description |
---|---|
bool |
True if there is at least one per-agent builder (with data in it). |
postprocess_batch_so_far(self, episode=None)
Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy builders, clearing per-agent state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
episode |
Optional[Episode] |
The Episode object that holds this MultiAgentBatchBuilder object. |
None |
Source code in ray/rllib/evaluation/sample_batch_builder.py
def postprocess_batch_so_far(self,
episode: Optional[Episode] = None) -> None:
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Args:
episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object.
"""
# Materialize the batches so far.
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
self.policy_map[self.agent_to_policy[agent_id]],
builder.build_and_reset())
# Apply postprocessor.
post_batches = {}
if self.clip_rewards is True:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
elif self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.clip(
pre_batch["rewards"],
a_min=-self.clip_rewards,
a_max=self.clip_rewards)
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
# Call the Policy's Exploration's postprocess method.
post_batches[agent_id] = pre_batch
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id], policy.get_session())
post_batches[agent_id] = policy.postprocess_trajectory(
post_batches[agent_id], other_batches, episode)
if log_once("after_post"):
logger.info(
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
format(summarize(post_batches)))
# Append into policy batches and reset
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_id, post_batch in sorted(post_batches.items()):
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=episode,
agent_id=agent_id,
policy_id=self.agent_to_policy[agent_id],
policies=self.policy_map,
postprocessed_batch=post_batch,
original_batches=pre_batches)
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
self.agent_to_policy.clear()
total(self)
Returns the total number of steps taken in the env (all agents).
Returns:
Type | Description |
---|---|
int |
The number of steps taken in total in the environment over all agents. |
Samplers
ray.rllib.evaluation.sampler.SyncSampler (SamplerInput)
Sync SamplerInput that collects experiences when get_data()
is called.
__init__(self, *, worker, env, clip_rewards, rollout_fragment_length, count_steps_by='env_steps', callbacks, horizon=None, multiple_episodes_in_batch=False, normalize_actions=True, clip_actions=False, soft_horizon=False, no_done_at_end=False, observation_fn=None, sample_collector_class=None, render=False, policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, tf_sess=None)
special
Initializes a SyncSampler instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
The RolloutWorker that will use this Sampler for sampling. |
required |
env |
BaseEnv |
Any Env object. Will be converted into an RLlib BaseEnv. |
required |
clip_rewards |
Union[bool, float] |
True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. |
required |
rollout_fragment_length |
int |
The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. |
required |
count_steps_by |
str |
One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode. |
'env_steps' |
callbacks |
DefaultCallbacks |
The Callbacks object to use when episode events happen during rollout. |
required |
horizon |
int |
Hard-reset the Env after this many timesteps. |
None |
multiple_episodes_in_batch |
bool |
Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly |
False |
normalize_actions |
bool |
Whether to normalize actions to the action space's bounds. |
True |
clip_actions |
bool |
Whether to clip actions according to the given action_space's bounds. |
False |
soft_horizon |
bool |
If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. |
False |
no_done_at_end |
bool |
Ignore the done=True at the end of the episode and instead record done=False. |
False |
observation_fn |
Optional[ObservationFunction] |
Optional multi-agent observation func to use for preprocessing observations. |
None |
sample_collector_class |
Optional[Type[ray.rllib.evaluation.collectors.sample_collector.SampleCollector]] |
An optional Samplecollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data. |
None |
render |
bool |
Whether to try to render the environment after each step. |
False |
Source code in ray/rllib/evaluation/sampler.py
def __init__(
self,
*,
worker: "RolloutWorker",
env: BaseEnv,
clip_rewards: Union[bool, float],
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True,
clip_actions: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
tf_sess=None,
):
"""Initializes a SyncSampler instance.
Args:
worker: The RolloutWorker that will use this Sampler for sampling.
env: Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards: True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no
clipping.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
callbacks: The Callbacks object to use when episode
events happen during rollout.
horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions: Whether to clip actions according to the
given action_space's bounds.
soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn: Optional multi-agent observation func to use for
preprocessing observations.
sample_collector_class: An optional Samplecollector sub-class to
use to collect, store, and retrieve environment-, model-,
and sampler data.
render: Whether to try to render the environment after each step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
if log_once("deprecated_sync_sampler_args"):
if policies is not None:
deprecation_warning(old="policies")
if policy_mapping_fn is not None:
deprecation_warning(old="policy_mapping_fn")
if preprocessors is not None:
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
if tf_sess is not None:
deprecation_warning(old="tf_sess")
self.base_env = BaseEnv.to_base_env(env)
self.rollout_fragment_length = rollout_fragment_length
self.horizon = horizon
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats()
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
worker.policy_map,
clip_rewards,
callbacks,
multiple_episodes_in_batch,
rollout_fragment_length,
count_steps_by=count_steps_by)
self.render = render
# Create the rollout generator to use for calls to `get_data()`.
self._env_runner = _env_runner(
worker, self.base_env, self.extra_batches.put, self.horizon,
normalize_actions, clip_actions, multiple_episodes_in_batch,
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
observation_fn, self.sample_collector, self.render)
self.metrics_queue = queue.Queue()
get_data(self)
Called by self.next()
to return the next batch of data.
Override this in child classes.
Returns:
Type | Description |
---|---|
Union[SampleBatch, MultiAgentBatch] |
The next batch of data. |
get_extra_batches(self)
Returns list of extra batches since the last call to this method.
The list will contain all SampleBatches or
MultiAgentBatches that the user has provided thus-far. Users can
add these "extra batches" to an episode by calling the episode's
add_extra_batch([SampleBatchType])
method. This can be done from
inside an overridden Policy.compute_actions_from_input_dict(...,
episodes)
or from a custom callback's on_episode_[start|step|end]()
methods.
Returns:
Type | Description |
---|---|
List[Union[SampleBatch, MultiAgentBatch]] |
List of SamplesBatches or MultiAgentBatches provided thus-far by the user since the last call to this method. |
get_metrics(self)
Returns list of episode metrics since the last call to this method.
The list will contain one RolloutMetrics object per completed episode.
Returns:
Type | Description |
---|---|
List[ray.rllib.evaluation.metrics.RolloutMetrics] |
List of RolloutMetrics objects, one per completed episode since the last call to this method. |
ray.rllib.evaluation.sampler.AsyncSampler (Thread, SamplerInput)
Async SamplerInput that collects experiences in thread and queues them.
Once started, experiences are continuously collected in the background
and put into a Queue, from where they can be unqueued by the caller
of get_data()
.
__init__(self, *, worker, env, clip_rewards, rollout_fragment_length, count_steps_by='env_steps', callbacks, horizon=None, multiple_episodes_in_batch=False, normalize_actions=True, clip_actions=False, soft_horizon=False, no_done_at_end=False, observation_fn=None, sample_collector_class=None, render=False, blackhole_outputs=False, policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, tf_sess=None)
special
Initializes an AsyncSampler instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
The RolloutWorker that will use this Sampler for sampling. |
required |
env |
BaseEnv |
Any Env object. Will be converted into an RLlib BaseEnv. |
required |
clip_rewards |
Union[bool, float] |
True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping. |
required |
rollout_fragment_length |
int |
The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. |
required |
count_steps_by |
str |
One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode. |
'env_steps' |
horizon |
Optional[int] |
Hard-reset the Env after this many timesteps. |
None |
multiple_episodes_in_batch |
bool |
Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly |
False |
normalize_actions |
bool |
Whether to normalize actions to the action space's bounds. |
True |
clip_actions |
bool |
Whether to clip actions according to the given action_space's bounds. |
False |
blackhole_outputs |
bool |
Whether to collect samples, but then not further process or store them (throw away all samples). |
False |
soft_horizon |
bool |
If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit. |
False |
no_done_at_end |
bool |
Ignore the done=True at the end of the episode and instead record done=False. |
False |
observation_fn |
Optional[ObservationFunction] |
Optional multi-agent observation func to use for preprocessing observations. |
None |
sample_collector_class |
Optional[Type[ray.rllib.evaluation.collectors.sample_collector.SampleCollector]] |
An optional SampleCollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data. |
None |
render |
bool |
Whether to try to render the environment after each step. |
False |
Source code in ray/rllib/evaluation/sampler.py
def __init__(
self,
*,
worker: "RolloutWorker",
env: BaseEnv,
clip_rewards: Union[bool, float],
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: Optional[int] = None,
multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True,
clip_actions: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
blackhole_outputs: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
tf_sess=None,
):
"""Initializes an AsyncSampler instance.
Args:
worker: The RolloutWorker that will use this Sampler for sampling.
env: Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards: True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no
clipping.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions: Whether to clip actions according to the
given action_space's bounds.
blackhole_outputs: Whether to collect samples, but then
not further process or store them (throw away all samples).
soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn: Optional multi-agent observation func to use for
preprocessing observations.
sample_collector_class: An optional SampleCollector sub-class to
use to collect, store, and retrieve environment-, model-,
and sampler data.
render: Whether to try to render the environment after each step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
if log_once("deprecated_async_sampler_args"):
if policies is not None:
deprecation_warning(old="policies")
if policy_mapping_fn is not None:
deprecation_warning(old="policy_mapping_fn")
if preprocessors is not None:
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
if tf_sess is not None:
deprecation_warning(old="tf_sess")
self.worker = worker
for _, f in worker.filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.base_env = BaseEnv.to_base_env(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.extra_batches = queue.Queue()
self.metrics_queue = queue.Queue()
self.rollout_fragment_length = rollout_fragment_length
self.horizon = horizon
self.clip_rewards = clip_rewards
self.daemon = True
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.callbacks = callbacks
self.normalize_actions = normalize_actions
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.soft_horizon = soft_horizon
self.no_done_at_end = no_done_at_end
self.perf_stats = _PerfStats()
self.shutdown = False
self.observation_fn = observation_fn
self.render = render
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
self.worker.policy_map,
self.clip_rewards,
self.callbacks,
self.multiple_episodes_in_batch,
self.rollout_fragment_length,
count_steps_by=count_steps_by)
get_data(self)
Called by self.next()
to return the next batch of data.
Override this in child classes.
Returns:
Type | Description |
---|---|
Union[SampleBatch, MultiAgentBatch] |
The next batch of data. |
Source code in ray/rllib/evaluation/sampler.py
get_extra_batches(self)
Returns list of extra batches since the last call to this method.
The list will contain all SampleBatches or
MultiAgentBatches that the user has provided thus-far. Users can
add these "extra batches" to an episode by calling the episode's
add_extra_batch([SampleBatchType])
method. This can be done from
inside an overridden Policy.compute_actions_from_input_dict(...,
episodes)
or from a custom callback's on_episode_[start|step|end]()
methods.
Returns:
Type | Description |
---|---|
List[Union[SampleBatch, MultiAgentBatch]] |
List of SamplesBatches or MultiAgentBatches provided thus-far by the user since the last call to this method. |
get_metrics(self)
Returns list of episode metrics since the last call to this method.
The list will contain one RolloutMetrics object per completed episode.
Returns:
Type | Description |
---|---|
List[ray.rllib.evaluation.metrics.RolloutMetrics] |
List of RolloutMetrics objects, one per completed episode since the last call to this method. |
run(self)
Method representing the thread's activity.
You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.
Utility Functions
ray.rllib.evaluation.postprocessing.compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True, use_critic=True)
Given a rollout, compute its value targets and the advantages.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rollout |
SampleBatch |
SampleBatch of a single trajectory. |
required |
last_r |
float |
Value estimation for last observation. |
required |
gamma |
float |
Discount factor. |
0.9 |
lambda_ |
float |
Parameter for GAE. |
1.0 |
use_gae |
bool |
Using Generalized Advantage Estimation. |
True |
use_critic |
bool |
Whether to use critic (value estimates). Setting this to False will use 0 as baseline. |
True |
Returns:
Type | Description |
---|---|
SampleBatch with experience from rollout and processed rewards. |
Source code in ray/rllib/evaluation/postprocessing.py
@DeveloperAPI
def compute_advantages(rollout: SampleBatch,
last_r: float,
gamma: float = 0.9,
lambda_: float = 1.0,
use_gae: bool = True,
use_critic: bool = True):
"""Given a rollout, compute its value targets and the advantages.
Args:
rollout: SampleBatch of a single trajectory.
last_r: Value estimation for last observation.
gamma: Discount factor.
lambda_: Parameter for GAE.
use_gae: Using Generalized Advantage Estimation.
use_critic: Whether to use critic (value estimates). Setting
this to False will use 0 as baseline.
Returns:
SampleBatch with experience from rollout and processed rewards.
"""
assert SampleBatch.VF_PREDS in rollout or not use_critic, \
"use_critic=True but values not found"
assert use_critic or not use_gae, \
"Can't use gae without using a value function"
if use_gae:
vpred_t = np.concatenate(
[rollout[SampleBatch.VF_PREDS],
np.array([last_r])])
delta_t = (
rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1])
# This formula for the advantage comes from:
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
rollout[Postprocessing.ADVANTAGES] = discount_cumsum(
delta_t, gamma * lambda_)
rollout[Postprocessing.VALUE_TARGETS] = (
rollout[Postprocessing.ADVANTAGES] +
rollout[SampleBatch.VF_PREDS]).astype(np.float32)
else:
rewards_plus_v = np.concatenate(
[rollout[SampleBatch.REWARDS],
np.array([last_r])])
discounted_returns = discount_cumsum(rewards_plus_v,
gamma)[:-1].astype(np.float32)
if use_critic:
rollout[Postprocessing.
ADVANTAGES] = discounted_returns - rollout[SampleBatch.
VF_PREDS]
rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
else:
rollout[Postprocessing.ADVANTAGES] = discounted_returns
rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
rollout[Postprocessing.ADVANTAGES])
rollout[Postprocessing.ADVANTAGES] = rollout[
Postprocessing.ADVANTAGES].astype(np.float32)
return rollout
ray.rllib.evaluation.metrics.collect_metrics(local_worker=None, remote_workers=None, to_be_collected=None, timeout_seconds=180)
Gathers episode metrics from RolloutWorker instances.
Source code in ray/rllib/evaluation/metrics.py
@DeveloperAPI
def collect_metrics(local_worker: Optional["RolloutWorker"] = None,
remote_workers: Optional[List[ActorHandle]] = None,
to_be_collected: Optional[List[ObjectRef]] = None,
timeout_seconds: int = 180) -> ResultDict:
"""Gathers episode metrics from RolloutWorker instances."""
if remote_workers is None:
remote_workers = []
if to_be_collected is None:
to_be_collected = []
episodes, to_be_collected = collect_episodes(
local_worker,
remote_workers,
to_be_collected,
timeout_seconds=timeout_seconds)
metrics = summarize_episodes(episodes, episodes)
return metrics