Skip to content

evaluation Package

Episodes

ray.rllib.evaluation.episode.Episode

Tracks the current state of a (possibly multi-agent) episode.

Attributes:

Name Type Description
new_batch_builder func

Create a new MultiAgentSampleBatchBuilder.

add_extra_batch func

Return a built MultiAgentBatch to the sampler.

batch_builder obj

Batch builder for the current episode.

total_reward float

Summed reward across all agents in this episode.

length int

Length of this episode.

episode_id int

Unique id identifying this trajectory.

agent_rewards dict

Summed rewards broken down by agent.

custom_metrics dict

Dict where the you can add custom metrics.

user_data dict

Dict that you can use for temporary storage. E.g. in between two custom callbacks referring to the same episode.

hist_data dict

Dict mapping str keys to List[float] for storage of per-timestep float data throughout the episode.

Use case 1: Model-based rollouts in multi-agent: A custom compute_actions() function in a policy can inspect the current episode state and perform a number of rollouts based on the policies and state of other agents in the environment.

Use case 2: Returning extra rollouts data. The model rollouts can be returned back to the sampler by calling:

>>> batch = episode.new_batch_builder()
>>> for each transition:
       batch.add_values(...)  # see sampler for usage
>>> episode.extra_batches.add(batch.build_and_reset())

__init__(self, policies, policy_mapping_fn, batch_builder_factory, extra_batch_callback, env_id, *, worker=None) special

Initializes an Episode instance.

Parameters:

Name Type Description Default
policies PolicyMap

The PolicyMap object (mapping PolicyIDs to Policy objects) to use for determining, which policy is used for which agent.

required
policy_mapping_fn Callable[[Any, Episode, RolloutWorker], str]

The mapping function mapping AgentIDs to PolicyIDs.

required
batch_builder_factory Callable[[], MultiAgentSampleBatchBuilder] required
extra_batch_callback Callable[[Union[SampleBatch, MultiAgentBatch]]] required
env_id Union[int, str]

The environment's ID in which this episode runs.

required
worker Optional[RolloutWorker]

The RolloutWorker instance, in which this episode runs.

None
Source code in ray/rllib/evaluation/episode.py
def __init__(
        self,
        policies: PolicyMap,
        policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
                                    PolicyID],
        batch_builder_factory: Callable[[],
                                        "MultiAgentSampleBatchBuilder"],
        extra_batch_callback: Callable[[SampleBatchType], None],
        env_id: EnvID,
        *,
        worker: Optional["RolloutWorker"] = None,
):
    """Initializes an Episode instance.

    Args:
        policies: The PolicyMap object (mapping PolicyIDs to Policy
            objects) to use for determining, which policy is used for
            which agent.
        policy_mapping_fn: The mapping function mapping AgentIDs to
            PolicyIDs.
        batch_builder_factory:
        extra_batch_callback:
        env_id: The environment's ID in which this episode runs.
        worker: The RolloutWorker instance, in which this episode runs.
    """
    self.new_batch_builder: Callable[
        [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
    self.add_extra_batch: Callable[[SampleBatchType],
                                   None] = extra_batch_callback
    self.batch_builder: "MultiAgentSampleBatchBuilder" = \
        batch_builder_factory()
    self.total_reward: float = 0.0
    self.length: int = 0
    self.episode_id: int = random.randrange(2e9)
    self.env_id = env_id
    self.worker = worker
    self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
    self.custom_metrics: Dict[str, float] = {}
    self.user_data: Dict[str, Any] = {}
    self.hist_data: Dict[str, List[float]] = {}
    self.media: Dict[str, Any] = {}
    self.policy_map: PolicyMap = policies
    self._policies = self.policy_map  # backward compatibility
    self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
                                     PolicyID] = policy_mapping_fn
    self._next_agent_index: int = 0
    self._agent_to_index: Dict[AgentID, int] = {}
    self._agent_to_policy: Dict[AgentID, PolicyID] = {}
    self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
    self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
    self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
    self._agent_to_last_done: Dict[AgentID, bool] = {}
    self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
    self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
    self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
    self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
    self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
        list)

get_agents(self)

Returns list of agent IDs that have appeared in this episode.

Returns:

Type Description
List[Any]

The list of all agent IDs that have appeared so far in this episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def get_agents(self) -> List[AgentID]:
    """Returns list of agent IDs that have appeared in this episode.

    Returns:
        The list of all agent IDs that have appeared so far in this
        episode.
    """
    return list(self._agent_to_index.keys())

last_action_for(self, agent_id='agent0')

Returns the last action for the specified AgentID, or zeros.

The "last" action is the most recent one taken by the agent.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last action for.

'agent0'

Returns:

Type Description
Any

Last action the specified AgentID has executed. Zeros in case the agent has never performed any actions in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_action_for(self,
                    agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
    """Returns the last action for the specified AgentID, or zeros.

    The "last" action is the most recent one taken by the agent.

    Args:
        agent_id: The agent's ID to get the last action for.

    Returns:
        Last action the specified AgentID has executed.
        Zeros in case the agent has never performed any actions in the
        episode.
    """
    # Agent has already taken at least one action in the episode.
    if agent_id in self._agent_to_last_action:
        return flatten_to_single_ndarray(
            self._agent_to_last_action[agent_id])
    # Agent has not acted yet, return all zeros.
    else:
        policy_id = self.policy_for(agent_id)
        policy = self.policy_map[policy_id]
        flat = flatten_to_single_ndarray(policy.action_space.sample())
        if hasattr(policy.action_space, "dtype"):
            return np.zeros_like(flat, dtype=policy.action_space.dtype)
        return np.zeros_like(flat)

last_done_for(self, agent_id='agent0')

Returns the last done flag for the specified AgentID.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last done flag for.

'agent0'

Returns:

Type Description
bool

Last done flag for the specified AgentID.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
    """Returns the last done flag for the specified AgentID.

    Args:
        agent_id: The agent's ID to get the last done flag for.

    Returns:
        Last done flag for the specified AgentID.
    """
    if agent_id not in self._agent_to_last_done:
        self._agent_to_last_done[agent_id] = False
    return self._agent_to_last_done[agent_id]

last_extra_action_outs_for(self, agent_id='agent0')

Returns the last extra-action outputs for the specified agent.

This data is returned by a call to Policy.compute_actions_from_input_dict as the 3rd return value (1st return value = action; 2nd return value = RNN state outs).

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last extra-action outs for.

'agent0'

Returns:

Type Description
dict

The last extra-action outs for the specified AgentID.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_extra_action_outs_for(
        self,
        agent_id: AgentID = _DUMMY_AGENT_ID,
) -> dict:
    """Returns the last extra-action outputs for the specified agent.

    This data is returned by a call to
    `Policy.compute_actions_from_input_dict` as the 3rd return value
    (1st return value = action; 2nd return value = RNN state outs).

    Args:
        agent_id: The agent's ID to get the last extra-action outs for.

    Returns:
        The last extra-action outs for the specified AgentID.
    """
    return self._agent_to_last_extra_action_outs[agent_id]

last_info_for(self, agent_id='agent0')

Returns the last info for the specified AgentID.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last info for.

'agent0'

Returns:

Type Description
Optional[dict]

Last info dict the specified AgentID has seen. None in case the agent has never made any observations in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
                  ) -> Optional[EnvInfoDict]:
    """Returns the last info for the specified AgentID.

    Args:
        agent_id: The agent's ID to get the last info for.

    Returns:
        Last info dict the specified AgentID has seen.
        None in case the agent has never made any observations in the
        episode.
    """
    return self._agent_to_last_info.get(agent_id)

last_observation_for(self, agent_id='agent0')

Returns the last observation for the specified AgentID.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last observation for.

'agent0'

Returns:

Type Description
Optional[Any]

Last observation the specified AgentID has seen. None in case the agent has never made any observations in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_observation_for(
        self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
    """Returns the last observation for the specified AgentID.

    Args:
        agent_id: The agent's ID to get the last observation for.

    Returns:
        Last observation the specified AgentID has seen. None in case
        the agent has never made any observations in the episode.
    """

    return self._agent_to_last_obs.get(agent_id)

last_raw_obs_for(self, agent_id='agent0')

Returns the last un-preprocessed obs for the specified AgentID.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last un-preprocessed observation for.

'agent0'

Returns:

Type Description
Optional[Any]

Last un-preprocessed observation the specified AgentID has seen. None in case the agent has never made any observations in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_raw_obs_for(
        self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
    """Returns the last un-preprocessed obs for the specified AgentID.

    Args:
        agent_id: The agent's ID to get the last un-preprocessed
            observation for.

    Returns:
        Last un-preprocessed observation the specified AgentID has seen.
        None in case the agent has never made any observations in the
        episode.
    """
    return self._agent_to_last_raw_obs.get(agent_id)

last_reward_for(self, agent_id='agent0')

Returns the last reward for the specified agent, or zero.

The "last" reward is the one received most recently by the agent.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the last reward for.

'agent0'

Returns:

Type Description
float

Last reward for the the specified AgentID. Zero in case the agent has never performed any actions (and thus received rewards) in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
    """Returns the last reward for the specified agent, or zero.

    The "last" reward is the one received most recently by the agent.

    Args:
        agent_id: The agent's ID to get the last reward for.

    Returns:
        Last reward for the the specified AgentID.
        Zero in case the agent has never performed any actions
        (and thus received rewards) in the episode.
    """

    history = self._agent_reward_history[agent_id]
    # We are at t > 0 -> Return previously received reward.
    if len(history) >= 1:
        return history[-1]
    # We're at t=0, so there is no previous reward, just return zero.
    else:
        return 0.0

policy_for(self, agent_id='agent0')

Returns and stores the policy ID for the specified agent.

If the agent is new, the policy mapping fn will be called to bind the agent to a policy for the duration of the entire episode (even if the policy_mapping_fn is changed in the meantime!).

Parameters:

Name Type Description Default
agent_id Any

The agent ID to lookup the policy ID for.

'agent0'

Returns:

Type Description
str

The policy ID for the specified agent.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
    """Returns and stores the policy ID for the specified agent.

    If the agent is new, the policy mapping fn will be called to bind the
    agent to a policy for the duration of the entire episode (even if the
    policy_mapping_fn is changed in the meantime!).

    Args:
        agent_id: The agent ID to lookup the policy ID for.

    Returns:
        The policy ID for the specified agent.
    """

    # Perform a new policy_mapping_fn lookup and bind AgentID for the
    # duration of this episode to the returned PolicyID.
    if agent_id not in self._agent_to_policy:
        # Try new API: pass in agent_id and episode as named args.
        # New signature should be: (agent_id, episode, worker, **kwargs)
        try:
            policy_id = self._agent_to_policy[agent_id] = \
                self.policy_mapping_fn(agent_id, self, worker=self.worker)
        except TypeError as e:
            if "positional argument" in e.args[0] or \
                    "unexpected keyword argument" in e.args[0]:
                if log_once("policy_mapping_new_signature"):
                    deprecation_warning(
                        old="policy_mapping_fn(agent_id)",
                        new="policy_mapping_fn(agent_id, episode, "
                        "worker, **kwargs)")
                policy_id = self._agent_to_policy[agent_id] = \
                    self.policy_mapping_fn(agent_id)
            else:
                raise e
    # Use already determined PolicyID.
    else:
        policy_id = self._agent_to_policy[agent_id]

    # PolicyID not found in policy map -> Error.
    if policy_id not in self.policy_map:
        raise KeyError("policy_mapping_fn returned invalid policy id "
                       f"'{policy_id}'!")
    return policy_id

prev_action_for(self, agent_id='agent0')

Returns the previous action for the specified agent, or zeros.

The "previous" action is the one taken one timestep before the most recent action taken by the agent.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the previous action for.

'agent0'

Returns:

Type Description
Any

Previous action the specified AgentID has executed. Zero in case the agent has never performed any actions (or only one) in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def prev_action_for(self,
                    agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
    """Returns the previous action for the specified agent, or zeros.

    The "previous" action is the one taken one timestep before the
    most recent action taken by the agent.

    Args:
        agent_id: The agent's ID to get the previous action for.

    Returns:
        Previous action the specified AgentID has executed.
        Zero in case the agent has never performed any actions (or only
        one) in the episode.
    """
    # We are at t > 1 -> There has been a previous action by this agent.
    if agent_id in self._agent_to_prev_action:
        return flatten_to_single_ndarray(
            self._agent_to_prev_action[agent_id])
    # We're at t <= 1, so return all zeros.
    else:
        return np.zeros_like(self.last_action_for(agent_id))

prev_reward_for(self, agent_id='agent0')

Returns the previous reward for the specified agent, or zero.

The "previous" reward is the one received one timestep before the most recently received reward of the agent.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the previous reward for.

'agent0'

Returns:

Type Description
float

Previous reward for the the specified AgentID. Zero in case the agent has never performed any actions (or only one) in the episode.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
    """Returns the previous reward for the specified agent, or zero.

    The "previous" reward is the one received one timestep before the
    most recently received reward of the agent.

    Args:
        agent_id: The agent's ID to get the previous reward for.

    Returns:
        Previous reward for the the specified AgentID.
        Zero in case the agent has never performed any actions (or only
        one) in the episode.
    """

    history = self._agent_reward_history[agent_id]
    # We are at t > 1 -> Return reward prior to most recent (last) one.
    if len(history) >= 2:
        return history[-2]
    # We're at t <= 1, so there is no previous reward, just return zero.
    else:
        return 0.0

rnn_state_for(self, agent_id='agent0')

Returns the last RNN state for the specified agent.

Parameters:

Name Type Description Default
agent_id Any

The agent's ID to get the most recent RNN state for.

'agent0'

Returns:

Type Description
List[Any]

Most recent RNN state of the the specified AgentID.

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
    """Returns the last RNN state for the specified agent.

    Args:
        agent_id: The agent's ID to get the most recent RNN state for.

    Returns:
        Most recent RNN state of the the specified AgentID.
    """

    if agent_id not in self._agent_to_rnn_state:
        policy_id = self.policy_for(agent_id)
        policy = self.policy_map[policy_id]
        self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
    return self._agent_to_rnn_state[agent_id]

soft_reset(self)

Clears rewards and metrics, but retains RNN and other state.

This is used to carry state across multiple logical episodes in the same env (i.e., if soft_horizon is set).

Source code in ray/rllib/evaluation/episode.py
@DeveloperAPI
def soft_reset(self) -> None:
    """Clears rewards and metrics, but retains RNN and other state.

    This is used to carry state across multiple logical episodes in the
    same env (i.e., if `soft_horizon` is set).
    """
    self.length = 0
    self.episode_id = random.randrange(2e9)
    self.total_reward = 0.0
    self.agent_rewards = defaultdict(float)
    self._agent_reward_history = defaultdict(list)

ray.rllib.evaluation.episode.MultiAgentEpisode (Episode)

Rollouts

ray.rllib.evaluation.rollout_worker.RolloutWorker (ParallelIteratorWorker)

Common experience collection class.

This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training.

This class supports vectorized and multi-agent policy evaluation (e.g., VectorEnv, MultiAgentEnv, etc.)

Examples:

>>> # Create a rollout worker and using it to collect experiences.
>>> worker = RolloutWorker(
...   env_creator=lambda _: gym.make("CartPole-v0"),
...   policy_spec=PGTFPolicy)
>>> print(worker.sample())
SampleBatch({
    "obs": [[...]], "actions": [[...]], "rewards": [[...]],
    "dones": [[...]], "new_obs": [[...]]})
>>> # Creating a multi-agent rollout worker
>>> worker = RolloutWorker(
...   env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
...   policy_spec={
...       # Use an ensemble of two policies for car agents
...       "car_policy1":
...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
...       "car_policy2":
...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
...       # Use a single shared policy for all traffic lights
...       "traffic_light_policy":
...         (PGTFPolicy, Box(...), Discrete(...), {}),
...   },
...   policy_mapping_fn=lambda agent_id, episode, **kwargs:
...     random.choice(["car_policy1", "car_policy2"])
...     if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(worker.sample())
MultiAgentBatch({
    "car_policy1": SampleBatch(...),
    "car_policy2": SampleBatch(...),
    "traffic_light_policy": SampleBatch(...)})

__init__(self, *, env_creator, validate_env=None, policy_spec=None, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, rollout_fragment_length=100, count_steps_by='env_steps', batch_mode='truncate_episodes', episode_horizon=None, preprocessor_pref='deepmind', sample_async=False, compress_observations=False, num_envs=1, observation_fn=None, observation_filter='NoFilter', clip_rewards=None, normalize_actions=True, clip_actions=False, env_config=None, model_config=None, policy_config=None, worker_index=0, num_workers=0, record_env=False, log_dir=None, log_level=None, callbacks=None, input_creator=<function RolloutWorker.<lambda> at 0x11d91bd40>, input_evaluation=frozenset(), output_creator=<function RolloutWorker.<lambda> at 0x11d91bdd0>, remote_worker_envs=False, remote_env_batch_wait_ms=0, soft_horizon=False, no_done_at_end=False, seed=None, extra_python_environs=None, fake_sampler=False, spaces=None, policy=None, monitor_path=None) special

Initializes a RolloutWorker instance.

Parameters:

Name Type Description Default
env_creator Callable[[ray.rllib.env.env_context.EnvContext], Any]

Function that returns a gym.Env given an EnvContext wrapped configuration.

required
validate_env Optional[Callable[[Any, ray.rllib.env.env_context.EnvContext], NoneType]]

Optional callable to validate the generated environment (only on worker=0).

None
policy_spec Union[type, Dict[str, ray.rllib.policy.policy.PolicySpec]]

The MultiAgentPolicyConfigDict mapping policy IDs (str) to PolicySpec's or a single policy class to use. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn can also be set (if not, will map all agents to DEFAULT_POLICY_ID).

None
policy_mapping_fn Optional[Callable[[Any, Episode], str]]

A callable that maps agent ids to policy ids in multi-agent mode. This function will be called each time a new agent appears in an episode, to bind that agent to a policy for the duration of the episode. If not provided, will map all agents to DEFAULT_POLICY_ID.

None
policies_to_train Optional[List[str]]

Optional list of policies to train, or None for all policies.

None
tf_session_creator Optional[Callable[[], tf1.Session]]

A function that returns a TF session. This is optional and only useful with TFPolicy.

None
rollout_fragment_length int

The target number of steps (maesured in count_steps_by) to include in each sample batch returned from this worker.

100
count_steps_by str

The unit in which to count fragment lengths. One of env_steps or agent_steps.

'env_steps'
batch_mode str

One of the following batch modes: - "truncate_episodes": Each call to sample() will return a batch of at most rollout_fragment_length * num_envs in size. The batch will be exactly rollout_fragment_length * num_envs in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. - "complete_episodes": Each call to sample() will return a batch of at least rollout_fragment_length * num_envs in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the batch size. Note that when num_envs > 1, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data.

'truncate_episodes'
episode_horizon Optional[int]

Horizon at which to stop episodes (even if the environment itself has not retured a "done" signal).

None
preprocessor_pref str

Whether to use RLlib preprocessors ("rllib") or deepmind ("deepmind"), when applicable.

'deepmind'
sample_async bool

Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy.

False
compress_observations bool

If true, compress the observations. They can be decompressed with rllib/utils/compression.

False
num_envs int

If more than one, will create multiple envs and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv.

1
observation_fn Optional[ObservationFunction]

Optional multi-agent observation function.

None
observation_filter str

Name of observation filter to use.

'NoFilter'
clip_rewards Union[bool, float]

True for clipping rewards to [-1.0, 1.0] prior to experience postprocessing. None: Clip for Atari only. float: Clip to [-clip_rewards; +clip_rewards].

None
normalize_actions bool

Whether to normalize actions to the action space's bounds.

True
clip_actions bool

Whether to clip action values to the range specified by the policy action space.

False
env_config Optional[dict]

Config to pass to the env creator.

None
model_config Optional[dict]

Config to use when creating the policy model.

None
policy_config Optional[dict]

Config to pass to the policy. In the multi-agent case, this config will be merged with the per-policy configs specified by policy_spec.

None
worker_index int

For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker.

0
num_workers int

For remote workers, how many workers altogether have been created?

0
record_env Union[bool, str]

Write out episode stats and videos using gym.wrappers.Monitor to this directory if specified. If True, use the default output dir in ~/ray_results/.... If False, do not record anything.

False
log_dir Optional[str]

Directory where logs can be placed.

None
log_level Optional[str]

Set the root log level on creation.

None
callbacks Type[DefaultCallbacks]

Custom sub-class of DefaultCallbacks for training/policy/rollout-worker callbacks.

None
input_creator Callable[[ray.rllib.offline.io_context.IOContext], ray.rllib.offline.input_reader.InputReader]

Function that returns an InputReader object for loading previous generated experiences.

<function RolloutWorker.<lambda> at 0x11d91bd40>
input_evaluation List[str]

How to evaluate the policy performance. This only makes sense to set when the input is reading offline data. The possible values include: - "is": the step-wise importance sampling estimator. - "wis": the weighted step-wise is estimator. - "simulation": run the environment in the background, but use this data for evaluation only and never for learning.

frozenset()
output_creator Callable[[ray.rllib.offline.io_context.IOContext], ray.rllib.offline.output_writer.OutputWriter]

Function that returns an OutputWriter object for saving generated experiences.

<function RolloutWorker.<lambda> at 0x11d91bdd0>
remote_worker_envs bool

If using num_envs_per_worker > 1, whether to create those new envs in remote processes instead of in the current process. This adds overheads, but can make sense if your envs are expensive to step/reset (e.g., for StarCraft). Use this cautiously, overheads are significant!

False
remote_env_batch_wait_ms int

Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf.

0
soft_horizon bool

Calculate rewards but don't reset the environment when the horizon is hit.

False
no_done_at_end bool

Ignore the done=True at the end of the episode and instead record done=False.

False
seed int

Set the seed of both np and tf to this value to to ensure each remote worker has unique exploration behavior.

None
extra_python_environs Optional[dict]

Extra python environments need to be set.

None
fake_sampler bool

Use a fake (inf speed) sampler for testing.

False
spaces Optional[Dict[str, Tuple[gym.spaces.space.Space, gym.spaces.space.Space]]]

An optional space dict mapping policy IDs to (obs_space, action_space)-tuples. This is used in case no Env is created on this RolloutWorker.

None
policy

Obsoleted arg. Use policy_spec instead.

None
monitor_path

Obsoleted arg. Use record_env instead.

None
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def __init__(
        self,
        *,
        env_creator: Callable[[EnvContext], EnvType],
        validate_env: Optional[Callable[[EnvType, EnvContext],
                                        None]] = None,
        policy_spec: Optional[Union[type, Dict[PolicyID,
                                               PolicySpec]]] = None,
        policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
                                             PolicyID]] = None,
        policies_to_train: Optional[List[PolicyID]] = None,
        tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
        rollout_fragment_length: int = 100,
        count_steps_by: str = "env_steps",
        batch_mode: str = "truncate_episodes",
        episode_horizon: Optional[int] = None,
        preprocessor_pref: str = "deepmind",
        sample_async: bool = False,
        compress_observations: bool = False,
        num_envs: int = 1,
        observation_fn: Optional["ObservationFunction"] = None,
        observation_filter: str = "NoFilter",
        clip_rewards: Optional[Union[bool, float]] = None,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        env_config: Optional[EnvConfigDict] = None,
        model_config: Optional[ModelConfigDict] = None,
        policy_config: Optional[PartialTrainerConfigDict] = None,
        worker_index: int = 0,
        num_workers: int = 0,
        record_env: Union[bool, str] = False,
        log_dir: Optional[str] = None,
        log_level: Optional[str] = None,
        callbacks: Type["DefaultCallbacks"] = None,
        input_creator: Callable[[
            IOContext
        ], InputReader] = lambda ioctx: ioctx.default_sampler_input(),
        input_evaluation: List[str] = frozenset([]),
        output_creator: Callable[
            [IOContext], OutputWriter] = lambda ioctx: NoopOutput(),
        remote_worker_envs: bool = False,
        remote_env_batch_wait_ms: int = 0,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        seed: int = None,
        extra_python_environs: Optional[dict] = None,
        fake_sampler: bool = False,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
        policy=None,
        monitor_path=None,
):
    """Initializes a RolloutWorker instance.

    Args:
        env_creator: Function that returns a gym.Env given an EnvContext
            wrapped configuration.
        validate_env: Optional callable to validate the generated
            environment (only on worker=0).
        policy_spec: The MultiAgentPolicyConfigDict mapping policy IDs
            (str) to PolicySpec's or a single policy class to use.
            If a dict is specified, then we are in multi-agent mode and a
            policy_mapping_fn can also be set (if not, will map all agents
            to DEFAULT_POLICY_ID).
        policy_mapping_fn: A callable that maps agent ids to policy ids in
            multi-agent mode. This function will be called each time a new
            agent appears in an episode, to bind that agent to a policy
            for the duration of the episode. If not provided, will map all
            agents to DEFAULT_POLICY_ID.
        policies_to_train: Optional list of policies to train, or None
            for all policies.
        tf_session_creator: A function that returns a TF session.
            This is optional and only useful with TFPolicy.
        rollout_fragment_length: The target number of steps
            (maesured in `count_steps_by`) to include in each sample
            batch returned from this worker.
        count_steps_by: The unit in which to count fragment
            lengths. One of env_steps or agent_steps.
        batch_mode: One of the following batch modes:
            - "truncate_episodes": Each call to sample() will return a
            batch of at most `rollout_fragment_length * num_envs` in size.
            The batch will be exactly `rollout_fragment_length * num_envs`
            in size if postprocessing does not change batch sizes. Episodes
            may be truncated in order to meet this size requirement.
            - "complete_episodes": Each call to sample() will return a
            batch of at least `rollout_fragment_length * num_envs` in
            size. Episodes will not be truncated, but multiple episodes
            may be packed within one batch to meet the batch size. Note
            that when `num_envs > 1`, episode steps will be buffered
            until the episode completes, and hence batches may contain
            significant amounts of off-policy data.
        episode_horizon: Horizon at which to stop episodes (even if the
            environment itself has not retured a "done" signal).
        preprocessor_pref: Whether to use RLlib preprocessors
            ("rllib") or deepmind ("deepmind"), when applicable.
        sample_async: Whether to compute samples asynchronously in
            the background, which improves throughput but can cause samples
            to be slightly off-policy.
        compress_observations: If true, compress the observations.
            They can be decompressed with rllib/utils/compression.
        num_envs: If more than one, will create multiple envs
            and vectorize the computation of actions. This has no effect if
            if the env already implements VectorEnv.
        observation_fn: Optional multi-agent observation function.
        observation_filter: Name of observation filter to use.
        clip_rewards: True for clipping rewards to [-1.0, 1.0] prior
            to experience postprocessing. None: Clip for Atari only.
            float: Clip to [-clip_rewards; +clip_rewards].
        normalize_actions: Whether to normalize actions to the
            action space's bounds.
        clip_actions: Whether to clip action values to the range
            specified by the policy action space.
        env_config: Config to pass to the env creator.
        model_config: Config to use when creating the policy model.
        policy_config: Config to pass to the
            policy. In the multi-agent case, this config will be merged
            with the per-policy configs specified by `policy_spec`.
        worker_index: For remote workers, this should be set to a
            non-zero and unique value. This index is passed to created envs
            through EnvContext so that envs can be configured per worker.
        num_workers: For remote workers, how many workers altogether
            have been created?
        record_env: Write out episode stats and videos
            using gym.wrappers.Monitor to this directory if specified. If
            True, use the default output dir in ~/ray_results/.... If
            False, do not record anything.
        log_dir: Directory where logs can be placed.
        log_level: Set the root log level on creation.
        callbacks: Custom sub-class of
            DefaultCallbacks for training/policy/rollout-worker callbacks.
        input_creator: Function that returns an InputReader object for
            loading previous generated experiences.
        input_evaluation: How to evaluate the policy
            performance. This only makes sense to set when the input is
            reading offline data. The possible values include:
            - "is": the step-wise importance sampling estimator.
            - "wis": the weighted step-wise is estimator.
            - "simulation": run the environment in the background, but
            use this data for evaluation only and never for learning.
        output_creator: Function that returns an OutputWriter object for
            saving generated experiences.
        remote_worker_envs: If using num_envs_per_worker > 1,
            whether to create those new envs in remote processes instead of
            in the current process. This adds overheads, but can make sense
            if your envs are expensive to step/reset (e.g., for StarCraft).
            Use this cautiously, overheads are significant!
        remote_env_batch_wait_ms: Timeout that remote workers
            are waiting when polling environments. 0 (continue when at
            least one env is ready) is a reasonable default, but optimal
            value could be obtained by measuring your environment
            step / reset and model inference perf.
        soft_horizon: Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end: Ignore the done=True at the end of the
            episode and instead record done=False.
        seed: Set the seed of both np and tf to this value to
            to ensure each remote worker has unique exploration behavior.
        extra_python_environs: Extra python environments need to be set.
        fake_sampler: Use a fake (inf speed) sampler for testing.
        spaces: An optional space dict mapping policy IDs
            to (obs_space, action_space)-tuples. This is used in case no
            Env is created on this RolloutWorker.
        policy: Obsoleted arg. Use `policy_spec` instead.
        monitor_path: Obsoleted arg. Use `record_env` instead.
    """

    # Deprecated args.
    if policy is not None:
        deprecation_warning("policy", "policy_spec", error=False)
        policy_spec = policy
    assert policy_spec is not None, \
        "Must provide `policy_spec` when creating RolloutWorker!"

    # Do quick translation into MultiAgentPolicyConfigDict.
    if not isinstance(policy_spec, dict):
        policy_spec = {
            DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)
        }
    policy_spec = {
        pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec)
        for pid, spec in policy_spec.copy().items()
    }

    if monitor_path is not None:
        deprecation_warning("monitor_path", "record_env", error=False)
        record_env = monitor_path

    self._original_kwargs: dict = locals().copy()
    del self._original_kwargs["self"]

    global _global_worker
    _global_worker = self

    # set extra environs first
    if extra_python_environs:
        for key, value in extra_python_environs.items():
            os.environ[key] = str(value)

    def gen_rollouts():
        while True:
            yield self.sample()

    ParallelIteratorWorker.__init__(self, gen_rollouts, False)

    policy_config = policy_config or {}
    if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
            # This eager check is necessary for certain all-framework tests
            # that use tf's eager_mode() context generator.
            and not tf1.executing_eagerly()):
        tf1.enable_eager_execution()

    if log_level:
        logging.getLogger("ray.rllib").setLevel(log_level)

    if worker_index > 1:
        disable_log_once_globally()  # only need 1 worker to log
    elif log_level == "DEBUG":
        enable_periodic_logging()

    env_context = EnvContext(
        env_config or {},
        worker_index=worker_index,
        vector_index=0,
        num_workers=num_workers,
    )
    self.env_context = env_context
    self.policy_config: PartialTrainerConfigDict = policy_config
    if callbacks:
        self.callbacks: "DefaultCallbacks" = callbacks()
    else:
        from ray.rllib.agents.callbacks import DefaultCallbacks  # noqa
        self.callbacks: DefaultCallbacks = DefaultCallbacks()
    self.worker_index: int = worker_index
    self.num_workers: int = num_workers
    model_config: ModelConfigDict = \
        model_config or self.policy_config.get("model") or {}

    # Default policy mapping fn is to always return DEFAULT_POLICY_ID,
    # independent on the agent ID and the episode passed in.
    self.policy_mapping_fn = \
        lambda agent_id, episode, worker, **kwargs: DEFAULT_POLICY_ID
    # If provided, set it here.
    self.set_policy_mapping_fn(policy_mapping_fn)

    self.env_creator: Callable[[EnvContext], EnvType] = env_creator
    self.rollout_fragment_length: int = rollout_fragment_length * num_envs
    self.count_steps_by: str = count_steps_by
    self.batch_mode: str = batch_mode
    self.compress_observations: bool = compress_observations
    self.preprocessing_enabled: bool = False \
        if policy_config.get("_disable_preprocessor_api") else True
    self.observation_filter = observation_filter
    self.last_batch: Optional[SampleBatchType] = None
    self.global_vars: Optional[dict] = None
    self.fake_sampler: bool = fake_sampler

    # Update the global seed for numpy/random/tf-eager/torch if we are not
    # the local worker, otherwise, this was already done in the Trainer
    # object itself.
    if self.worker_index > 0:
        update_global_seed_if_necessary(
            policy_config.get("framework"), seed)

    # A single environment provided by the user (via config.env). This may
    # also remain None.
    # 1) Create the env using the user provided env_creator. This may
    #    return a gym.Env (incl. MultiAgentEnv), an already vectorized
    #    VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
    # 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
    # 3) Seed the env, if necessary.
    # 4) Vectorize the existing single env by creating more clones of
    #    this env and wrapping it with the RLlib BaseEnv class.
    self.env = None

    # Create a (single) env for this worker.
    if not (worker_index == 0 and num_workers > 0
            and not policy_config.get("create_env_on_driver")):
        # Run the `env_creator` function passing the EnvContext.
        self.env = env_creator(copy.deepcopy(self.env_context))

    if self.env is not None:
        # Validate environment (general validation function).
        _validate_env(self.env, env_context=self.env_context)
        # Custom validation function given.
        if validate_env is not None:
            validate_env(self.env, self.env_context)
        # We can't auto-wrap a BaseEnv.
        if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):

            def wrap(env):
                return env

        # Atari type env and "deepmind" preprocessor pref.
        elif is_atari(self.env) and \
                not model_config.get("custom_preprocessor") and \
                preprocessor_pref == "deepmind":

            # Deepmind wrappers already handle all preprocessing.
            self.preprocessing_enabled = False

            # If clip_rewards not explicitly set to False, switch it
            # on here (clip between -1.0 and 1.0).
            if clip_rewards is None:
                clip_rewards = True

            # Framestacking is used.
            use_framestack = model_config.get("framestack") is True

            def wrap(env):
                env = wrap_deepmind(
                    env,
                    dim=model_config.get("dim"),
                    framestack=use_framestack)
                env = record_env_wrapper(env, record_env, log_dir,
                                         policy_config)
                return env

        # gym.Env -> Wrap with gym Monitor.
        else:

            def wrap(env):
                return record_env_wrapper(env, record_env, log_dir,
                                          policy_config)

        # Wrap env through the correct wrapper.
        self.env: EnvType = wrap(self.env)
        # Ideally, we would use the same make_sub_env() function below
        # to create self.env, but wrap(env) and self.env has a cyclic
        # dependency on each other right now, so we would settle on
        # duplicating the random seed setting logic for now.
        _update_env_seed_if_necessary(self.env, seed, worker_index, 0)

    def make_sub_env(vector_index):
        # Used to created additional environments during environment
        # vectorization.

        # Create the env context (config dict + meta-data) for
        # this particular sub-env within the vectorized one.
        env_ctx = env_context.copy_with_overrides(
            worker_index=worker_index,
            vector_index=vector_index,
            remote=remote_worker_envs)
        # Create the sub-env.
        env = env_creator(env_ctx)
        # Validate first.
        _validate_env(env, env_context=env_ctx)
        # Custom validation function given by user.
        if validate_env is not None:
            validate_env(env, env_ctx)
        # Use our wrapper, defined above.
        env = wrap(env)

        # Make sure a deterministic random seed is set on
        # all the sub-environments if specified.
        _update_env_seed_if_necessary(env, seed, worker_index,
                                      vector_index)
        return env

    self.make_sub_env_fn = make_sub_env
    self.spaces = spaces

    policy_dict = _determine_spaces_for_multi_agent_dict(
        policy_spec,
        self.env,
        spaces=self.spaces,
        policy_config=policy_config)

    # List of IDs of those policies, which should be trained.
    # By default, these are all policies found in the policy_dict.
    self.policies_to_train: List[PolicyID] = policies_to_train or list(
        policy_dict.keys())
    self.set_policies_to_train(self.policies_to_train)

    self.policy_map: PolicyMap = None
    self.preprocessors: Dict[PolicyID, Preprocessor] = None

    # Check available number of GPUs.
    num_gpus = policy_config.get("num_gpus", 0) if \
        self.worker_index == 0 else \
        policy_config.get("num_gpus_per_worker", 0)
    # Error if we don't find enough GPUs.
    if ray.is_initialized() and \
            ray.worker._mode() != ray.worker.LOCAL_MODE and \
            not policy_config.get("_fake_gpus"):

        devices = []
        if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
            devices = get_tf_gpu_devices()
        elif policy_config.get("framework") == "torch":
            devices = list(range(torch.cuda.device_count()))

        if len(devices) < num_gpus:
            raise RuntimeError(
                ERR_MSG_NO_GPUS.format(len(devices), devices) +
                HOWTO_CHANGE_CONFIG)
    # Warn, if running in local-mode and actual GPUs (not faked) are
    # requested.
    elif ray.is_initialized() and \
            ray.worker._mode() == ray.worker.LOCAL_MODE and \
            num_gpus > 0 and not policy_config.get("_fake_gpus"):
        logger.warning(
            "You are running ray with `local_mode=True`, but have "
            f"configured {num_gpus} GPUs to be used! In local mode, "
            f"Policies are placed on the CPU and the `num_gpus` setting "
            f"is ignored.")

    self._build_policy_map(
        policy_dict,
        policy_config,
        session_creator=tf_session_creator,
        seed=seed)

    # Update Policy's view requirements from Model, only if Policy directly
    # inherited from base `Policy` class. At this point here, the Policy
    # must have it's Model (if any) defined and ready to output an initial
    # state.
    for pol in self.policy_map.values():
        if not pol._model_init_state_automatically_added:
            pol._update_model_view_requirements_from_init_state()

    self.multiagent: bool = set(
        self.policy_map.keys()) != {DEFAULT_POLICY_ID}
    if self.multiagent and self.env is not None:
        if not isinstance(self.env,
                          (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
                           ray.actor.ActorHandle)):
            raise ValueError(
                f"Have multiple policies {self.policy_map}, but the "
                f"env {self.env} is not a subclass of BaseEnv, "
                f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")

    self.filters: Dict[PolicyID, Filter] = {
        policy_id: get_filter(self.observation_filter,
                              policy.observation_space.shape)
        for (policy_id, policy) in self.policy_map.items()
    }
    if self.worker_index == 0:
        logger.info("Built filter map: {}".format(self.filters))

    # Vectorize environment, if any.
    self.num_envs: int = num_envs
    # This RolloutWorker has no env.
    if self.env is None:
        self.async_env = None
    # Use a custom env-vectorizer and call it providing self.env.
    elif "custom_vector_env" in policy_config:
        self.async_env = policy_config["custom_vector_env"](self.env)
    # Default: Vectorize self.env via the make_sub_env function. This adds
    # further clones of self.env and creates a RLlib BaseEnv (which is
    # vectorized under the hood).
    else:
        # Always use vector env for consistency even if num_envs = 1.
        self.async_env: BaseEnv = BaseEnv.to_base_env(
            self.env,
            make_env=self.make_sub_env_fn,
            num_envs=num_envs,
            remote_envs=remote_worker_envs,
            remote_env_batch_wait_ms=remote_env_batch_wait_ms,
            policy_config=policy_config,
        )

    # `truncate_episodes`: Allow a batch to contain more than one episode
    # (fragments) and always make the batch `rollout_fragment_length`
    # long.
    if self.batch_mode == "truncate_episodes":
        pack = True
    # `complete_episodes`: Never cut episodes and sampler will return
    # exactly one (complete) episode per poll.
    elif self.batch_mode == "complete_episodes":
        rollout_fragment_length = float("inf")
        pack = False
    else:
        raise ValueError("Unsupported batch mode: {}".format(
            self.batch_mode))

    # Create the IOContext for this worker.
    self.io_context: IOContext = IOContext(log_dir, policy_config,
                                           worker_index, self)
    self.reward_estimators: List[OffPolicyEstimator] = []
    for method in input_evaluation:
        if method == "simulation":
            logger.warning(
                "Requested 'simulation' input evaluation method: "
                "will discard all sampler outputs and keep only metrics.")
            sample_async = True
        elif method == "is":
            ise = ImportanceSamplingEstimator.\
                create_from_io_context(self.io_context)
            self.reward_estimators.append(ise)
        elif method == "wis":
            wise = WeightedImportanceSamplingEstimator.\
                create_from_io_context(self.io_context)
            self.reward_estimators.append(wise)
        else:
            raise ValueError(
                "Unknown evaluation method: {}".format(method))

    render = False
    if policy_config.get("render_env") is True and \
            (num_workers == 0 or worker_index == 1):
        render = True

    if self.env is None:
        self.sampler = None
    elif sample_async:
        self.sampler = AsyncSampler(
            worker=self,
            env=self.async_env,
            clip_rewards=clip_rewards,
            rollout_fragment_length=rollout_fragment_length,
            count_steps_by=count_steps_by,
            callbacks=self.callbacks,
            horizon=episode_horizon,
            multiple_episodes_in_batch=pack,
            normalize_actions=normalize_actions,
            clip_actions=clip_actions,
            blackhole_outputs="simulation" in input_evaluation,
            soft_horizon=soft_horizon,
            no_done_at_end=no_done_at_end,
            observation_fn=observation_fn,
            sample_collector_class=policy_config.get("sample_collector"),
            render=render,
        )
        # Start the Sampler thread.
        self.sampler.start()
    else:
        self.sampler = SyncSampler(
            worker=self,
            env=self.async_env,
            clip_rewards=clip_rewards,
            rollout_fragment_length=rollout_fragment_length,
            count_steps_by=count_steps_by,
            callbacks=self.callbacks,
            horizon=episode_horizon,
            multiple_episodes_in_batch=pack,
            normalize_actions=normalize_actions,
            clip_actions=clip_actions,
            soft_horizon=soft_horizon,
            no_done_at_end=no_done_at_end,
            observation_fn=observation_fn,
            sample_collector_class=policy_config.get("sample_collector"),
            render=render,
        )

    self.input_reader: InputReader = input_creator(self.io_context)
    self.output_writer: OutputWriter = output_creator(self.io_context)

    logger.debug(
        "Created rollout worker with env {} ({}), policies {}".format(
            self.async_env, self.env, self.policy_map))

add_policy(self, *, policy_id, policy_cls, observation_space=None, action_space=None, config=None, policy_mapping_fn=None, policies_to_train=None)

Adds a new policy to this RolloutWorker.

Parameters:

Name Type Description Default
policy_id str

ID of the policy to add.

required
policy_cls Type[ray.rllib.policy.policy.Policy]

The Policy class to use for constructing the new Policy.

required
observation_space Optional[gym.spaces.space.Space]

The observation space of the policy to add.

None
action_space Optional[gym.spaces.space.Space]

The action space of the policy to add.

None
config Optional[dict]

The config overrides for the policy to add.

None
policy_config

The base config of the Trainer object owning this RolloutWorker.

required
policy_mapping_fn Optional[Callable[[Any, Episode], str]]

An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

None
policies_to_train Optional[List[str]]

An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

None

Returns:

Type Description
Policy

The newly added policy.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def add_policy(
        self,
        *,
        policy_id: PolicyID,
        policy_cls: Type[Policy],
        observation_space: Optional[gym.spaces.Space] = None,
        action_space: Optional[gym.spaces.Space] = None,
        config: Optional[PartialTrainerConfigDict] = None,
        policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
                                             PolicyID]] = None,
        policies_to_train: Optional[List[PolicyID]] = None,
) -> Policy:
    """Adds a new policy to this RolloutWorker.

    Args:
        policy_id: ID of the policy to add.
        policy_cls: The Policy class to use for constructing the new
            Policy.
        observation_space: The observation space of the policy to add.
        action_space: The action space of the policy to add.
        config: The config overrides for the policy to add.
        policy_config: The base config of the Trainer object owning this
            RolloutWorker.
        policy_mapping_fn: An optional (updated) policy mapping function
            to use from here on. Note that already ongoing episodes will
            not change their mapping but will use the old mapping till
            the end of the episode.
        policies_to_train: An optional list of policy IDs to be trained.
            If None, will keep the existing list in place. Policies,
            whose IDs are not in the list will not be updated.

    Returns:
        The newly added policy.
    """
    if policy_id in self.policy_map:
        raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
    policy_dict = _determine_spaces_for_multi_agent_dict(
        {
            policy_id: PolicySpec(policy_cls, observation_space,
                                  action_space, config or {})
        },
        self.env,
        spaces=self.spaces,
        policy_config=self.policy_config,
    )
    self._build_policy_map(
        policy_dict,
        self.policy_config,
        seed=self.policy_config.get("seed"))
    new_policy = self.policy_map[policy_id]

    self.filters[policy_id] = get_filter(
        self.observation_filter, new_policy.observation_space.shape)

    self.set_policy_mapping_fn(policy_mapping_fn)
    self.set_policies_to_train(policies_to_train)

    return new_policy

apply(self, func, *args)

Calls the given function with this rollout worker instance.

Parameters:

Name Type Description Default
func Callable[[RolloutWorker, Optional[Any]], ~T]

The function to call with this RolloutWorker as first argument.

required
args

Optional additional args to pass to the function call.

()

Returns:

Type Description
~T

The return value of the function call.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T],
          *args) -> T:
    """Calls the given function with this rollout worker instance.

    Args:
        func: The function to call with this RolloutWorker as first
            argument.
        args: Optional additional args to pass to the function call.

    Returns:
        The return value of the function call.
    """
    return func(self, *args)

apply_gradients(self, grads)

Applies the given gradients to this worker's models.

Uses the Policy's/ies' apply_gradients method(s) to perform the operations.

Parameters:

Name Type Description Default
grads Union[List[Tuple[Any, Any]], List[Any], Dict[str, Union[List[Tuple[Any, Any]], List[Any]]]]

Single ModelGradients (single-agent case) or a dict mapping PolicyIDs to the respective model gradients structs.

required

Examples:

>>> samples = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
>>> worker.apply_gradients(grads)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def apply_gradients(
        self,
        grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
) -> None:
    """Applies the given gradients to this worker's models.

    Uses the Policy's/ies' apply_gradients method(s) to perform the
    operations.

    Args:
        grads: Single ModelGradients (single-agent case) or a dict
            mapping PolicyIDs to the respective model gradients
            structs.

    Examples:
        >>> samples = worker.sample()
        >>> grads, info = worker.compute_gradients(samples)
        >>> worker.apply_gradients(grads)
    """
    if log_once("apply_gradients"):
        logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
    # Grads is a dict (mapping PolicyIDs to ModelGradients).
    # Multi-agent case.
    if isinstance(grads, dict):
        for pid, g in grads.items():
            if pid in self.policies_to_train:
                self.policy_map[pid].apply_gradients(g)
    # Grads is a ModelGradients type. Single-agent case.
    elif DEFAULT_POLICY_ID in self.policies_to_train:
        self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)

as_remote(num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None) classmethod

Returns RolloutWorker class as a @ray.remote using given options.

The returned class can then be used to instantiate ray actors.

Parameters:

Name Type Description Default
num_cpus Optional[int]

The number of CPUs to allocate for the remote actor.

None
num_gpus Union[int, float]

The number of GPUs to allocate for the remote actor. This could be a fraction as well.

None
memory Optional[int]

The heap memory request for the remote actor.

None
object_store_memory Optional[int]

The object store memory for the remote actor.

None
resources Optional[dict]

The default custom resources to allocate for the remote actor.

None

Returns:

Type Description
type

The @ray.remote decorated RolloutWorker class.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
@classmethod
def as_remote(cls,
              num_cpus: Optional[int] = None,
              num_gpus: Optional[Union[int, float]] = None,
              memory: Optional[int] = None,
              object_store_memory: Optional[int] = None,
              resources: Optional[dict] = None) -> type:
    """Returns RolloutWorker class as a `@ray.remote using given options`.

    The returned class can then be used to instantiate ray actors.

    Args:
        num_cpus: The number of CPUs to allocate for the remote actor.
        num_gpus: The number of GPUs to allocate for the remote actor.
            This could be a fraction as well.
        memory: The heap memory request for the remote actor.
        object_store_memory: The object store memory for the remote actor.
        resources: The default custom resources to allocate for the remote
            actor.

    Returns:
        The `@ray.remote` decorated RolloutWorker class.
    """
    return ray.remote(
        num_cpus=num_cpus,
        num_gpus=num_gpus,
        memory=memory,
        object_store_memory=object_store_memory,
        resources=resources)(cls)

compute_gradients(self, samples)

Returns a gradient computed w.r.t the specified samples.

Uses the Policy's/ies' compute_gradients method(s) to perform the calculations.

Parameters:

Name Type Description Default
samples Union[SampleBatch, MultiAgentBatch]

The SampleBatch or MultiAgentBatch to compute gradients for using this worker's policies.

required

Returns:

Type Description
Tuple[Union[List[Tuple[Any, Any]], List[Any]], dict]

In the single-agent case, a tuple consisting of ModelGradients and info dict of the worker's policy. In the multi-agent case, a tuple consisting of a dict mapping PolicyID to ModelGradients and a dict mapping PolicyID to extra metadata info. Note that the first return value (grads) can be applied as is to a compatible worker using the worker's apply_gradients() method.

Examples:

>>> batch = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def compute_gradients(
        self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]:
    """Returns a gradient computed w.r.t the specified samples.

    Uses the Policy's/ies' compute_gradients method(s) to perform the
    calculations.

    Args:
        samples: The SampleBatch or MultiAgentBatch to compute gradients
            for using this worker's policies.

    Returns:
        In the single-agent case, a tuple consisting of ModelGradients and
        info dict of the worker's policy.
        In the multi-agent case, a tuple consisting of a dict mapping
        PolicyID to ModelGradients and a dict mapping PolicyID to extra
        metadata info.
        Note that the first return value (grads) can be applied as is to a
        compatible worker using the worker's `apply_gradients()` method.

    Examples:
        >>> batch = worker.sample()
        >>> grads, info = worker.compute_gradients(samples)
    """
    if log_once("compute_gradients"):
        logger.info("Compute gradients on:\n\n{}\n".format(
            summarize(samples)))
    # MultiAgentBatch -> Calculate gradients for all policies.
    if isinstance(samples, MultiAgentBatch):
        grad_out, info_out = {}, {}
        if self.policy_config.get("framework") == "tf":
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                builder = TFRunBuilder(policy.get_session(),
                                       "compute_gradients")
                grad_out[pid], info_out[pid] = (
                    policy._build_compute_gradients(builder, batch))
            grad_out = {k: builder.get(v) for k, v in grad_out.items()}
            info_out = {k: builder.get(v) for k, v in info_out.items()}
        else:
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                grad_out[pid], info_out[pid] = (
                    self.policy_map[pid].compute_gradients(batch))
    # SampleBatch -> Calculate gradients for the default policy.
    else:
        grad_out, info_out = (
            self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))

    info_out["batch_count"] = samples.count
    if log_once("grad_out"):
        logger.info("Compute grad info:\n\n{}\n".format(
            summarize(info_out)))

    return grad_out, info_out

creation_args(self)

Returns the kwargs dict used to create this worker.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def creation_args(self) -> dict:
    """Returns the kwargs dict used to create this worker."""
    return self._original_kwargs

find_free_port(self)

Finds a free port on the node that this worker runs on.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def find_free_port(self) -> int:
    """Finds a free port on the node that this worker runs on."""
    from ray.util.sgd import utils
    return utils.find_free_port()

for_policy(self, func, policy_id='default_policy', **kwargs)

Calls the given function with the specified policy as first arg.

Parameters:

Name Type Description Default
func Callable[[ray.rllib.policy.policy.Policy, Optional[Any]], ~T]

The function to call with the policy as first arg.

required
policy_id Optional[str]

The PolicyID of the policy to call the function with.

'default_policy'

Returns:

Type Description
~T

The return value of the function call.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def for_policy(self,
               func: Callable[[Policy, Optional[Any]], T],
               policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
               **kwargs) -> T:
    """Calls the given function with the specified policy as first arg.

    Args:
        func: The function to call with the policy as first arg.
        policy_id: The PolicyID of the policy to call the function with.

    Keyword Args:
        kwargs: Additional kwargs to be passed to the call.

    Returns:
        The return value of the function call.
    """

    return func(self.policy_map[policy_id], **kwargs)

foreach_env(self, func)

Calls the given function with each sub-environment as arg.

Parameters:

Name Type Description Default
func Callable[[Any], ~T]

The function to call for each underlying sub-environment (as only arg).

required

Returns:

Type Description
List[~T]

The list of return values of all calls to func([env]).

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
    """Calls the given function with each sub-environment as arg.

    Args:
        func: The function to call for each underlying
            sub-environment (as only arg).

    Returns:
         The list of return values of all calls to `func([env])`.
    """

    if self.async_env is None:
        return []

    envs = self.async_env.get_sub_environments()
    # Empty list (not implemented): Call function directly on the
    # BaseEnv.
    if not envs:
        return [func(self.async_env)]
    # Call function on all underlying (vectorized) sub environments.
    else:
        return [func(e) for e in envs]

foreach_env_with_context(self, func)

Calls given function with each sub-env plus env_ctx as args.

Parameters:

Name Type Description Default
func Callable[[Any, ray.rllib.env.env_context.EnvContext], ~T]

The function to call for each underlying sub-environment and its EnvContext (as the args).

required

Returns:

Type Description
List[~T]

The list of return values of all calls to func([env, ctx]).

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_env_with_context(
        self, func: Callable[[EnvType, EnvContext], T]) -> List[T]:
    """Calls given function with each sub-env plus env_ctx as args.

    Args:
        func: The function to call for each underlying
            sub-environment and its EnvContext (as the args).

    Returns:
         The list of return values of all calls to `func([env, ctx])`.
    """

    if self.async_env is None:
        return []

    envs = self.async_env.get_sub_environments()
    # Empty list (not implemented): Call function directly on the
    # BaseEnv.
    if not envs:
        return [func(self.async_env, self.env_context)]
    # Call function on all underlying (vectorized) sub environments.
    else:
        ret = []
        for i, e in enumerate(envs):
            ctx = self.env_context.copy_with_overrides(vector_index=i)
            ret.append(func(e, ctx))
        return ret

foreach_policy(self, func, **kwargs)

Calls the given function with each (policy, policy_id) tuple.

Parameters:

Name Type Description Default
func Callable[[ray.rllib.policy.policy.Policy, str, Optional[Any]], ~T]

The function to call with each (policy, policy ID) tuple.

required

Returns:

Type Description
List[~T]

The list of return values of all calls to func([policy, pid, **kwargs]).

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_policy(self,
                   func: Callable[[Policy, PolicyID, Optional[Any]], T],
                   **kwargs) -> List[T]:
    """Calls the given function with each (policy, policy_id) tuple.

    Args:
        func: The function to call with each (policy, policy ID) tuple.

    Keyword Args:
        kwargs: Additional kwargs to be passed to the call.

    Returns:
         The list of return values of all calls to
            `func([policy, pid, **kwargs])`.
    """
    return [
        func(policy, pid, **kwargs)
        for pid, policy in self.policy_map.items()
    ]

foreach_trainable_policy(self, func, **kwargs)

Calls the given function with each (policy, policy_id) tuple.

Only those policies/IDs will be called on, which can be found in self.policies_to_train.

Parameters:

Name Type Description Default
func Callable[[ray.rllib.policy.policy.Policy, str, Optional[Any]], ~T]

The function to call with each (policy, policy ID) tuple, for only those policies that are in self.policies_to_train.

required

Returns:

Type Description
List[~T]

The list of return values of all calls to func([policy, pid, **kwargs]).

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def foreach_trainable_policy(
        self, func: Callable[[Policy, PolicyID, Optional[Any]], T],
        **kwargs) -> List[T]:
    """
    Calls the given function with each (policy, policy_id) tuple.


    Only those policies/IDs will be called on, which can be found in
    `self.policies_to_train`.

    Args:
        func: The function to call with each (policy, policy ID) tuple,
            for only those policies that are in `self.policies_to_train`.

    Keyword Args:
        kwargs: Additional kwargs to be passed to the call.

    Returns:
        The list of return values of all calls to
        `func([policy, pid, **kwargs])`.
    """
    return [
        func(policy, pid, **kwargs)
        for pid, policy in self.policy_map.items()
        if pid in self.policies_to_train
    ]

get_filters(self, flush_after=False)

Returns a snapshot of filters.

Parameters:

Name Type Description Default
flush_after bool

Clears the filter buffer state.

False

Returns:

Type Description
Dict

Dict for serializable filters

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_filters(self, flush_after: bool = False) -> Dict:
    """Returns a snapshot of filters.

    Args:
        flush_after: Clears the filter buffer state.

    Returns:
        Dict for serializable filters
    """
    return_filters = {}
    for k, f in self.filters.items():
        return_filters[k] = f.as_serializable()
        if flush_after:
            f.clear_buffer()
    return return_filters

get_global_vars(self)

Returns the current global_vars dict of this worker.

Returns:

Type Description
dict

The current global_vars dict of this worker.

Examples:

>>> global_vars = worker.get_global_vars()
>>> print(global_vars)
{"timestep": 424242}
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_global_vars(self) -> dict:
    """Returns the current global_vars dict of this worker.

    Returns:
        The current global_vars dict of this worker.

    Examples:
        >>> global_vars = worker.get_global_vars()
        >>> print(global_vars)
        {"timestep": 424242}
    """
    return self.global_vars

get_host(self)

Returns the hostname of the process running this evaluator.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_host(self) -> str:
    """Returns the hostname of the process running this evaluator."""
    return platform.node()

get_metrics(self)

Returns the thus-far collected metrics from this worker's rollouts.

Returns:

Type Description
List[Union[ray.rllib.evaluation.metrics.RolloutMetrics, ray.rllib.offline.off_policy_estimator.OffPolicyEstimate]]

List of RolloutMetrics and/or OffPolicyEstimate objects collected thus-far.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
    """Returns the thus-far collected metrics from this worker's rollouts.

    Returns:
         List of RolloutMetrics and/or OffPolicyEstimate objects
         collected thus-far.
    """

    # Get metrics from sampler (if any).
    if self.sampler is not None:
        out = self.sampler.get_metrics()
    else:
        out = []
    # Get metrics from our reward-estimators (if any).
    for m in self.reward_estimators:
        out.extend(m.get_metrics())

    return out

get_node_ip(self)

Returns the IP address of the node that this worker runs on.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_node_ip(self) -> str:
    """Returns the IP address of the node that this worker runs on."""
    return ray.util.get_node_ip_address()

get_policy(self, policy_id='default_policy')

Return policy for the specified id, or None.

Parameters:

Name Type Description Default
policy_id str

ID of the policy to return. None for DEFAULT_POLICY_ID (in the single agent case).

'default_policy'

Returns:

Type Description
Optional[ray.rllib.policy.policy.Policy]

The policy under the given ID (or None if not found).

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
        Optional[Policy]:
    """Return policy for the specified id, or None.

    Args:
        policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
            (in the single agent case).

    Returns:
        The policy under the given ID (or None if not found).
    """
    return self.policy_map.get(policy_id)

get_weights(self, policies=None)

Returns each policies' model weights of this worker.

Parameters:

Name Type Description Default
policies Optional[List[str]]

List of PolicyIDs to get the weights from. Use None for all policies.

None

Returns:

Type Description
Dict[str, dict]

Dict mapping PolicyIDs to ModelWeights.

Examples:

>>> weights = worker.get_weights()
>>> print(weights)
{"default_policy": {"layer1": array(...), "layer2": ...}}
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def get_weights(
        self,
        policies: Optional[List[PolicyID]] = None,
) -> Dict[PolicyID, ModelWeights]:
    """Returns each policies' model weights of this worker.

    Args:
        policies: List of PolicyIDs to get the weights from.
            Use None for all policies.

    Returns:
        Dict mapping PolicyIDs to ModelWeights.

    Examples:
        >>> weights = worker.get_weights()
        >>> print(weights)
        {"default_policy": {"layer1": array(...), "layer2": ...}}
    """
    if policies is None:
        policies = list(self.policy_map.keys())
    policies = force_list(policies)

    return {
        pid: policy.get_weights()
        for pid, policy in self.policy_map.items() if pid in policies
    }

learn_on_batch(self, samples)

Update policies based on the given batch.

This is the equivalent to apply_gradients(compute_gradients(samples)), but can be optimized to avoid pulling gradients into CPU memory.

Parameters:

Name Type Description Default
samples Union[SampleBatch, MultiAgentBatch]

The SampleBatch or MultiAgentBatch to learn on.

required

Returns:

Type Description
Dict

Dictionary of extra metadata from compute_gradients().

Examples:

>>> batch = worker.sample()
>>> info = worker.learn_on_batch(samples)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def learn_on_batch(self, samples: SampleBatchType) -> Dict:
    """Update policies based on the given batch.

    This is the equivalent to apply_gradients(compute_gradients(samples)),
    but can be optimized to avoid pulling gradients into CPU memory.

    Args:
        samples: The SampleBatch or MultiAgentBatch to learn on.

    Returns:
        Dictionary of extra metadata from compute_gradients().

    Examples:
        >>> batch = worker.sample()
        >>> info = worker.learn_on_batch(samples)
    """
    if log_once("learn_on_batch"):
        logger.info(
            "Training on concatenated sample batches:\n\n{}\n".format(
                summarize(samples)))
    if isinstance(samples, MultiAgentBatch):
        info_out = {}
        builders = {}
        to_fetch = {}
        for pid, batch in samples.policy_batches.items():
            if pid not in self.policies_to_train:
                continue
            # Decompress SampleBatch, in case some columns are compressed.
            batch.decompress_if_needed()
            policy = self.policy_map[pid]
            tf_session = policy.get_session()
            if tf_session and hasattr(policy, "_build_learn_on_batch"):
                builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
                to_fetch[pid] = policy._build_learn_on_batch(
                    builders[pid], batch)
            else:
                info_out[pid] = policy.learn_on_batch(batch)
        info_out.update(
            {pid: builders[pid].get(v)
             for pid, v in to_fetch.items()})
    else:
        info_out = {
            DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
            .learn_on_batch(samples)
        }
    if log_once("learn_out"):
        logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
    return info_out

remove_policy(self, *, policy_id='default_policy', policy_mapping_fn=None, policies_to_train=None)

Removes a policy from this RolloutWorker.

Parameters:

Name Type Description Default
policy_id str

ID of the policy to be removed. None for DEFAULT_POLICY_ID.

'default_policy'
policy_mapping_fn Optional[Callable[[Any], str]]

An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

None
policies_to_train Optional[List[str]]

An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

None
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def remove_policy(
        self,
        *,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
        policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
        policies_to_train: Optional[List[PolicyID]] = None,
) -> None:
    """Removes a policy from this RolloutWorker.

    Args:
        policy_id: ID of the policy to be removed. None for
            DEFAULT_POLICY_ID.
        policy_mapping_fn: An optional (updated) policy mapping function
            to use from here on. Note that already ongoing episodes will
            not change their mapping but will use the old mapping till
            the end of the episode.
        policies_to_train: An optional list of policy IDs to be trained.
            If None, will keep the existing list in place. Policies,
            whose IDs are not in the list will not be updated.
    """
    if policy_id not in self.policy_map:
        raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
    del self.policy_map[policy_id]
    del self.preprocessors[policy_id]
    self.set_policy_mapping_fn(policy_mapping_fn)
    self.set_policies_to_train(policies_to_train)

restore(self, objs)

Restores this RolloutWorker's state from a sequence of bytes.

Parameters:

Name Type Description Default
objs bytes

The byte sequence to restore this worker's state from.

required

Examples:

>>> state = worker.save()
>>> new_worker = RolloutWorker(...)
>>> new_worker.restore(state)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def restore(self, objs: bytes) -> None:
    """Restores this RolloutWorker's state from a sequence of bytes.

    Args:
        objs: The byte sequence to restore this worker's state from.

    Examples:
        >>> state = worker.save()
        >>> new_worker = RolloutWorker(...)
        >>> new_worker.restore(state)
    """
    objs = pickle.loads(objs)
    self.sync_filters(objs["filters"])
    for pid, state in objs["state"].items():
        if pid not in self.policy_map:
            pol_spec = objs.get("policy_specs", {}).get(pid)
            if not pol_spec:
                logger.warning(
                    f"PolicyID '{pid}' was probably added on-the-fly (not"
                    " part of the static `multagent.policies` config) and"
                    " no PolicySpec objects found in the pickled policy "
                    "state. Will not add `{pid}`, but ignore it for now.")
            else:
                self.add_policy(
                    policy_id=pid,
                    policy_cls=pol_spec.policy_class,
                    observation_space=pol_spec.observation_space,
                    action_space=pol_spec.action_space,
                    config=pol_spec.config,
                )
        else:
            self.policy_map[pid].set_state(state)

sample(self)

Returns a batch of experience sampled from this worker.

This method must be implemented by subclasses.

Returns:

Type Description
Union[SampleBatch, MultiAgentBatch]

A columnar batch of experiences (e.g., tensors).

Examples:

>>> print(worker.sample())
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def sample(self) -> SampleBatchType:
    """Returns a batch of experience sampled from this worker.

    This method must be implemented by subclasses.

    Returns:
        A columnar batch of experiences (e.g., tensors).

    Examples:
        >>> print(worker.sample())
        SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
    """

    if self.fake_sampler and self.last_batch is not None:
        return self.last_batch
    elif self.input_reader is None:
        raise ValueError("RolloutWorker has no `input_reader` object! "
                         "Cannot call `sample()`. You can try setting "
                         "`create_env_on_driver` to True.")

    if log_once("sample_start"):
        logger.info("Generating sample batch of size {}".format(
            self.rollout_fragment_length))

    batches = [self.input_reader.next()]
    steps_so_far = batches[0].count if \
        self.count_steps_by == "env_steps" else \
        batches[0].agent_steps()

    # In truncate_episodes mode, never pull more than 1 batch per env.
    # This avoids over-running the target batch size.
    if self.batch_mode == "truncate_episodes":
        max_batches = self.num_envs
    else:
        max_batches = float("inf")

    while (steps_so_far < self.rollout_fragment_length
           and len(batches) < max_batches):
        batch = self.input_reader.next()
        steps_so_far += batch.count if \
            self.count_steps_by == "env_steps" else \
            batch.agent_steps()
        batches.append(batch)
    batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
        batches[0]

    self.callbacks.on_sample_end(worker=self, samples=batch)

    # Always do writes prior to compression for consistency and to allow
    # for better compression inside the writer.
    self.output_writer.write(batch)

    # Do off-policy estimation, if needed.
    if self.reward_estimators:
        for sub_batch in batch.split_by_episode():
            for estimator in self.reward_estimators:
                estimator.process(sub_batch)

    if log_once("sample_end"):
        logger.info("Completed sample batch:\n\n{}\n".format(
            summarize(batch)))

    if self.compress_observations:
        batch.compress(bulk=self.compress_observations == "bulk")

    if self.fake_sampler:
        self.last_batch = batch
    return batch

sample_and_learn(self, expected_batch_size, num_sgd_iter, sgd_minibatch_size, standardize_fields)

Sample and batch and learn on it.

This is typically used in combination with distributed allreduce.

Parameters:

Name Type Description Default
expected_batch_size int

Expected number of samples to learn on.

required
num_sgd_iter int

Number of SGD iterations.

required
sgd_minibatch_size str

SGD minibatch size.

required
standardize_fields List[str]

List of sample fields to normalize.

required

Returns:

Type Description
Tuple[dict, int]

A tuple consisting of a dictionary of extra metadata returned from the policies' learn_on_batch() and the number of samples learned on.

Source code in ray/rllib/evaluation/rollout_worker.py
def sample_and_learn(self, expected_batch_size: int, num_sgd_iter: int,
                     sgd_minibatch_size: str,
                     standardize_fields: List[str]) -> Tuple[dict, int]:
    """Sample and batch and learn on it.

    This is typically used in combination with distributed allreduce.

    Args:
        expected_batch_size: Expected number of samples to learn on.
        num_sgd_iter: Number of SGD iterations.
        sgd_minibatch_size: SGD minibatch size.
        standardize_fields: List of sample fields to normalize.

    Returns:
        A tuple consisting of a dictionary of extra metadata returned from
            the policies' `learn_on_batch()` and the number of samples
            learned on.
    """
    batch = self.sample()
    assert batch.count == expected_batch_size, \
        ("Batch size possibly out of sync between workers, expected:",
         expected_batch_size, "got:", batch.count)
    logger.info("Executing distributed minibatch SGD "
                "with epoch size {}, minibatch size {}".format(
                    batch.count, sgd_minibatch_size))
    info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
                            sgd_minibatch_size, standardize_fields)
    return info, batch.count

sample_with_count(self)

Same as sample() but returns the count as a separate value.

Returns:

Type Description
Tuple[Union[SampleBatch, MultiAgentBatch], int]

A columnar batch of experiences (e.g., tensors) and the size of the collected batch.

Examples:

>>> print(worker.sample_with_count())
(SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...}), 3)
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
@ray.method(num_returns=2)
def sample_with_count(self) -> Tuple[SampleBatchType, int]:
    """Same as sample() but returns the count as a separate value.

    Returns:
        A columnar batch of experiences (e.g., tensors) and the
            size of the collected batch.

    Examples:
        >>> print(worker.sample_with_count())
        (SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...}), 3)
    """
    batch = self.sample()
    return batch, batch.count

save(self)

Serializes this RolloutWorker's current state and returns it.

Returns:

Type Description
bytes

The current state of this RolloutWorker as a serialized, pickled byte sequence.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def save(self) -> bytes:
    """Serializes this RolloutWorker's current state and returns it.

    Returns:
        The current state of this RolloutWorker as a serialized, pickled
        byte sequence.
    """
    filters = self.get_filters(flush_after=True)
    state = {}
    policy_specs = {}
    for pid in self.policy_map:
        state[pid] = self.policy_map[pid].get_state()
        policy_specs[pid] = self.policy_map.policy_specs[pid]
    return pickle.dumps({
        "filters": filters,
        "state": state,
        "policy_specs": policy_specs,
    })

set_global_vars(self, global_vars)

Updates this worker's and all its policies' global vars.

Parameters:

Name Type Description Default
global_vars dict

The new global_vars dict.

required

Examples:

>>> global_vars = worker.set_global_vars({"timestep": 4242})
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_global_vars(self, global_vars: dict) -> None:
    """Updates this worker's and all its policies' global vars.

    Args:
        global_vars: The new global_vars dict.

    Examples:
        >>> global_vars = worker.set_global_vars({"timestep": 4242})
    """
    self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
    self.global_vars = global_vars

set_policies_to_train(self, policies_to_train=None)

Sets self.policies_to_train to a new list of PolicyIDs.

Parameters:

Name Type Description Default
policies_to_train Optional[List[str]]

The new list of policy IDs to train with. If None, will keep the existing list in place.

None
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_policies_to_train(
        self, policies_to_train: Optional[List[PolicyID]] = None) -> None:
    """Sets `self.policies_to_train` to a new list of PolicyIDs.

    Args:
        policies_to_train: The new list of policy IDs to train with.
            If None, will keep the existing list in place.
    """
    if policies_to_train is not None:
        self.policies_to_train = policies_to_train

set_policy_mapping_fn(self, policy_mapping_fn=None)

Sets self.policy_mapping_fn to a new callable (if provided).

Parameters:

Name Type Description Default
policy_mapping_fn Optional[Callable[[Any, Episode], str]]

The new mapping function to use. If None, will keep the existing mapping function in place.

None
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_policy_mapping_fn(
        self,
        policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
                                             PolicyID]] = None,
) -> None:
    """Sets `self.policy_mapping_fn` to a new callable (if provided).

    Args:
        policy_mapping_fn: The new mapping function to use. If None,
            will keep the existing mapping function in place.
    """
    if policy_mapping_fn is not None:
        self.policy_mapping_fn = policy_mapping_fn
        if not callable(self.policy_mapping_fn):
            raise ValueError("`policy_mapping_fn` must be a callable!")

set_weights(self, weights, global_vars=None)

Sets each policies' model weights of this worker.

Parameters:

Name Type Description Default
weights Dict[str, dict]

Dict mapping PolicyIDs to the new weights to be used.

required
global_vars Optional[Dict]

An optional global vars dict to set this worker to. If None, do not update the global_vars.

None

Examples:

>>> weights = worker.get_weights()
>>> # Set `global_vars` (timestep) as well.
>>> worker.set_weights(weights, {"timestep": 42})
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def set_weights(self,
                weights: Dict[PolicyID, ModelWeights],
                global_vars: Optional[Dict] = None) -> None:
    """Sets each policies' model weights of this worker.

    Args:
        weights: Dict mapping PolicyIDs to the new weights to be used.
        global_vars: An optional global vars dict to set this
            worker to. If None, do not update the global_vars.

    Examples:
        >>> weights = worker.get_weights()
        >>> # Set `global_vars` (timestep) as well.
        >>> worker.set_weights(weights, {"timestep": 42})
    """
    for pid, w in weights.items():
        self.policy_map[pid].set_weights(w)
    if global_vars:
        self.set_global_vars(global_vars)

setup_torch_data_parallel(self, url, world_rank, world_size, backend)

Join a torch process group for distributed SGD.

Source code in ray/rllib/evaluation/rollout_worker.py
def setup_torch_data_parallel(self, url: str, world_rank: int,
                              world_size: int, backend: str) -> None:
    """Join a torch process group for distributed SGD."""

    logger.info("Joining process group, url={}, world_rank={}, "
                "world_size={}, backend={}".format(url, world_rank,
                                                   world_size, backend))
    torch.distributed.init_process_group(
        backend=backend,
        init_method=url,
        rank=world_rank,
        world_size=world_size)

    for pid, policy in self.policy_map.items():
        if not isinstance(policy, TorchPolicy):
            raise ValueError(
                "This policy does not support torch distributed", policy)
        policy.distributed_world_size = world_size

stop(self)

Releases all resources used by this RolloutWorker.

Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def stop(self) -> None:
    """Releases all resources used by this RolloutWorker."""

    # If we have an env -> Release its resources.
    if self.env is not None:
        self.async_env.stop()
    # Close all policies' sessions (if tf static graph).
    for policy in self.policy_map.values():
        sess = policy.get_session()
        # Closes the tf session, if any.
        if sess is not None:
            sess.close()

sync_filters(self, new_filters)

Changes self's filter to given and rebases any accumulated delta.

Parameters:

Name Type Description Default
new_filters dict

Filters with new state to update local copy.

required
Source code in ray/rllib/evaluation/rollout_worker.py
@DeveloperAPI
def sync_filters(self, new_filters: dict) -> None:
    """Changes self's filter to given and rebases any accumulated delta.

    Args:
        new_filters: Filters with new state to update local copy.
    """
    assert all(k in new_filters for k in self.filters)
    for k in self.filters:
        self.filters[k].sync(new_filters[k])

Sample Batches

ray.rllib.policy.sample_batch.SampleBatch (dict)

Wrapper around a dictionary with string keys and array-like values.

For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three samples, each with an "obs" and "reward" attribute.

__init__(self, *args, **kwargs) special

Constructs a sample batch (same params as dict constructor).

Note: All args and those *kwargs not listed below will be passed as-is to the parent dict constructor.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def __init__(self, *args, **kwargs):
    """Constructs a sample batch (same params as dict constructor).

    Note: All *args and those **kwargs not listed below will be passed
    as-is to the parent dict constructor.

    Keyword Args:
        _time_major (Optinal[bool]): Whether data in this sample batch
            is time-major. This is False by default and only relevant
            if the data contains sequences.
        _max_seq_len (Optional[bool]): The max sequence chunk length
            if the data contains sequences.
        _zero_padded (Optional[bool]): Whether the data in this batch
            contains sequences AND these sequences are right-zero-padded
            according to the `_max_seq_len` setting.
        _is_training (Optional[bool]): Whether this batch is used for
            training. If False, batch may be used for e.g. action
            computations (inference).
    """

    # Possible seq_lens (TxB or BxT) setup.
    self.time_major = kwargs.pop("_time_major", None)
    # Maximum seq len value.
    self.max_seq_len = kwargs.pop("_max_seq_len", None)
    # Is alredy right-zero-padded?
    self.zero_padded = kwargs.pop("_zero_padded", False)
    # Whether this batch is used for training (vs inference).
    self._is_training = kwargs.pop("_is_training", None)

    # Call super constructor. This will make the actual data accessible
    # by column name (str) via e.g. self["some-col"].
    dict.__init__(self, *args, **kwargs)

    self.accessed_keys = set()
    self.added_keys = set()
    self.deleted_keys = set()
    self.intercepted_values = {}
    self.get_interceptor = None

    # Clear out None seq-lens.
    seq_lens_ = self.get(SampleBatch.SEQ_LENS)
    if seq_lens_ is None or \
            (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
        self.pop(SampleBatch.SEQ_LENS, None)
    # Numpyfy seq_lens if list.
    elif isinstance(seq_lens_, list):
        self[SampleBatch.SEQ_LENS] = seq_lens_ = \
            np.array(seq_lens_, dtype=np.int32)

    if self.max_seq_len is None and seq_lens_ is not None and \
            not (tf and tf.is_tensor(seq_lens_)) and \
            len(seq_lens_) > 0:
        self.max_seq_len = max(seq_lens_)

    if self._is_training is None:
        self._is_training = self.pop("is_training", False)

    lengths = []
    copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
    for k, v in copy_.items():
        assert isinstance(k, str), self

        # TODO: Drop support for lists as values.
        # Convert lists of int|float into numpy arrays make sure all data
        # has same length.
        if isinstance(v, list):
            self[k] = np.array(v)

        # Try to infer the "length" of the SampleBatch by finding the first
        # value that is actually a ndarray/tensor. This would fail if
        # all values are nested dicts/tuples of more complex underlying
        # structures.
        len_ = len(v) if isinstance(
            v,
            (list, np.ndarray)) or (torch and torch.is_tensor(v)) else None
        if len_:
            lengths.append(len_)

    if self.get(SampleBatch.SEQ_LENS) is not None and \
            not (tf and tf.is_tensor(self[SampleBatch.SEQ_LENS])) and \
            len(self[SampleBatch.SEQ_LENS]) > 0:
        self.count = sum(self[SampleBatch.SEQ_LENS])
    else:
        self.count = lengths[0] if lengths else 0

    # A convenience map for slicing this batch into sub-batches along
    # the time axis. This helps reduce repeated iterations through the
    # batch's seq_lens array to find good slicing points. Built lazily
    # when needed.
    self._slice_map = []

columns(self, keys)

Returns a list of the batch-data in the specified columns.

Parameters:

Name Type Description Default
keys List[str]

List of column names fo which to return the data.

required

Returns:

Type Description
List[any]

The list of data items ordered by the order of column names in keys.

Examples:

>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def columns(self, keys: List[str]) -> List[any]:
    """Returns a list of the batch-data in the specified columns.

    Args:
        keys (List[str]): List of column names fo which to return the data.

    Returns:
        List[any]: The list of data items ordered by the order of column
            names in `keys`.

    Examples:
        >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
        >>> print(batch.columns(["a", "b"]))
        [[1], [2]]
    """

    # TODO: (sven) Make this work for nested data as well.
    out = []
    for k in keys:
        out.append(self[k])
    return out

compress(self, bulk=False, columns=frozenset({'new_obs', 'obs'}))

Compresses the data buffers (by column) in place.

Parameters:

Name Type Description Default
bulk bool

Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.

False
columns Set[str]

The columns to compress. Default: Only compress the obs and new_obs columns.

frozenset({'new_obs', 'obs'})

Returns:

Type Description
SampleBatch

This very (now compressed) SampleBatch.

Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def compress(self,
             bulk: bool = False,
             columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
    """Compresses the data buffers (by column) in place.

    Args:
        bulk (bool): Whether to compress across the batch dimension (0)
            as well. If False will compress n separate list items, where n
            is the batch size.
        columns (Set[str]): The columns to compress. Default: Only
            compress the obs and new_obs columns.

    Returns:
        SampleBatch: This very (now compressed) SampleBatch.
    """

    def _compress_in_place(path, value):
        if path[0] not in columns:
            return
        curr = self
        for i, p in enumerate(path):
            if i == len(path) - 1:
                if bulk:
                    curr[p] = pack(value)
                else:
                    curr[p] = np.array([pack(o) for o in value])
            curr = curr[p]

    tree.map_structure_with_path(_compress_in_place, self)

    return self

concat(self, other)

Concatenates other to this one and returns a new SampleBatch.

Parameters:

Name Type Description Default
other SampleBatch

The other SampleBatch object to concat to this one.

required

Returns:

Type Description
SampleBatch

The new SampleBatch, resulting from concating other to self.

Examples:

>>> b1 = SampleBatch({"a": np.array([1, 2])})
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])})
>>> print(b1.concat(b2))
{"a": np.array([1, 2, 3, 4, 5])}
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
    """Concatenates `other` to this one and returns a new SampleBatch.

    Args:
        other (SampleBatch): The other SampleBatch object to concat to this
            one.

    Returns:
        SampleBatch: The new SampleBatch, resulting from concating `other`
            to `self`.

    Examples:
        >>> b1 = SampleBatch({"a": np.array([1, 2])})
        >>> b2 = SampleBatch({"a": np.array([3, 4, 5])})
        >>> print(b1.concat(b2))
        {"a": np.array([1, 2, 3, 4, 5])}
    """
    return self.concat_samples([self, other])

concat_samples(samples) staticmethod

Concatenates n SampleBatches or MultiAgentBatches.

Parameters:

Name Type Description Default
samples Union[List[SampleBatch], List[MultiAgentBatch]]

List of SampleBatches or MultiAgentBatches to be concatenated.

required

Returns:

Type Description
Union[SampleBatch, MultiAgentBatch]

A new (concatenated) SampleBatch or MultiAgentBatch.

Examples:

>>> b1 = SampleBatch({"a": np.array([1, 2]),
...                   "b": np.array([10, 11])})
>>> b2 = SampleBatch({"a": np.array([3]),
...                   "b": np.array([12])})
>>> print(SampleBatch.concat_samples([b1, b2]))
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def concat_samples(
        samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
) -> Union["SampleBatch", "MultiAgentBatch"]:
    """Concatenates n SampleBatches or MultiAgentBatches.

    Args:
        samples (Union[List[SampleBatch], List[MultiAgentBatch]]): List of
            SampleBatches or MultiAgentBatches to be concatenated.

    Returns:
        Union[SampleBatch, MultiAgentBatch]: A new (concatenated)
            SampleBatch or MultiAgentBatch.

    Examples:
        >>> b1 = SampleBatch({"a": np.array([1, 2]),
        ...                   "b": np.array([10, 11])})
        >>> b2 = SampleBatch({"a": np.array([3]),
        ...                   "b": np.array([12])})
        >>> print(SampleBatch.concat_samples([b1, b2]))
        {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
    """
    if any(isinstance(s, MultiAgentBatch) for s in samples):
        return MultiAgentBatch.concat_samples(samples)
    concatd_seq_lens = []
    concat_samples = []
    zero_padded = samples[0].zero_padded
    max_seq_len = samples[0].max_seq_len
    time_major = samples[0].time_major
    for s in samples:
        if s.count > 0:
            assert s.zero_padded == zero_padded
            assert s.time_major == time_major
            if zero_padded:
                assert s.max_seq_len == max_seq_len
            concat_samples.append(s)
            if s.get(SampleBatch.SEQ_LENS) is not None:
                concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])

    # If we don't have any samples (0 or only empty SampleBatches),
    # return an empty SampleBatch here.
    if len(concat_samples) == 0:
        return SampleBatch()

    # Collect the concat'd data.
    concatd_data = {}

    def concat_key(*values):
        return concat_aligned(values, time_major)

    try:
        for k in concat_samples[0].keys():
            if k == "infos":
                concatd_data[k] = concat_aligned(
                    [s[k] for s in concat_samples], time_major=time_major)
            else:
                concatd_data[k] = tree.map_structure(
                    concat_key, *[c[k] for c in concat_samples])
    except Exception:
        raise ValueError(f"Cannot concat data under key '{k}', b/c "
                         "sub-structures under that key don't match. "
                         f"`samples`={samples}")

    # Return a new (concat'd) SampleBatch.
    return SampleBatch(
        concatd_data,
        seq_lens=concatd_seq_lens,
        _time_major=time_major,
        _zero_padded=zero_padded,
        _max_seq_len=max_seq_len,
    )

copy(self, shallow=False)

Creates a deep or shallow copy of this SampleBatch and returns it.

Parameters:

Name Type Description Default
shallow bool

Whether the copying should be done shallowly.

False

Returns:

Type Description
SampleBatch

A deep or shallow copy of this SampleBatch object.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
    """Creates a deep or shallow copy of this SampleBatch and returns it.

    Args:
        shallow (bool): Whether the copying should be done shallowly.

    Returns:
        SampleBatch: A deep or shallow copy of this SampleBatch object.
    """
    copy_ = {k: v for k, v in self.items()}
    data = tree.map_structure(
        lambda v: (np.array(v, copy=not shallow) if
                   isinstance(v, np.ndarray) else v),
        copy_,
    )
    copy_ = SampleBatch(data)
    copy_.set_get_interceptor(self.get_interceptor)
    copy_.added_keys = self.added_keys
    copy_.deleted_keys = self.deleted_keys
    copy_.accessed_keys = self.accessed_keys
    return copy_

decompress_if_needed(self, columns=frozenset({'new_obs', 'obs'}))

Decompresses data buffers (per column if not compressed) in place.

Parameters:

Name Type Description Default
columns Set[str]

The columns to decompress. Default: Only decompress the obs and new_obs columns.

frozenset({'new_obs', 'obs'})

Returns:

Type Description
SampleBatch

This very (now uncompressed) SampleBatch.

Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def decompress_if_needed(self,
                         columns: Set[str] = frozenset(
                             ["obs", "new_obs"])) -> "SampleBatch":
    """Decompresses data buffers (per column if not compressed) in place.

    Args:
        columns (Set[str]): The columns to decompress. Default: Only
            decompress the obs and new_obs columns.

    Returns:
        SampleBatch: This very (now uncompressed) SampleBatch.
    """

    def _decompress_in_place(path, value):
        if path[0] not in columns:
            return
        curr = self
        for p in path[:-1]:
            curr = curr[p]
        # Bulk compressed.
        if is_compressed(value):
            curr[path[-1]] = unpack(value)
        # Non bulk compressed.
        elif len(value) > 0 and is_compressed(value[0]):
            curr[path[-1]] = np.array([unpack(o) for o in value])

    tree.map_structure_with_path(_decompress_in_place, self)

    return self

get(self, key, default=None)

Return the value for key if key is in the dictionary, else default.

Source code in ray/rllib/policy/sample_batch.py
def get(self, key, default=None):
    try:
        return self.__getitem__(key)
    except KeyError:
        return default

get_single_step_input_dict(self, view_requirements, index='last')

Creates single ts SampleBatch at given index from self.

For usage as input-dict for model (action or value function) calls.

Parameters:

Name Type Description Default
view_requirements Dict[str, ViewRequirement]

A view requirements dict from the model for which to produce the input_dict.

required
index Union[str, int]

An integer index value indicating the position in the trajectory for which to generate the compute_actions input dict. Set to "last" to generate the dict at the very end of the trajectory (e.g. for value estimation). Note that "last" is different from -1, as "last" will use the final NEXT_OBS as observation input.

'last'

Returns:

Type Description
SampleBatch

The (single-timestep) input dict for ModelV2 calls.

Source code in ray/rllib/policy/sample_batch.py
@ExperimentalAPI
def get_single_step_input_dict(
        self,
        view_requirements: ViewRequirementsDict,
        index: Union[str, int] = "last",
) -> "SampleBatch":
    """Creates single ts SampleBatch at given index from `self`.

    For usage as input-dict for model (action or value function) calls.

    Args:
        view_requirements: A view requirements dict from the model for
            which to produce the input_dict.
        index: An integer index value indicating the
            position in the trajectory for which to generate the
            compute_actions input dict. Set to "last" to generate the dict
            at the very end of the trajectory (e.g. for value estimation).
            Note that "last" is different from -1, as "last" will use the
            final NEXT_OBS as observation input.

    Returns:
        The (single-timestep) input dict for ModelV2 calls.
    """
    last_mappings = {
        SampleBatch.OBS: SampleBatch.NEXT_OBS,
        SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
        SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
    }

    input_dict = {}
    for view_col, view_req in view_requirements.items():
        if view_req.used_for_compute_actions is False:
            continue

        # Create batches of size 1 (single-agent input-dict).
        data_col = view_req.data_col or view_col
        if index == "last":
            data_col = last_mappings.get(data_col, data_col)
            # Range needed.
            if view_req.shift_from is not None:
                # Batch repeat value > 1: We have single frames in the
                # batch at each timestep (for the `data_col`).
                data = self[view_col][-1]
                traj_len = len(self[data_col])
                missing_at_end = traj_len % view_req.batch_repeat_value
                # Index into the observations column must be shifted by
                # -1 b/c index=0 for observations means the current (last
                # seen) observation (after having taken an action).
                obs_shift = -1 if data_col in [
                    SampleBatch.OBS, SampleBatch.NEXT_OBS
                ] else 0
                from_ = view_req.shift_from + obs_shift
                to_ = view_req.shift_to + obs_shift + 1
                if to_ == 0:
                    to_ = None
                input_dict[view_col] = np.array([
                    np.concatenate(
                        [data,
                         self[data_col][-missing_at_end:]])[from_:to_]
                ])
            # Single index.
            else:
                input_dict[view_col] = tree.map_structure(
                    lambda v: v[-1:],  # keep as array (w/ 1 element)
                    self[data_col],
                )
        # Single index somewhere inside the trajectory (non-last).
        else:
            input_dict[view_col] = self[data_col][index:index + 1
                                                  if index != -1 else None]

    return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))

right_zero_pad(self, max_seq_len, exclude_states=True)

Right (adding zeros at end) zero-pads this SampleBatch in-place.

This will set the self.zero_padded flag to True and self.max_seq_len to the given max_seq_len value.

Parameters:

Name Type Description Default
max_seq_len int

The max (total) length to zero pad to.

required
exclude_states bool

If False, also right-zero-pad all state_in_x data. If True, leave state_in_x keys as-is.

True

Returns:

Type Description
SampleBatch

This very (now right-zero-padded) SampleBatch.

Exceptions:

Type Description
ValueError

If self[SampleBatch.SEQ_LENS] is None (not defined).

Examples:

>>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=4))
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3],
...                      "state_in_0": [1.0, 3.0],
...                      "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=5))
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
 "state_in_0": [1.0, 3.0],  # <- all state-ins remain as-is
 "seq_lens": [1, 2]}
Source code in ray/rllib/policy/sample_batch.py
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
    """Right (adding zeros at end) zero-pads this SampleBatch in-place.

    This will set the `self.zero_padded` flag to True and
    `self.max_seq_len` to the given `max_seq_len` value.

    Args:
        max_seq_len: The max (total) length to zero pad to.
        exclude_states: If False, also right-zero-pad all
            `state_in_x` data. If True, leave `state_in_x` keys
            as-is.

    Returns:
        SampleBatch: This very (now right-zero-padded) SampleBatch.

    Raises:
        ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).

    Examples:
        >>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]})
        >>> print(batch.right_zero_pad(max_seq_len=4))
        {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}

        >>> batch = SampleBatch({"a": [1, 2, 3],
        ...                      "state_in_0": [1.0, 3.0],
        ...                      "seq_lens": [1, 2]})
        >>> print(batch.right_zero_pad(max_seq_len=5))
        {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
         "state_in_0": [1.0, 3.0],  # <- all state-ins remain as-is
         "seq_lens": [1, 2]}
    """
    seq_lens = self.get(SampleBatch.SEQ_LENS)
    if seq_lens is None:
        raise ValueError(
            "Cannot right-zero-pad SampleBatch if no `seq_lens` field "
            "present! SampleBatch={self}")

    length = len(seq_lens) * max_seq_len

    def _zero_pad_in_place(path, value):
        # Skip "state_in_..." columns and "seq_lens".
        if (exclude_states is True and path[0].startswith("state_in_")) \
                or path[0] == SampleBatch.SEQ_LENS:
            return
        # Generate zero-filled primer of len=max_seq_len.
        if value.dtype == np.object or value.dtype.type is np.str_:
            f_pad = [None] * length
        else:
            # Make sure type doesn't change.
            f_pad = np.zeros(
                (length, ) + np.shape(value)[1:], dtype=value.dtype)
        # Fill primer with data.
        f_pad_base = f_base = 0
        for len_ in self[SampleBatch.SEQ_LENS]:
            f_pad[f_pad_base:f_pad_base + len_] = value[f_base:f_base +
                                                        len_]
            f_pad_base += max_seq_len
            f_base += len_
        assert f_base == len(value), value

        # Update our data in-place.
        curr = self
        for i, p in enumerate(path):
            if i == len(path) - 1:
                curr[p] = f_pad
            curr = curr[p]

    self_as_dict = {k: v for k, v in self.items()}
    tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)

    # Set flags to indicate, we are now zero-padded (and to what extend).
    self.zero_padded = True
    self.max_seq_len = max_seq_len

    return self

rows(self)

Returns an iterator over data rows, i.e. dicts with column values.

Note that if seq_lens is set in self, we set it to [1] in the rows.

!!! yields Dict[str, TensorType]: The column values of the row in this iteration.

Examples:

>>> batch = SampleBatch({
...    "a": [1, 2, 3],
...    "b": [4, 5, 6],
...    "seq_lens": [1, 2]
... })
>>> for row in batch.rows():
       print(row)
{"a": 1, "b": 4, "seq_lens": [1]}
{"a": 2, "b": 5, "seq_lens": [1]}
{"a": 3, "b": 6, "seq_lens": [1]}
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def rows(self) -> Iterator[Dict[str, TensorType]]:
    """Returns an iterator over data rows, i.e. dicts with column values.

    Note that if `seq_lens` is set in self, we set it to [1] in the rows.

    Yields:
        Dict[str, TensorType]: The column values of the row in this
            iteration.

    Examples:
        >>> batch = SampleBatch({
        ...    "a": [1, 2, 3],
        ...    "b": [4, 5, 6],
        ...    "seq_lens": [1, 2]
        ... })
        >>> for row in batch.rows():
               print(row)
        {"a": 1, "b": 4, "seq_lens": [1]}
        {"a": 2, "b": 5, "seq_lens": [1]}
        {"a": 3, "b": 6, "seq_lens": [1]}
    """

    # Do we add seq_lens=[1] to each row?
    seq_lens = None if self.get(
        SampleBatch.SEQ_LENS) is None else np.array([1])

    self_as_dict = {k: v for k, v in self.items()}

    for i in range(self.count):
        yield tree.map_structure_with_path(
            lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
            self_as_dict,
        )

shuffle(self)

Shuffles the rows of this batch in-place.

Returns:

Type Description
SampleBatch

This very (now shuffled) SampleBatch.

Exceptions:

Type Description
ValueError

If self[SampleBatch.SEQ_LENS] is defined.

Examples:

>>> batch = SampleBatch({"a": [1, 2, 3, 4]})
>>> print(batch.shuffle())
{"a": [4, 1, 3, 2]}
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def shuffle(self) -> None:
    """Shuffles the rows of this batch in-place.

    Returns:
        SampleBatch: This very (now shuffled) SampleBatch.

    Raises:
        ValueError: If self[SampleBatch.SEQ_LENS] is defined.

    Examples:
        >>> batch = SampleBatch({"a": [1, 2, 3, 4]})
        >>> print(batch.shuffle())
        {"a": [4, 1, 3, 2]}
    """

    # Shuffling the data when we have `seq_lens` defined is probably
    # a bad idea!
    if self.get(SampleBatch.SEQ_LENS) is not None:
        raise ValueError(
            "SampleBatch.shuffle not possible when your data has "
            "`seq_lens` defined!")

    # Get a permutation over the single items once and use the same
    # permutation for all the data (otherwise, data would become
    # meaningless).
    permutation = np.random.permutation(self.count)

    def _permutate_in_place(path, value):
        curr = self
        for i, p in enumerate(path):
            if i == len(path) - 1:
                curr[p] = value[permutation]
            # Translate into list (tuples are immutable).
            if isinstance(curr[p], tuple):
                curr[p] = list(curr[p])
            curr = curr[p]

    tree.map_structure_with_path(_permutate_in_place, self)

    return self

size_bytes(self)

Returns sum over number of bytes of all data buffers.

For numpy arrays, we use .nbytes. For all other value types, we use sys.getsizeof(...).

Returns:

Type Description
int

The overall size in bytes of the data buffer (all columns).

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def size_bytes(self) -> int:
    """Returns sum over number of bytes of all data buffers.

    For numpy arrays, we use `.nbytes`. For all other value types, we use
    sys.getsizeof(...).

    Returns:
        int: The overall size in bytes of the data buffer (all columns).
    """
    return sum(
        v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
        for v in tree.flatten(self))

split_by_episode(self)

Splits by eps_id column and returns list of new batches.

Returns:

Type Description
List[SampleBatch]

List of batches, one per distinct episode.

Exceptions:

Type Description
KeyError

If the eps_id AND dones columns are not present.

Examples:

>>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]})
>>> print(batch.split_by_episode())
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def split_by_episode(self) -> List["SampleBatch"]:
    """Splits by `eps_id` column and returns list of new batches.

    Returns:
        List[SampleBatch]: List of batches, one per distinct episode.

    Raises:
        KeyError: If the `eps_id` AND `dones` columns are not present.

    Examples:
        >>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]})
        >>> print(batch.split_by_episode())
        [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
    """

    # No eps_id in data -> Make sure there are no "dones" in the middle
    # and add eps_id automatically.
    if SampleBatch.EPS_ID not in self:
        # TODO: (sven) Shouldn't we rather split by DONEs then and not
        #  add fake eps-ids (0s) at all?
        if SampleBatch.DONES in self:
            assert not any(self[SampleBatch.DONES][:-1])
        self[SampleBatch.EPS_ID] = np.repeat(0, self.count)
        return [self]

    # Produce a new slice whenever we find a new episode ID.
    slices = []
    cur_eps_id = self[SampleBatch.EPS_ID][0]
    offset = 0
    for i in range(self.count):
        next_eps_id = self[SampleBatch.EPS_ID][i]
        if next_eps_id != cur_eps_id:
            slices.append(self[offset:i])
            offset = i
            cur_eps_id = next_eps_id
    # Add final slice.
    slices.append(self[offset:self.count])

    # TODO: (sven) Are these checks necessary? Should be all ok according
    #  to above logic.
    for s in slices:
        slen = len(set(s[SampleBatch.EPS_ID]))
        assert slen == 1, (s, slen)
    assert sum(s.count for s in slices) == self.count, (slices, self.count)

    return slices

timeslices(self, size=None, num_slices=None, k=None)

Returns SampleBatches, each one representing a k-slice of this one.

Will start from timestep 0 and produce slices of size=k.

Parameters:

Name Type Description Default
size Optional[int]

The size (in timesteps) of each returned SampleBatch.

None
num_slices Optional[int]

The number of slices to produce.

None
k int

The size (in timesteps) of each returned SampleBatch.

None

Returns:

Type Description
List[SampleBatch]

The list of num_slices (new) SampleBatches or n (new) SampleBatches each one of size size.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def timeslices(self,
               size: Optional[int] = None,
               num_slices: Optional[int] = None,
               k: Optional[int] = None) -> List["SampleBatch"]:
    """Returns SampleBatches, each one representing a k-slice of this one.

    Will start from timestep 0 and produce slices of size=k.

    Args:
        size (Optional[int]): The size (in timesteps) of each returned
            SampleBatch.
        num_slices (Optional[int]): The number of slices to produce.
        k (int): Obsoleted: Use size or num_slices instead!
            The size (in timesteps) of each returned SampleBatch.

    Returns:
        List[SampleBatch]: The list of `num_slices` (new) SampleBatches
            or n (new) SampleBatches each one of size `size`.
    """
    if size is None and num_slices is None:
        deprecation_warning("k", "size or num_slices")
        assert k is not None
        size = k

    if size is None:
        assert isinstance(num_slices, int)

        slices = []
        left = len(self)
        start = 0
        while left:
            len_ = left // (num_slices - len(slices))
            stop = start + len_
            slices.append(self[start:stop])
            left -= len_
            start = stop

        return slices

    else:
        assert isinstance(size, int)

        slices = []
        left = len(self)
        start = 0
        while left:
            stop = start + size
            slices.append(self[start:stop])
            left -= size
            start = stop

        return slices

to_device(self, device, framework='torch')

Source code in ray/rllib/policy/sample_batch.py
def to_device(self, device, framework="torch"):
    """TODO: transfer batch to given device as framework tensor."""
    if framework == "torch":
        assert torch is not None
        for k, v in self.items():
            if isinstance(v, np.ndarray) and v.dtype != np.object:
                self[k] = torch.from_numpy(v).to(device)
    else:
        raise NotImplementedError
    return self

ray.rllib.evaluation.sample_batch_builder.SampleBatchBuilder

Util to build a SampleBatch incrementally.

For efficiency, SampleBatches hold values in column form (as arrays). However, it is useful to add data one row (dict) at a time.

add_batch(self, batch)

Add the given batch of values to this batch.

Source code in ray/rllib/evaluation/sample_batch_builder.py
def add_batch(self, batch: SampleBatch) -> None:
    """Add the given batch of values to this batch."""

    for k, column in batch.items():
        self.buffers[k].extend(column)
    self.count += batch.count

add_values(self, **values)

Add the given dictionary (row) of values to this batch.

Source code in ray/rllib/evaluation/sample_batch_builder.py
def add_values(self, **values: Any) -> None:
    """Add the given dictionary (row) of values to this batch."""

    for k, v in values.items():
        self.buffers[k].append(v)
    self.count += 1

build_and_reset(self)

Returns a sample batch including all previously added values.

Source code in ray/rllib/evaluation/sample_batch_builder.py
def build_and_reset(self) -> SampleBatch:
    """Returns a sample batch including all previously added values."""

    batch = SampleBatch(
        {k: to_float_array(v)
         for k, v in self.buffers.items()})
    if SampleBatch.UNROLL_ID not in batch:
        batch[SampleBatch.UNROLL_ID] = np.repeat(
            SampleBatchBuilder._next_unroll_id, batch.count)
        SampleBatchBuilder._next_unroll_id += 1
    self.buffers.clear()
    self.count = 0
    return batch

ray.rllib.policy.sample_batch.MultiAgentBatch

A batch of experiences from multiple agents in the environment.

Attributes:

Name Type Description
policy_batches Dict[PolicyID, SampleBatch]

Mapping from policy ids to SampleBatches of experiences.

count int

The number of env steps in this batch.

__init__(self, policy_batches, env_steps) special

Initialize a MultiAgentBatch object.

Parameters:

Name Type Description Default
policy_batches Dict[PolicyID, SampleBatch]

Mapping from policy ids to SampleBatches of experiences.

required
env_steps int

The number of environment steps in the environment this batch contains. This will be less than the number of transitions this batch contains across all policies in total.

required
Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
             env_steps: int):
    """Initialize a MultiAgentBatch object.

    Args:
        policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
            ids to SampleBatches of experiences.
        env_steps (int): The number of environment steps in the environment
            this batch contains. This will be less than the number of
            transitions this batch contains across all policies in total.
    """

    for v in policy_batches.values():
        assert isinstance(v, SampleBatch)
    self.policy_batches = policy_batches
    # Called "count" for uniformity with SampleBatch.
    # Prefer to access this via the `env_steps()` method when possible
    # for clarity.
    self.count = env_steps

agent_steps(self)

The number of agent steps (there are >= 1 agent steps per env step).

Returns:

Type Description
int

The number of agent steps total in this batch.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def agent_steps(self) -> int:
    """The number of agent steps (there are >= 1 agent steps per env step).

    Returns:
        int: The number of agent steps total in this batch.
    """
    ct = 0
    for batch in self.policy_batches.values():
        ct += batch.count
    return ct

compress(self, bulk=False, columns=frozenset({'new_obs', 'obs'}))

Compresses each policy batch (per column) in place.

Parameters:

Name Type Description Default
bulk bool

Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.

False
columns Set[str]

Set of column names to compress.

frozenset({'new_obs', 'obs'})
Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def compress(self,
             bulk: bool = False,
             columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
    """Compresses each policy batch (per column) in place.

    Args:
        bulk (bool): Whether to compress across the batch dimension (0)
            as well. If False will compress n separate list items, where n
            is the batch size.
        columns (Set[str]): Set of column names to compress.
    """
    for batch in self.policy_batches.values():
        batch.compress(bulk=bulk, columns=columns)

concat_samples(samples) staticmethod

Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.

Parameters:

Name Type Description Default
samples List[MultiAgentBatch]

List of MultiagentBatch objects to concatenate.

required

Returns:

Type Description
MultiAgentBatch

A new MultiAgentBatch consisting of the concatenated inputs.

Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
    """Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.

    Args:
        samples (List[MultiAgentBatch]): List of MultiagentBatch objects
            to concatenate.

    Returns:
        MultiAgentBatch: A new MultiAgentBatch consisting of the
            concatenated inputs.
    """
    policy_batches = collections.defaultdict(list)
    env_steps = 0
    for s in samples:
        # Some batches in `samples` are not MultiAgentBatch.
        if not isinstance(s, MultiAgentBatch):
            # If empty SampleBatch: ok (just ignore).
            if isinstance(s, SampleBatch) and len(s) <= 0:
                continue
            # Otherwise: Error.
            raise ValueError(
                "`MultiAgentBatch.concat_samples()` can only concat "
                "MultiAgentBatch types, not {}!".format(type(s).__name__))
        for key, batch in s.policy_batches.items():
            policy_batches[key].append(batch)
        env_steps += s.env_steps()
    out = {}
    for key, batches in policy_batches.items():
        out[key] = SampleBatch.concat_samples(batches)
    return MultiAgentBatch(out, env_steps)

copy(self)

Deep-copies self into a new MultiAgentBatch.

Returns:

Type Description
MultiAgentBatch

The copy of self with deep-copied data.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def copy(self) -> "MultiAgentBatch":
    """Deep-copies self into a new MultiAgentBatch.

    Returns:
        MultiAgentBatch: The copy of self with deep-copied data.
    """
    return MultiAgentBatch(
        {k: v.copy()
         for (k, v) in self.policy_batches.items()}, self.count)

decompress_if_needed(self, columns=frozenset({'new_obs', 'obs'}))

Decompresses each policy batch (per column), if already compressed.

Parameters:

Name Type Description Default
columns Set[str]

Set of column names to decompress.

frozenset({'new_obs', 'obs'})

Returns:

Type Description
MultiAgentBatch

This very MultiAgentBatch.

Source code in ray/rllib/policy/sample_batch.py
@DeveloperAPI
def decompress_if_needed(self,
                         columns: Set[str] = frozenset(
                             ["obs", "new_obs"])) -> "MultiAgentBatch":
    """Decompresses each policy batch (per column), if already compressed.

    Args:
        columns (Set[str]): Set of column names to decompress.

    Returns:
        MultiAgentBatch: This very MultiAgentBatch.
    """
    for batch in self.policy_batches.values():
        batch.decompress_if_needed(columns)
    return self

env_steps(self)

The number of env steps (there are >= 1 agent steps per env step).

Returns:

Type Description
int

The number of environment steps contained in this batch.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def env_steps(self) -> int:
    """The number of env steps (there are >= 1 agent steps per env step).

    Returns:
        int: The number of environment steps contained in this batch.
    """
    return self.count

size_bytes(self)

Returns:

Type Description
int

The overall size in bytes of all policy batches (all columns).

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def size_bytes(self) -> int:
    """
    Returns:
        int: The overall size in bytes of all policy batches (all columns).
    """
    return sum(b.size_bytes() for b in self.policy_batches.values())

timeslices(self, k)

Returns k-step batches holding data for each agent at those steps.

For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3], for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.

Calling timeslices(1) would return three MultiAgentBatches containing [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].

Calling timeslices(2) would return two MultiAgentBatches containing [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].

This method is used to implement "lockstep" replay mode. Note that this method does not guarantee each batch contains only data from a single unroll. Batches might contain data from multiple different envs.

Source code in ray/rllib/policy/sample_batch.py
@PublicAPI
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
    """Returns k-step batches holding data for each agent at those steps.

    For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
    for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.

    Calling timeslices(1) would return three MultiAgentBatches containing
    [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].

    Calling timeslices(2) would return two MultiAgentBatches containing
    [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].

    This method is used to implement "lockstep" replay mode. Note that this
    method does not guarantee each batch contains only data from a single
    unroll. Batches might contain data from multiple different envs.
    """
    from ray.rllib.evaluation.sample_batch_builder import \
        SampleBatchBuilder

    # Build a sorted set of (eps_id, t, policy_id, data...)
    steps = []
    for policy_id, batch in self.policy_batches.items():
        for row in batch.rows():
            steps.append((row[SampleBatch.EPS_ID], row[SampleBatch.T],
                          row[SampleBatch.AGENT_INDEX], policy_id, row))
    steps.sort()

    finished_slices = []
    cur_slice = collections.defaultdict(SampleBatchBuilder)
    cur_slice_size = 0

    def finish_slice():
        nonlocal cur_slice_size
        assert cur_slice_size > 0
        batch = MultiAgentBatch(
            {k: v.build_and_reset()
             for k, v in cur_slice.items()}, cur_slice_size)
        cur_slice_size = 0
        finished_slices.append(batch)

    # For each unique env timestep.
    for _, group in itertools.groupby(steps, lambda x: x[:2]):
        # Accumulate into the current slice.
        for _, _, _, policy_id, row in group:
            cur_slice[policy_id].add_values(**row)
        cur_slice_size += 1
        # Slice has reached target number of env steps.
        if cur_slice_size >= k:
            finish_slice()
            assert cur_slice_size == 0

    if cur_slice_size > 0:
        finish_slice()

    assert len(finished_slices) > 0, finished_slices
    return finished_slices

wrap_as_needed(policy_batches, env_steps) staticmethod

Returns SampleBatch or MultiAgentBatch, depending on given policies.

Parameters:

Name Type Description Default
policy_batches Dict[PolicyID, SampleBatch]

Mapping from policy ids to SampleBatch.

required
env_steps int

Number of env steps in the batch.

required

Returns:

Type Description
Union[SampleBatch, MultiAgentBatch]

The single default policy's SampleBatch or a MultiAgentBatch (more than one policy).

Source code in ray/rllib/policy/sample_batch.py
@staticmethod
@PublicAPI
def wrap_as_needed(
        policy_batches: Dict[PolicyID, SampleBatch],
        env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
    """Returns SampleBatch or MultiAgentBatch, depending on given policies.

    Args:
        policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
            ids to SampleBatch.
        env_steps (int): Number of env steps in the batch.

    Returns:
        Union[SampleBatch, MultiAgentBatch]: The single default policy's
            SampleBatch or a MultiAgentBatch (more than one policy).
    """
    if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
        return policy_batches[DEFAULT_POLICY_ID]
    return MultiAgentBatch(
        policy_batches=policy_batches, env_steps=env_steps)

ray.rllib.evaluation.sample_batch_builder.MultiAgentSampleBatchBuilder

Util to build SampleBatches for each policy in a multi-agent env.

Input data is per-agent, while output data is per-policy. There is an M:N mapping between agents and policies. We retain one local batch builder per agent. When an agent is done, then its local batch is appended into the corresponding policy batch for the agent's policy.

__init__(self, policy_map, clip_rewards, callbacks) special

Initialize a MultiAgentSampleBatchBuilder.

Parameters:

Name Type Description Default
policy_map Dict[str,Policy]

Maps policy ids to policy instances.

required
clip_rewards Union[bool,float]

Whether to clip rewards before postprocessing (at +/-1.0) or the actual value to +/- clip.

required
callbacks DefaultCallbacks

RLlib callbacks.

required
Source code in ray/rllib/evaluation/sample_batch_builder.py
def __init__(self, policy_map: Dict[PolicyID, Policy], clip_rewards: bool,
             callbacks: "DefaultCallbacks"):
    """Initialize a MultiAgentSampleBatchBuilder.

    Args:
        policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
        clip_rewards (Union[bool,float]): Whether to clip rewards before
            postprocessing (at +/-1.0) or the actual value to +/- clip.
        callbacks (DefaultCallbacks): RLlib callbacks.
    """
    if log_once("MultiAgentSampleBatchBuilder"):
        deprecation_warning(
            old="MultiAgentSampleBatchBuilder", error=False)
    self.policy_map = policy_map
    self.clip_rewards = clip_rewards
    # Build the Policies' SampleBatchBuilders.
    self.policy_builders = {
        k: SampleBatchBuilder()
        for k in policy_map.keys()
    }
    # Whenever we observe a new agent, add a new SampleBatchBuilder for
    # this agent.
    self.agent_builders = {}
    # Internal agent-to-policy map.
    self.agent_to_policy = {}
    self.callbacks = callbacks
    # Number of "inference" steps taken in the environment.
    # Regardless of the number of agents involved in each of these steps.
    self.count = 0

add_values(self, agent_id, policy_id, **values)

Add the given dictionary (row) of values to this batch.

Parameters:

Name Type Description Default
agent_id obj

Unique id for the agent we are adding values for.

required
policy_id obj

Unique id for policy controlling the agent.

required
values dict

Row of values to add for this agent.

{}
Source code in ray/rllib/evaluation/sample_batch_builder.py
@DeveloperAPI
def add_values(self, agent_id: AgentID, policy_id: AgentID,
               **values: Any) -> None:
    """Add the given dictionary (row) of values to this batch.

    Args:
        agent_id (obj): Unique id for the agent we are adding values for.
        policy_id (obj): Unique id for policy controlling the agent.
        values (dict): Row of values to add for this agent.
    """

    if agent_id not in self.agent_builders:
        self.agent_builders[agent_id] = SampleBatchBuilder()
        self.agent_to_policy[agent_id] = policy_id

    # Include the current agent id for multi-agent algorithms.
    if agent_id != _DUMMY_AGENT_ID:
        values["agent_id"] = agent_id

    self.agent_builders[agent_id].add_values(**values)

build_and_reset(self, episode=None)

Returns the accumulated sample batches for each policy.

Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset.

Parameters:

Name Type Description Default
episode Optional[Episode]

The Episode object that holds this MultiAgentBatchBuilder object or None.

None

Returns:

Type Description
MultiAgentBatch

Returns the accumulated sample batches for each policy.

Source code in ray/rllib/evaluation/sample_batch_builder.py
@DeveloperAPI
def build_and_reset(self,
                    episode: Optional[Episode] = None) -> MultiAgentBatch:
    """Returns the accumulated sample batches for each policy.

    Any unprocessed rows will be first postprocessed with a policy
    postprocessor. The internal state of this builder will be reset.

    Args:
        episode (Optional[Episode]): The Episode object that
            holds this MultiAgentBatchBuilder object or None.

    Returns:
        MultiAgentBatch: Returns the accumulated sample batches for each
            policy.
    """

    self.postprocess_batch_so_far(episode)
    policy_batches = {}
    for policy_id, builder in self.policy_builders.items():
        if builder.count > 0:
            policy_batches[policy_id] = builder.build_and_reset()
    old_count = self.count
    self.count = 0
    return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)

has_pending_agent_data(self)

Returns whether there is pending unprocessed data.

Returns:

Type Description
bool

True if there is at least one per-agent builder (with data in it).

Source code in ray/rllib/evaluation/sample_batch_builder.py
def has_pending_agent_data(self) -> bool:
    """Returns whether there is pending unprocessed data.

    Returns:
        bool: True if there is at least one per-agent builder (with data
            in it).
    """

    return len(self.agent_builders) > 0

postprocess_batch_so_far(self, episode=None)

Apply policy postprocessors to any unprocessed rows.

This pushes the postprocessed per-agent batches onto the per-policy builders, clearing per-agent state.

Parameters:

Name Type Description Default
episode Optional[Episode]

The Episode object that holds this MultiAgentBatchBuilder object.

None
Source code in ray/rllib/evaluation/sample_batch_builder.py
def postprocess_batch_so_far(self,
                             episode: Optional[Episode] = None) -> None:
    """Apply policy postprocessors to any unprocessed rows.

    This pushes the postprocessed per-agent batches onto the per-policy
    builders, clearing per-agent state.

    Args:
        episode (Optional[Episode]): The Episode object that
            holds this MultiAgentBatchBuilder object.
    """

    # Materialize the batches so far.
    pre_batches = {}
    for agent_id, builder in self.agent_builders.items():
        pre_batches[agent_id] = (
            self.policy_map[self.agent_to_policy[agent_id]],
            builder.build_and_reset())

    # Apply postprocessor.
    post_batches = {}
    if self.clip_rewards is True:
        for _, (_, pre_batch) in pre_batches.items():
            pre_batch["rewards"] = np.sign(pre_batch["rewards"])
    elif self.clip_rewards:
        for _, (_, pre_batch) in pre_batches.items():
            pre_batch["rewards"] = np.clip(
                pre_batch["rewards"],
                a_min=-self.clip_rewards,
                a_max=self.clip_rewards)
    for agent_id, (_, pre_batch) in pre_batches.items():
        other_batches = pre_batches.copy()
        del other_batches[agent_id]
        policy = self.policy_map[self.agent_to_policy[agent_id]]
        if any(pre_batch["dones"][:-1]) or len(set(
                pre_batch["eps_id"])) > 1:
            raise ValueError(
                "Batches sent to postprocessing must only contain steps "
                "from a single trajectory.", pre_batch)
        # Call the Policy's Exploration's postprocess method.
        post_batches[agent_id] = pre_batch
        if getattr(policy, "exploration", None) is not None:
            policy.exploration.postprocess_trajectory(
                policy, post_batches[agent_id], policy.get_session())
        post_batches[agent_id] = policy.postprocess_trajectory(
            post_batches[agent_id], other_batches, episode)

    if log_once("after_post"):
        logger.info(
            "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
            format(summarize(post_batches)))

    # Append into policy batches and reset
    from ray.rllib.evaluation.rollout_worker import get_global_worker
    for agent_id, post_batch in sorted(post_batches.items()):
        self.callbacks.on_postprocess_trajectory(
            worker=get_global_worker(),
            episode=episode,
            agent_id=agent_id,
            policy_id=self.agent_to_policy[agent_id],
            policies=self.policy_map,
            postprocessed_batch=post_batch,
            original_batches=pre_batches)
        self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
            post_batch)

    self.agent_builders.clear()
    self.agent_to_policy.clear()

total(self)

Returns the total number of steps taken in the env (all agents).

Returns:

Type Description
int

The number of steps taken in total in the environment over all agents.

Source code in ray/rllib/evaluation/sample_batch_builder.py
def total(self) -> int:
    """Returns the total number of steps taken in the env (all agents).

    Returns:
        int: The number of steps taken in total in the environment over all
            agents.
    """

    return sum(a.count for a in self.agent_builders.values())

Samplers

ray.rllib.evaluation.sampler.SyncSampler (SamplerInput)

Sync SamplerInput that collects experiences when get_data() is called.

__init__(self, *, worker, env, clip_rewards, rollout_fragment_length, count_steps_by='env_steps', callbacks, horizon=None, multiple_episodes_in_batch=False, normalize_actions=True, clip_actions=False, soft_horizon=False, no_done_at_end=False, observation_fn=None, sample_collector_class=None, render=False, policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, tf_sess=None) special

Initializes a SyncSampler instance.

Parameters:

Name Type Description Default
worker RolloutWorker

The RolloutWorker that will use this Sampler for sampling.

required
env BaseEnv

Any Env object. Will be converted into an RLlib BaseEnv.

required
clip_rewards Union[bool, float]

True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping.

required
rollout_fragment_length int

The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object.

required
count_steps_by str

One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode.

'env_steps'
callbacks DefaultCallbacks

The Callbacks object to use when episode events happen during rollout.

required
horizon int

Hard-reset the Env after this many timesteps.

None
multiple_episodes_in_batch bool

Whether to pack multiple episodes into each batch. This guarantees batches will be exactly rollout_fragment_length in size.

False
normalize_actions bool

Whether to normalize actions to the action space's bounds.

True
clip_actions bool

Whether to clip actions according to the given action_space's bounds.

False
soft_horizon bool

If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit.

False
no_done_at_end bool

Ignore the done=True at the end of the episode and instead record done=False.

False
observation_fn Optional[ObservationFunction]

Optional multi-agent observation func to use for preprocessing observations.

None
sample_collector_class Optional[Type[ray.rllib.evaluation.collectors.sample_collector.SampleCollector]]

An optional Samplecollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data.

None
render bool

Whether to try to render the environment after each step.

False
Source code in ray/rllib/evaluation/sampler.py
def __init__(
        self,
        *,
        worker: "RolloutWorker",
        env: BaseEnv,
        clip_rewards: Union[bool, float],
        rollout_fragment_length: int,
        count_steps_by: str = "env_steps",
        callbacks: "DefaultCallbacks",
        horizon: int = None,
        multiple_episodes_in_batch: bool = False,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        observation_fn: Optional["ObservationFunction"] = None,
        sample_collector_class: Optional[Type[SampleCollector]] = None,
        render: bool = False,
        # Obsolete.
        policies=None,
        policy_mapping_fn=None,
        preprocessors=None,
        obs_filters=None,
        tf_sess=None,
):
    """Initializes a SyncSampler instance.

    Args:
        worker: The RolloutWorker that will use this Sampler for sampling.
        env: Any Env object. Will be converted into an RLlib BaseEnv.
        clip_rewards: True for +/-1.0 clipping,
            actual float value for +/- value clipping. False for no
            clipping.
        rollout_fragment_length: The length of a fragment to collect
            before building a SampleBatch from the data and resetting
            the SampleBatchBuilder object.
        count_steps_by: One of "env_steps" (default) or "agent_steps".
            Use "agent_steps", if you want rollout lengths to be counted
            by individual agent steps. In a multi-agent env,
            a single env_step contains one or more agent_steps, depending
            on how many agents are present at any given time in the
            ongoing episode.
        callbacks: The Callbacks object to use when episode
            events happen during rollout.
        horizon: Hard-reset the Env after this many timesteps.
        multiple_episodes_in_batch: Whether to pack multiple
            episodes into each batch. This guarantees batches will be
            exactly `rollout_fragment_length` in size.
        normalize_actions: Whether to normalize actions to the
            action space's bounds.
        clip_actions: Whether to clip actions according to the
            given action_space's bounds.
        soft_horizon: If True, calculate bootstrapped values as if
            episode had ended, but don't physically reset the environment
            when the horizon is hit.
        no_done_at_end: Ignore the done=True at the end of the
            episode and instead record done=False.
        observation_fn: Optional multi-agent observation func to use for
            preprocessing observations.
        sample_collector_class: An optional Samplecollector sub-class to
            use to collect, store, and retrieve environment-, model-,
            and sampler data.
        render: Whether to try to render the environment after each step.
    """
    # All of the following arguments are deprecated. They will instead be
    # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
    if log_once("deprecated_sync_sampler_args"):
        if policies is not None:
            deprecation_warning(old="policies")
        if policy_mapping_fn is not None:
            deprecation_warning(old="policy_mapping_fn")
        if preprocessors is not None:
            deprecation_warning(old="preprocessors")
        if obs_filters is not None:
            deprecation_warning(old="obs_filters")
        if tf_sess is not None:
            deprecation_warning(old="tf_sess")

    self.base_env = BaseEnv.to_base_env(env)
    self.rollout_fragment_length = rollout_fragment_length
    self.horizon = horizon
    self.extra_batches = queue.Queue()
    self.perf_stats = _PerfStats()
    if not sample_collector_class:
        sample_collector_class = SimpleListCollector
    self.sample_collector = sample_collector_class(
        worker.policy_map,
        clip_rewards,
        callbacks,
        multiple_episodes_in_batch,
        rollout_fragment_length,
        count_steps_by=count_steps_by)
    self.render = render

    # Create the rollout generator to use for calls to `get_data()`.
    self._env_runner = _env_runner(
        worker, self.base_env, self.extra_batches.put, self.horizon,
        normalize_actions, clip_actions, multiple_episodes_in_batch,
        callbacks, self.perf_stats, soft_horizon, no_done_at_end,
        observation_fn, self.sample_collector, self.render)
    self.metrics_queue = queue.Queue()

get_data(self)

Called by self.next() to return the next batch of data.

Override this in child classes.

Returns:

Type Description
Union[SampleBatch, MultiAgentBatch]

The next batch of data.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_data(self) -> SampleBatchType:
    while True:
        item = next(self._env_runner)
        if isinstance(item, RolloutMetrics):
            self.metrics_queue.put(item)
        else:
            return item

get_extra_batches(self)

Returns list of extra batches since the last call to this method.

The list will contain all SampleBatches or MultiAgentBatches that the user has provided thus-far. Users can add these "extra batches" to an episode by calling the episode's add_extra_batch([SampleBatchType]) method. This can be done from inside an overridden Policy.compute_actions_from_input_dict(..., episodes) or from a custom callback's on_episode_[start|step|end]() methods.

Returns:

Type Description
List[Union[SampleBatch, MultiAgentBatch]]

List of SamplesBatches or MultiAgentBatches provided thus-far by the user since the last call to this method.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_extra_batches(self) -> List[SampleBatchType]:
    extra = []
    while True:
        try:
            extra.append(self.extra_batches.get_nowait())
        except queue.Empty:
            break
    return extra

get_metrics(self)

Returns list of episode metrics since the last call to this method.

The list will contain one RolloutMetrics object per completed episode.

Returns:

Type Description
List[ray.rllib.evaluation.metrics.RolloutMetrics]

List of RolloutMetrics objects, one per completed episode since the last call to this method.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_metrics(self) -> List[RolloutMetrics]:
    completed = []
    while True:
        try:
            completed.append(self.metrics_queue.get_nowait()._replace(
                perf_stats=self.perf_stats.get()))
        except queue.Empty:
            break
    return completed

ray.rllib.evaluation.sampler.AsyncSampler (Thread, SamplerInput)

Async SamplerInput that collects experiences in thread and queues them.

Once started, experiences are continuously collected in the background and put into a Queue, from where they can be unqueued by the caller of get_data().

__init__(self, *, worker, env, clip_rewards, rollout_fragment_length, count_steps_by='env_steps', callbacks, horizon=None, multiple_episodes_in_batch=False, normalize_actions=True, clip_actions=False, soft_horizon=False, no_done_at_end=False, observation_fn=None, sample_collector_class=None, render=False, blackhole_outputs=False, policies=None, policy_mapping_fn=None, preprocessors=None, obs_filters=None, tf_sess=None) special

Initializes an AsyncSampler instance.

Parameters:

Name Type Description Default
worker RolloutWorker

The RolloutWorker that will use this Sampler for sampling.

required
env BaseEnv

Any Env object. Will be converted into an RLlib BaseEnv.

required
clip_rewards Union[bool, float]

True for +/-1.0 clipping, actual float value for +/- value clipping. False for no clipping.

required
rollout_fragment_length int

The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object.

required
count_steps_by str

One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode.

'env_steps'
horizon Optional[int]

Hard-reset the Env after this many timesteps.

None
multiple_episodes_in_batch bool

Whether to pack multiple episodes into each batch. This guarantees batches will be exactly rollout_fragment_length in size.

False
normalize_actions bool

Whether to normalize actions to the action space's bounds.

True
clip_actions bool

Whether to clip actions according to the given action_space's bounds.

False
blackhole_outputs bool

Whether to collect samples, but then not further process or store them (throw away all samples).

False
soft_horizon bool

If True, calculate bootstrapped values as if episode had ended, but don't physically reset the environment when the horizon is hit.

False
no_done_at_end bool

Ignore the done=True at the end of the episode and instead record done=False.

False
observation_fn Optional[ObservationFunction]

Optional multi-agent observation func to use for preprocessing observations.

None
sample_collector_class Optional[Type[ray.rllib.evaluation.collectors.sample_collector.SampleCollector]]

An optional SampleCollector sub-class to use to collect, store, and retrieve environment-, model-, and sampler data.

None
render bool

Whether to try to render the environment after each step.

False
Source code in ray/rllib/evaluation/sampler.py
def __init__(
        self,
        *,
        worker: "RolloutWorker",
        env: BaseEnv,
        clip_rewards: Union[bool, float],
        rollout_fragment_length: int,
        count_steps_by: str = "env_steps",
        callbacks: "DefaultCallbacks",
        horizon: Optional[int] = None,
        multiple_episodes_in_batch: bool = False,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        observation_fn: Optional["ObservationFunction"] = None,
        sample_collector_class: Optional[Type[SampleCollector]] = None,
        render: bool = False,
        blackhole_outputs: bool = False,
        # Obsolete.
        policies=None,
        policy_mapping_fn=None,
        preprocessors=None,
        obs_filters=None,
        tf_sess=None,
):
    """Initializes an AsyncSampler instance.

    Args:
        worker: The RolloutWorker that will use this Sampler for sampling.
        env: Any Env object. Will be converted into an RLlib BaseEnv.
        clip_rewards: True for +/-1.0 clipping,
            actual float value for +/- value clipping. False for no
            clipping.
        rollout_fragment_length: The length of a fragment to collect
            before building a SampleBatch from the data and resetting
            the SampleBatchBuilder object.
        count_steps_by: One of "env_steps" (default) or "agent_steps".
            Use "agent_steps", if you want rollout lengths to be counted
            by individual agent steps. In a multi-agent env,
            a single env_step contains one or more agent_steps, depending
            on how many agents are present at any given time in the
            ongoing episode.
        horizon: Hard-reset the Env after this many timesteps.
        multiple_episodes_in_batch: Whether to pack multiple
            episodes into each batch. This guarantees batches will be
            exactly `rollout_fragment_length` in size.
        normalize_actions: Whether to normalize actions to the
            action space's bounds.
        clip_actions: Whether to clip actions according to the
            given action_space's bounds.
        blackhole_outputs: Whether to collect samples, but then
            not further process or store them (throw away all samples).
        soft_horizon: If True, calculate bootstrapped values as if
            episode had ended, but don't physically reset the environment
            when the horizon is hit.
        no_done_at_end: Ignore the done=True at the end of the
            episode and instead record done=False.
        observation_fn: Optional multi-agent observation func to use for
            preprocessing observations.
        sample_collector_class: An optional SampleCollector sub-class to
            use to collect, store, and retrieve environment-, model-,
            and sampler data.
        render: Whether to try to render the environment after each step.
    """
    # All of the following arguments are deprecated. They will instead be
    # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
    if log_once("deprecated_async_sampler_args"):
        if policies is not None:
            deprecation_warning(old="policies")
        if policy_mapping_fn is not None:
            deprecation_warning(old="policy_mapping_fn")
        if preprocessors is not None:
            deprecation_warning(old="preprocessors")
        if obs_filters is not None:
            deprecation_warning(old="obs_filters")
        if tf_sess is not None:
            deprecation_warning(old="tf_sess")

    self.worker = worker

    for _, f in worker.filters.items():
        assert getattr(f, "is_concurrent", False), \
            "Observation Filter must support concurrent updates."

    self.base_env = BaseEnv.to_base_env(env)
    threading.Thread.__init__(self)
    self.queue = queue.Queue(5)
    self.extra_batches = queue.Queue()
    self.metrics_queue = queue.Queue()
    self.rollout_fragment_length = rollout_fragment_length
    self.horizon = horizon
    self.clip_rewards = clip_rewards
    self.daemon = True
    self.multiple_episodes_in_batch = multiple_episodes_in_batch
    self.callbacks = callbacks
    self.normalize_actions = normalize_actions
    self.clip_actions = clip_actions
    self.blackhole_outputs = blackhole_outputs
    self.soft_horizon = soft_horizon
    self.no_done_at_end = no_done_at_end
    self.perf_stats = _PerfStats()
    self.shutdown = False
    self.observation_fn = observation_fn
    self.render = render
    if not sample_collector_class:
        sample_collector_class = SimpleListCollector
    self.sample_collector = sample_collector_class(
        self.worker.policy_map,
        self.clip_rewards,
        self.callbacks,
        self.multiple_episodes_in_batch,
        self.rollout_fragment_length,
        count_steps_by=count_steps_by)

get_data(self)

Called by self.next() to return the next batch of data.

Override this in child classes.

Returns:

Type Description
Union[SampleBatch, MultiAgentBatch]

The next batch of data.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_data(self) -> SampleBatchType:
    if not self.is_alive():
        raise RuntimeError("Sampling thread has died")
    rollout = self.queue.get(timeout=600.0)

    # Propagate errors.
    if isinstance(rollout, BaseException):
        raise rollout

    return rollout

get_extra_batches(self)

Returns list of extra batches since the last call to this method.

The list will contain all SampleBatches or MultiAgentBatches that the user has provided thus-far. Users can add these "extra batches" to an episode by calling the episode's add_extra_batch([SampleBatchType]) method. This can be done from inside an overridden Policy.compute_actions_from_input_dict(..., episodes) or from a custom callback's on_episode_[start|step|end]() methods.

Returns:

Type Description
List[Union[SampleBatch, MultiAgentBatch]]

List of SamplesBatches or MultiAgentBatches provided thus-far by the user since the last call to this method.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_extra_batches(self) -> List[SampleBatchType]:
    extra = []
    while True:
        try:
            extra.append(self.extra_batches.get_nowait())
        except queue.Empty:
            break
    return extra

get_metrics(self)

Returns list of episode metrics since the last call to this method.

The list will contain one RolloutMetrics object per completed episode.

Returns:

Type Description
List[ray.rllib.evaluation.metrics.RolloutMetrics]

List of RolloutMetrics objects, one per completed episode since the last call to this method.

Source code in ray/rllib/evaluation/sampler.py
@override(SamplerInput)
def get_metrics(self) -> List[RolloutMetrics]:
    completed = []
    while True:
        try:
            completed.append(self.metrics_queue.get_nowait()._replace(
                perf_stats=self.perf_stats.get()))
        except queue.Empty:
            break
    return completed

run(self)

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

Source code in ray/rllib/evaluation/sampler.py
@override(threading.Thread)
def run(self):
    try:
        self._run()
    except BaseException as e:
        self.queue.put(e)
        raise e

Utility Functions

ray.rllib.evaluation.postprocessing.compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True, use_critic=True)

Given a rollout, compute its value targets and the advantages.

Parameters:

Name Type Description Default
rollout SampleBatch

SampleBatch of a single trajectory.

required
last_r float

Value estimation for last observation.

required
gamma float

Discount factor.

0.9
lambda_ float

Parameter for GAE.

1.0
use_gae bool

Using Generalized Advantage Estimation.

True
use_critic bool

Whether to use critic (value estimates). Setting this to False will use 0 as baseline.

True

Returns:

Type Description

SampleBatch with experience from rollout and processed rewards.

Source code in ray/rllib/evaluation/postprocessing.py
@DeveloperAPI
def compute_advantages(rollout: SampleBatch,
                       last_r: float,
                       gamma: float = 0.9,
                       lambda_: float = 1.0,
                       use_gae: bool = True,
                       use_critic: bool = True):
    """Given a rollout, compute its value targets and the advantages.

    Args:
        rollout: SampleBatch of a single trajectory.
        last_r: Value estimation for last observation.
        gamma: Discount factor.
        lambda_: Parameter for GAE.
        use_gae: Using Generalized Advantage Estimation.
        use_critic: Whether to use critic (value estimates). Setting
            this to False will use 0 as baseline.

    Returns:
        SampleBatch with experience from rollout and processed rewards.
    """

    assert SampleBatch.VF_PREDS in rollout or not use_critic, \
        "use_critic=True but values not found"
    assert use_critic or not use_gae, \
        "Can't use gae without using a value function"

    if use_gae:
        vpred_t = np.concatenate(
            [rollout[SampleBatch.VF_PREDS],
             np.array([last_r])])
        delta_t = (
            rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1])
        # This formula for the advantage comes from:
        # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
        rollout[Postprocessing.ADVANTAGES] = discount_cumsum(
            delta_t, gamma * lambda_)
        rollout[Postprocessing.VALUE_TARGETS] = (
            rollout[Postprocessing.ADVANTAGES] +
            rollout[SampleBatch.VF_PREDS]).astype(np.float32)
    else:
        rewards_plus_v = np.concatenate(
            [rollout[SampleBatch.REWARDS],
             np.array([last_r])])
        discounted_returns = discount_cumsum(rewards_plus_v,
                                             gamma)[:-1].astype(np.float32)

        if use_critic:
            rollout[Postprocessing.
                    ADVANTAGES] = discounted_returns - rollout[SampleBatch.
                                                               VF_PREDS]
            rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
        else:
            rollout[Postprocessing.ADVANTAGES] = discounted_returns
            rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
                rollout[Postprocessing.ADVANTAGES])

    rollout[Postprocessing.ADVANTAGES] = rollout[
        Postprocessing.ADVANTAGES].astype(np.float32)

    return rollout

ray.rllib.evaluation.metrics.collect_metrics(local_worker=None, remote_workers=None, to_be_collected=None, timeout_seconds=180)

Gathers episode metrics from RolloutWorker instances.

Source code in ray/rllib/evaluation/metrics.py
@DeveloperAPI
def collect_metrics(local_worker: Optional["RolloutWorker"] = None,
                    remote_workers: Optional[List[ActorHandle]] = None,
                    to_be_collected: Optional[List[ObjectRef]] = None,
                    timeout_seconds: int = 180) -> ResultDict:
    """Gathers episode metrics from RolloutWorker instances."""
    if remote_workers is None:
        remote_workers = []

    if to_be_collected is None:
        to_be_collected = []

    episodes, to_be_collected = collect_episodes(
        local_worker,
        remote_workers,
        to_be_collected,
        timeout_seconds=timeout_seconds)
    metrics = summarize_episodes(episodes, episodes)
    return metrics
Back to top