Skip to content

agents package

ray.rllib.agents.trainer.Trainer (Trainable)

An RLlib algorithm responsible for optimizing one or more Policies.

Trainers contain a WorkerSet under self.workers. A WorkerSet is normally composed of a single local worker (self.workers.local_worker()), used to compute and apply learning updates, and optionally one or more remote workers (self.workers.remote_workers()), used to generate environment samples in parallel.

Each worker (remotes or local) contains a PolicyMap, which itself may contain either one policy for single-agent training or one or more policies for multi-agent training. Policies are synchronized automatically from time to time using ray.remote calls. The exact synchronization logic depends on the specific algorithm (Trainer) used, but this usually happens from local worker to all remote workers and after each training update.

You can write your own Trainer sub-classes by using the rllib.agents.trainer_template.py::build_trainer() utility function. This allows you to provide a custom execution_plan. You can find the different built-in algorithms' execution plans in their respective main py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py.

The most important API methods a Trainer exposes are train(), evaluate(), save() and restore(). Trainer objects retain internal model state between calls to train(), so you should create a new Trainer instance for each training session.

__init__(self, config=None, env=None, logger_creator=None, remote_checkpoint_dir=None, sync_function_tpl=None) special

Initializes a Trainer instance.

Parameters:

Name Type Description Default
config Optional[dict]

Algorithm-specific configuration dict.

None
env Union[str, Any]

Name of the environment to use (e.g. a gym-registered str), a full class path (e.g. "ray.rllib.examples.env.random_env.RandomEnv"), or an Env class directly. Note that this arg can also be specified via the "env" key in config.

None
logger_creator Optional[Callable[[], ray.tune.logger.Logger]]

Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created.

None
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def __init__(self,
             config: Optional[PartialTrainerConfigDict] = None,
             env: Optional[Union[str, EnvType]] = None,
             logger_creator: Optional[Callable[[], Logger]] = None,
             remote_checkpoint_dir: Optional[str] = None,
             sync_function_tpl: Optional[str] = None):
    """Initializes a Trainer instance.

    Args:
        config: Algorithm-specific configuration dict.
        env: Name of the environment to use (e.g. a gym-registered str),
            a full class path (e.g.
            "ray.rllib.examples.env.random_env.RandomEnv"), or an Env
            class directly. Note that this arg can also be specified via
            the "env" key in `config`.
        logger_creator: Callable that creates a ray.tune.Logger
            object. If unspecified, a default logger is created.
    """

    # User provided (partial) config (this may be w/o the default
    # Trainer's `COMMON_CONFIG` (see above)). Will get merged with
    # COMMON_CONFIG in self.setup().
    config = config or {}

    # Trainers allow env ids to be passed directly to the constructor.
    self._env_id = self._register_if_needed(
        env or config.get("env"), config)
    # The env creator callable, taking an EnvContext (config dict)
    # as arg and returning an RLlib supported Env type (e.g. a gym.Env).
    self.env_creator: Callable[[EnvContext], EnvType] = None

    # Placeholder for a local replay buffer instance.
    self.local_replay_buffer = None

    # Create a default logger creator if no logger_creator is specified
    if logger_creator is None:
        # Default logdir prefix containing the agent's name and the
        # env id.
        timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
        logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
                                          timestr)
        if not os.path.exists(DEFAULT_RESULTS_DIR):
            os.makedirs(DEFAULT_RESULTS_DIR)
        logdir = tempfile.mkdtemp(
            prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)

        # Allow users to more precisely configure the created logger
        # via "logger_config.type".
        if config.get(
                "logger_config") and "type" in config["logger_config"]:

            def default_logger_creator(config):
                """Creates a custom logger with the default prefix."""
                cfg = config["logger_config"].copy()
                cls = cfg.pop("type")
                # Provide default for logdir, in case the user does
                # not specify this in the "logger_config" dict.
                logdir_ = cfg.pop("logdir", logdir)
                return from_config(cls=cls, _args=[cfg], logdir=logdir_)

        # If no `type` given, use tune's UnifiedLogger as last resort.
        else:

            def default_logger_creator(config):
                """Creates a Unified logger with the default prefix."""
                return UnifiedLogger(config, logdir, loggers=None)

        logger_creator = default_logger_creator

    super().__init__(config, logger_creator, remote_checkpoint_dir,
                     sync_function_tpl)

add_policy(self, policy_id, policy_cls, *, observation_space=None, action_space=None, config=None, policy_mapping_fn=None, policies_to_train=None, evaluation_workers=True)

Adds a new policy to this Trainer.

Parameters:

Name Type Description Default
policy_id PolicyID

ID of the policy to add.

required
policy_cls Type[Policy]

The Policy class to use for constructing the new Policy.

required
observation_space Optional[gym.spaces.Space]

The observation space of the policy to add.

None
action_space Optional[gym.spaces.Space]

The action space of the policy to add.

None
config Optional[PartialTrainerConfigDict]

The config overrides for the policy to add.

None
policy_mapping_fn Optional[Callable[[AgentID], PolicyID]]

An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

None
policies_to_train Optional[List[PolicyID]]

An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

None
evaluation_workers bool

Whether to add the new policy also to the evaluation WorkerSet.

True

Returns:

Type Description
Policy

The newly added policy (the copy that got added to the local worker).

Source code in ray/rllib/agents/trainer.py
@PublicAPI
def add_policy(
        self,
        policy_id: PolicyID,
        policy_cls: Type[Policy],
        *,
        observation_space: Optional[gym.spaces.Space] = None,
        action_space: Optional[gym.spaces.Space] = None,
        config: Optional[PartialTrainerConfigDict] = None,
        policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID],
                                             PolicyID]] = None,
        policies_to_train: Optional[List[PolicyID]] = None,
        evaluation_workers: bool = True,
) -> Policy:
    """Adds a new policy to this Trainer.

    Args:
        policy_id (PolicyID): ID of the policy to add.
        policy_cls (Type[Policy]): The Policy class to use for
            constructing the new Policy.
        observation_space (Optional[gym.spaces.Space]): The observation
            space of the policy to add.
        action_space (Optional[gym.spaces.Space]): The action space
            of the policy to add.
        config (Optional[PartialTrainerConfigDict]): The config overrides
            for the policy to add.
        policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
            optional (updated) policy mapping function to use from here on.
            Note that already ongoing episodes will not change their
            mapping but will use the old mapping till the end of the
            episode.
        policies_to_train (Optional[List[PolicyID]]): An optional list of
            policy IDs to be trained. If None, will keep the existing list
            in place. Policies, whose IDs are not in the list will not be
            updated.
        evaluation_workers (bool): Whether to add the new policy also
            to the evaluation WorkerSet.

    Returns:
        Policy: The newly added policy (the copy that got added to the
            local worker).
    """

    def fn(worker: RolloutWorker):
        # `foreach_worker` function: Adds the policy the the worker (and
        # maybe changes its policy_mapping_fn - if provided here).
        worker.add_policy(
            policy_id=policy_id,
            policy_cls=policy_cls,
            observation_space=observation_space,
            action_space=action_space,
            config=config,
            policy_mapping_fn=policy_mapping_fn,
            policies_to_train=policies_to_train,
        )

    # Run foreach_worker fn on all workers (incl. evaluation workers).
    self.workers.foreach_worker(fn)
    if evaluation_workers and self.evaluation_workers is not None:
        self.evaluation_workers.foreach_worker(fn)

    # Return newly added policy (from the local rollout worker).
    return self.get_policy(policy_id)

cleanup(self)

Subclasses should override this for any cleanup on stop.

If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.

You can kill a Ray actor by calling actor.__ray_terminate__.remote() on the actor.

.. versionadded:: 0.8.7

Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def cleanup(self) -> None:
    # Stop all workers.
    if hasattr(self, "workers"):
        self.workers.stop()
    # Stop all optimizers.
    if hasattr(self, "optimizer") and self.optimizer:
        self.optimizer.stop()

collect_metrics(self, selected_workers=None)

Collects metrics from the remote workers of this agent.

This is the same data as returned by a call to train().

Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def collect_metrics(self,
                    selected_workers: List[ActorHandle] = None) -> dict:
    """Collects metrics from the remote workers of this agent.

    This is the same data as returned by a call to train().
    """
    return self.optimizer.collect_metrics(
        self.config["collect_metrics_timeout"],
        min_history=self.config["metrics_smoothing_episodes"],
        selected_workers=selected_workers)

compute_actions(self, observations, state=None, *, prev_action=None, prev_reward=None, info=None, policy_id='default_policy', full_fetch=False, explore=None, timestep=None, episodes=None, unsquash_actions=None, clip_actions=None, normalize_actions=None, **kwargs)

Computes an action for the specified policy on the local Worker.

Note that you can also access the policy object through self.get_policy(policy_id) and call compute_actions() on it directly.

Parameters:

Name Type Description Default
observation

Observation from the environment.

required
state Optional[List[Union[Any, dict, tuple]]]

RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action).

None
prev_action Union[Any, dict, tuple]

Previous action value, if any.

None
prev_reward Union[Any, dict, tuple]

Previous reward, if any.

None
info Optional[dict]

Env info dict, if any.

None
policy_id str

Policy to query (only applies to multi-agent).

'default_policy'
full_fetch bool

Whether to return extra action fetch results. This is always set to True if RNN state is specified.

False
explore Optional[bool]

Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]).

None
timestep Optional[int]

The current (sampling) time step.

None
episodes Optional[List[ray.rllib.evaluation.episode.Episode]]

This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms.

None
unsquash_actions Optional[bool]

Should actions be unsquashed according to the env's/Policy's action space? If None, use self.config["normalize_actions"].

None
clip_actions Optional[bool]

Should actions be clipped according to the env's/Policy's action space? If None, use self.config["clip_actions"].

None

Returns:

Type Description
any

The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy.

Source code in ray/rllib/agents/trainer.py
@PublicAPI
def compute_actions(
        self,
        observations: TensorStructType,
        state: Optional[List[TensorStructType]] = None,
        *,
        prev_action: Optional[TensorStructType] = None,
        prev_reward: Optional[TensorStructType] = None,
        info: Optional[EnvInfoDict] = None,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
        full_fetch: bool = False,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        episodes: Optional[List[Episode]] = None,
        unsquash_actions: Optional[bool] = None,
        clip_actions: Optional[bool] = None,
        # Deprecated.
        normalize_actions=None,
        **kwargs,
):
    """Computes an action for the specified policy on the local Worker.

    Note that you can also access the policy object through
    self.get_policy(policy_id) and call compute_actions() on it directly.

    Args:
        observation: Observation from the environment.
        state: RNN hidden state, if any. If state is not None,
            then all of compute_single_action(...) is returned
            (computed action, rnn state(s), logits dictionary).
            Otherwise compute_single_action(...)[0] is returned
            (computed action).
        prev_action: Previous action value, if any.
        prev_reward: Previous reward, if any.
        info: Env info dict, if any.
        policy_id: Policy to query (only applies to multi-agent).
        full_fetch: Whether to return extra action fetch results.
            This is always set to True if RNN state is specified.
        explore: Whether to pick an exploitation or exploration
            action (default: None -> use self.config["explore"]).
        timestep: The current (sampling) time step.
        episodes: This provides access to all of the internal episodes'
            state, which may be useful for model-based or multi-agent
            algorithms.
        unsquash_actions: Should actions be unsquashed according
            to the env's/Policy's action space? If None, use
            self.config["normalize_actions"].
        clip_actions: Should actions be clipped according to the
            env's/Policy's action space? If None, use
            self.config["clip_actions"].

    Keyword Args:
        kwargs: forward compatibility placeholder

    Returns:
        any: The computed action if full_fetch=False, or
        tuple: The full output of policy.compute_actions() if
        full_fetch=True or we have an RNN-based Policy.
    """
    if normalize_actions is not None:
        deprecation_warning(
            old="Trainer.compute_actions(`normalize_actions`=...)",
            new="Trainer.compute_actions(`unsquash_actions`=...)",
            error=False)
        unsquash_actions = normalize_actions

    # Preprocess obs and states.
    state_defined = state is not None
    policy = self.get_policy(policy_id)
    filtered_obs, filtered_state = [], []
    for agent_id, ob in observations.items():
        worker = self.workers.local_worker()
        preprocessed = worker.preprocessors[policy_id].transform(ob)
        filtered = worker.filters[policy_id](preprocessed, update=False)
        filtered_obs.append(filtered)
        if state is None:
            continue
        elif agent_id in state:
            filtered_state.append(state[agent_id])
        else:
            filtered_state.append(policy.get_initial_state())

    # Batch obs and states
    obs_batch = np.stack(filtered_obs)
    if state is None:
        state = []
    else:
        state = list(zip(*filtered_state))
        state = [np.stack(s) for s in state]

    input_dict = {SampleBatch.OBS: obs_batch}
    if prev_action:
        input_dict[SampleBatch.PREV_ACTIONS] = prev_action
    if prev_reward:
        input_dict[SampleBatch.PREV_REWARDS] = prev_reward
    if info:
        input_dict[SampleBatch.INFOS] = info
    for i, s in enumerate(state):
        input_dict[f"state_in_{i}"] = s

    # Batch compute actions
    actions, states, infos = policy.compute_actions_from_input_dict(
        input_dict=input_dict,
        explore=explore,
        timestep=timestep,
        episodes=episodes,
    )

    # Unbatch actions for the environment into a multi-agent dict.
    single_actions = space_utils.unbatch(actions)
    actions = {}
    for key, a in zip(observations, single_actions):
        # If we work in normalized action space (normalize_actions=True),
        # we re-translate here into the env's action space.
        if unsquash_actions:
            a = space_utils.unsquash_action(a, policy.action_space_struct)
        # Clip, according to env's action space.
        elif clip_actions:
            a = space_utils.clip_action(a, policy.action_space_struct)
        actions[key] = a

    # Unbatch states into a multi-agent dict.
    unbatched_states = {}
    for idx, agent_id in enumerate(observations):
        unbatched_states[agent_id] = [s[idx] for s in states]

    # Return only actions or full tuple
    if state_defined or full_fetch:
        return actions, unbatched_states, infos
    else:
        return actions

compute_single_action(self, observation=None, state=None, *, prev_action=None, prev_reward=None, info=None, input_dict=None, policy_id='default_policy', full_fetch=False, explore=None, timestep=None, episode=None, unsquash_action=None, clip_action=None, unsquash_actions=-1, clip_actions=-1, **kwargs)

Computes an action for the specified policy on the local worker.

Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it directly.

Parameters:

Name Type Description Default
observation Union[Any, dict, tuple]

Single (unbatched) observation from the environment.

None
state Optional[List[Union[Any, dict, tuple]]]

List of all RNN hidden (single, unbatched) state tensors.

None
prev_action Union[Any, dict, tuple]

Single (unbatched) previous action value.

None
prev_reward Optional[float]

Single (unbatched) previous reward value.

None
info Optional[dict]

Env info dict, if any.

None
input_dict Optional[ray.rllib.policy.sample_batch.SampleBatch]

An optional SampleBatch that holds all the values for: obs, state, prev_action, and prev_reward, plus maybe custom defined views of the current env trajectory. Note that only one of obs or input_dict must be non-None.

None
policy_id str

Policy to query (only applies to multi-agent). Default: "default_policy".

'default_policy'
full_fetch bool

Whether to return extra action fetch results. This is always set to True if state is specified.

False
explore Optional[bool]

Whether to apply exploration to the action. Default: None -> use self.config["explore"].

None
timestep Optional[int]

The current (sampling) time step.

None
episode Optional[ray.rllib.evaluation.episode.Episode]

This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms.

None
unsquash_action Optional[bool]

Should actions be unsquashed according to the env's/Policy's action space? If None, use the value of self.config["normalize_actions"].

None
clip_action Optional[bool]

Should actions be clipped according to the env's/Policy's action space? If None, use the value of self.config["clip_actions"].

None

Returns:

Type Description
Union[Any, dict, tuple, Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]]]

The computed action if full_fetch=False, or a tuple of a) the full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy.

Exceptions:

Type Description
KeyError

If the policy_id cannot be found in this Trainer's local worker.

Source code in ray/rllib/agents/trainer.py
@PublicAPI
def compute_single_action(
        self,
        observation: Optional[TensorStructType] = None,
        state: Optional[List[TensorStructType]] = None,
        *,
        prev_action: Optional[TensorStructType] = None,
        prev_reward: Optional[float] = None,
        info: Optional[EnvInfoDict] = None,
        input_dict: Optional[SampleBatch] = None,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
        full_fetch: bool = False,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        episode: Optional[Episode] = None,
        unsquash_action: Optional[bool] = None,
        clip_action: Optional[bool] = None,

        # Deprecated args.
        unsquash_actions=DEPRECATED_VALUE,
        clip_actions=DEPRECATED_VALUE,

        # Kwargs placeholder for future compatibility.
        **kwargs,
) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType],
                                   Dict[str, TensorType]]]:
    """Computes an action for the specified policy on the local worker.

    Note that you can also access the policy object through
    self.get_policy(policy_id) and call compute_single_action() on it
    directly.

    Args:
        observation: Single (unbatched) observation from the
            environment.
        state: List of all RNN hidden (single, unbatched) state tensors.
        prev_action: Single (unbatched) previous action value.
        prev_reward: Single (unbatched) previous reward value.
        info: Env info dict, if any.
        input_dict: An optional SampleBatch that holds all the values
            for: obs, state, prev_action, and prev_reward, plus maybe
            custom defined views of the current env trajectory. Note
            that only one of `obs` or `input_dict` must be non-None.
        policy_id: Policy to query (only applies to multi-agent).
            Default: "default_policy".
        full_fetch: Whether to return extra action fetch results.
            This is always set to True if `state` is specified.
        explore: Whether to apply exploration to the action.
            Default: None -> use self.config["explore"].
        timestep: The current (sampling) time step.
        episode: This provides access to all of the internal episodes'
            state, which may be useful for model-based or multi-agent
            algorithms.
        unsquash_action: Should actions be unsquashed according to the
            env's/Policy's action space? If None, use the value of
            self.config["normalize_actions"].
        clip_action: Should actions be clipped according to the
            env's/Policy's action space? If None, use the value of
            self.config["clip_actions"].

    Keyword Args:
        kwargs: forward compatibility placeholder

    Returns:
        The computed action if full_fetch=False, or a tuple of a) the
            full output of policy.compute_actions() if full_fetch=True
            or we have an RNN-based Policy.

    Raises:
        KeyError: If the `policy_id` cannot be found in this Trainer's
            local worker.
    """
    if clip_actions != DEPRECATED_VALUE:
        deprecation_warning(
            old="Trainer.compute_single_action(`clip_actions`=...)",
            new="Trainer.compute_single_action(`clip_action`=...)",
            error=False)
        clip_action = clip_actions
    if unsquash_actions != DEPRECATED_VALUE:
        deprecation_warning(
            old="Trainer.compute_single_action(`unsquash_actions`=...)",
            new="Trainer.compute_single_action(`unsquash_action`=...)",
            error=False)
        unsquash_action = unsquash_actions

    # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
    # are all None.
    err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \
              "args to Trainer.compute_single_action!"
    if input_dict is not None:
        assert observation is None and prev_action is None and \
               prev_reward is None and state is None, err_msg
        observation = input_dict[SampleBatch.OBS]
    else:
        assert observation is not None, err_msg

    # Get the policy to compute the action for (in the multi-agent case,
    # Trainer may hold >1 policies).
    policy = self.get_policy(policy_id)
    if policy is None:
        raise KeyError(
            f"PolicyID '{policy_id}' not found in PolicyMap of the "
            f"Trainer's local worker!")
    local_worker = self.workers.local_worker()

    # Check the preprocessor and preprocess, if necessary.
    pp = local_worker.preprocessors[policy_id]
    if pp and type(pp).__name__ != "NoPreprocessor":
        observation = pp.transform(observation)
    observation = local_worker.filters[policy_id](
        observation, update=False)

    # Input-dict.
    if input_dict is not None:
        input_dict[SampleBatch.OBS] = observation
        action, state, extra = policy.compute_single_action(
            input_dict=input_dict,
            explore=explore,
            timestep=timestep,
            episode=episode,
        )
    # Individual args.
    else:
        action, state, extra = policy.compute_single_action(
            obs=observation,
            state=state,
            prev_action=prev_action,
            prev_reward=prev_reward,
            info=info,
            explore=explore,
            timestep=timestep,
            episode=episode,
        )

    # If we work in normalized action space (normalize_actions=True),
    # we re-translate here into the env's action space.
    if unsquash_action:
        action = space_utils.unsquash_action(action,
                                             policy.action_space_struct)
    # Clip, according to env's action space.
    elif clip_action:
        action = space_utils.clip_action(action,
                                         policy.action_space_struct)

    # Return 3-Tuple: Action, states, and extra-action fetches.
    if state or full_fetch:
        return action, state, extra
    # Ensure backward compatibility.
    else:
        return action

default_resource_request(config) classmethod

Provides a static resource requirement for the given configuration.

This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.

.. code-block:: python

@classmethod
def default_resource_request(cls, config):
    return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])

Parameters:

Name Type Description Default
config[Dict[str, Any]]

The Trainable's config dict.

required

Returns:

Type Description
Union[Resources, PlacementGroupFactory]

A Resources object or PlacementGroupFactory consumed by Tune for queueing.

Source code in ray/rllib/agents/trainer.py
@classmethod
@override(Trainable)
def default_resource_request(
        cls, config: PartialTrainerConfigDict) -> \
        Union[Resources, PlacementGroupFactory]:

    # Default logic for RLlib algorithms (Trainers):
    # Create one bundle per individual worker (local or remote).
    # Use `num_cpus_for_driver` and `num_gpus` for the local worker and
    # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
    # workers to determine their CPU/GPU resource needs.

    # Convenience config handles.
    cf = dict(cls.get_default_config(), **config)
    eval_cf = cf["evaluation_config"]

    # TODO(ekl): add custom resources here once tune supports them
    # Return PlacementGroupFactory containing all needed resources
    # (already properly defined as device bundles).
    return PlacementGroupFactory(
        bundles=[{
            # Local worker.
            "CPU": cf["num_cpus_for_driver"],
            "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
        }] + [
            {
                # RolloutWorkers.
                "CPU": cf["num_cpus_per_worker"],
                "GPU": cf["num_gpus_per_worker"],
            } for _ in range(cf["num_workers"])
        ] + ([
            {
                # Evaluation workers.
                # Note: The local eval worker is located on the driver CPU.
                "CPU": eval_cf.get("num_cpus_per_worker",
                                   cf["num_cpus_per_worker"]),
                "GPU": eval_cf.get("num_gpus_per_worker",
                                   cf["num_gpus_per_worker"]),
            } for _ in range(cf["evaluation_num_workers"])
        ] if cf["evaluation_interval"] else []),
        strategy=config.get("placement_strategy", "PACK"))

evaluate(self, episodes_left_fn=None)

Evaluates current policy under evaluation_config settings.

Note that this default implementation does not do anything beyond merging evaluation_config with the normal trainer config.

Parameters:

Name Type Description Default
episodes_left_fn Optional[Callable[[int], int]]

An optional callable taking the already run num episodes as only arg and returning the number of episodes left to run. It's used to find out whether evaluation should continue.

None
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None
             ) -> dict:
    """Evaluates current policy under `evaluation_config` settings.

    Note that this default implementation does not do anything beyond
    merging evaluation_config with the normal trainer config.

    Args:
        episodes_left_fn: An optional callable taking the already run
            num episodes as only arg and returning the number of
            episodes left to run. It's used to find out whether
            evaluation should continue.
    """
    # In case we are evaluating (in a thread) parallel to training,
    # we may have to re-enable eager mode here (gets disabled in the
    # thread).
    if self.config.get("framework") in ["tf2", "tfe"] and \
            not tf.executing_eagerly():
        tf1.enable_eager_execution()

    # Call the `_before_evaluate` hook.
    self._before_evaluate()

    # Sync weights to the evaluation WorkerSet.
    if self.evaluation_workers is not None:
        self._sync_weights_to_workers(worker_set=self.evaluation_workers)
        self._sync_filters_if_needed(self.evaluation_workers)

    if self.config["custom_eval_function"]:
        logger.info("Running custom eval function {}".format(
            self.config["custom_eval_function"]))
        metrics = self.config["custom_eval_function"](
            self, self.evaluation_workers)
        if not metrics or not isinstance(metrics, dict):
            raise ValueError("Custom eval function must return "
                             "dict of metrics, got {}.".format(metrics))
    else:
        # How many episodes do we need to run?
        # In "auto" mode (only for parallel eval + training): Run one
        # episode per eval worker.
        num_episodes = self.config["evaluation_num_episodes"] if \
            self.config["evaluation_num_episodes"] != "auto" else \
            (self.config["evaluation_num_workers"] or 1)

        # Default done-function returns True, whenever num episodes
        # have been completed.
        if episodes_left_fn is None:

            def episodes_left_fn(num_episodes_done):
                return num_episodes - num_episodes_done

        logger.info(
            f"Evaluating current policy for {num_episodes} episodes.")

        metrics = None
        # No evaluation worker set ->
        # Do evaluation using the local worker. Expect error due to the
        # local worker not having an env.
        if self.evaluation_workers is None:
            try:
                for _ in range(num_episodes):
                    self.workers.local_worker().sample()
                metrics = collect_metrics(self.workers.local_worker())
            except ValueError as e:
                if "RolloutWorker has no `input_reader` object" in \
                        e.args[0]:
                    raise ValueError(
                        "Cannot evaluate w/o an evaluation worker set in "
                        "the Trainer or w/o an env on the local worker!\n"
                        "Try one of the following:\n1) Set "
                        "`evaluation_interval` >= 0 to force creating a "
                        "separate evaluation worker set.\n2) Set "
                        "`create_env_on_driver=True` to force the local "
                        "(non-eval) worker to have an environment to "
                        "evaluate on.")
                else:
                    raise e

        # Evaluation worker set only has local worker.
        elif self.config["evaluation_num_workers"] == 0:
            for _ in range(num_episodes):
                self.evaluation_workers.local_worker().sample()

        # Evaluation worker set has n remote workers.
        else:
            # How many episodes have we run (across all eval workers)?
            num_episodes_done = 0
            round_ = 0
            while True:
                episodes_left_to_do = episodes_left_fn(num_episodes_done)
                if episodes_left_to_do <= 0:
                    break

                round_ += 1
                batches = ray.get([
                    w.sample.remote() for i, w in enumerate(
                        self.evaluation_workers.remote_workers())
                    if i < episodes_left_to_do
                ])
                # Per our config for the evaluation workers
                # (`rollout_fragment_length=1` and
                # `batch_mode=complete_episode`), we know that we'll have
                # exactly one episode per returned batch.
                num_episodes_done += len(batches)
                logger.info(
                    f"Ran round {round_} of parallel evaluation "
                    f"({num_episodes_done}/{num_episodes} episodes done)")
        if metrics is None:
            metrics = collect_metrics(
                self.evaluation_workers.local_worker(),
                self.evaluation_workers.remote_workers())
    return {"evaluation": metrics}

export_policy_checkpoint(self, export_dir, filename_prefix='model', policy_id='default_policy')

Exports policy model checkpoint to a local directory.

Parameters:

Name Type Description Default
export_dir str

Writable local directory.

required
filename_prefix str

file name prefix of checkpoint files.

'model'
policy_id str

Optional policy id to export.

'default_policy'

Examples:

>>> trainer = MyTrainer()
>>> for _ in range(10):
>>>     trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def export_policy_checkpoint(
        self,
        export_dir: str,
        filename_prefix: str = "model",
        policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
    """Exports policy model checkpoint to a local directory.

    Args:
        export_dir: Writable local directory.
        filename_prefix: file name prefix of checkpoint files.
        policy_id: Optional policy id to export.

    Example:
        >>> trainer = MyTrainer()
        >>> for _ in range(10):
        >>>     trainer.train()
        >>> trainer.export_policy_checkpoint("/tmp/export_dir")
    """
    self.get_policy(policy_id).export_checkpoint(export_dir,
                                                 filename_prefix)

export_policy_model(self, export_dir, policy_id='default_policy', onnx=None)

Exports policy model with given policy_id to a local directory.

Parameters:

Name Type Description Default
export_dir str

Writable local directory.

required
policy_id str

Optional policy id to export.

'default_policy'
onnx Optional[int]

If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. If None, the output format will be DL framework specific.

None

Examples:

>>> trainer = MyTrainer()
>>> for _ in range(10):
>>>     trainer.train()
>>> trainer.export_policy_model("/tmp/dir")
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def export_policy_model(self,
                        export_dir: str,
                        policy_id: PolicyID = DEFAULT_POLICY_ID,
                        onnx: Optional[int] = None) -> None:
    """Exports policy model with given policy_id to a local directory.

    Args:
        export_dir: Writable local directory.
        policy_id: Optional policy id to export.
        onnx: If given, will export model in ONNX format. The
            value of this parameter set the ONNX OpSet version to use.
            If None, the output format will be DL framework specific.

    Example:
        >>> trainer = MyTrainer()
        >>> for _ in range(10):
        >>>     trainer.train()
        >>> trainer.export_policy_model("/tmp/dir")
        >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
    """
    self.get_policy(policy_id).export_model(export_dir, onnx)

get_default_policy_class(self, config)

Returns a default Policy class to use, given a config.

This class will be used inside RolloutWorkers' PolicyMaps in case the policy class is not provided by the user in any single- or multi-agent PolicySpec.

This method is experimental and currently only used, iff the Trainer class was not created using the build_trainer utility and if the Trainer sub-class does not override _init() and create it's own WorkerSet in _init().

Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
def get_default_policy_class(self, config: PartialTrainerConfigDict):
    """Returns a default Policy class to use, given a config.

    This class will be used inside RolloutWorkers' PolicyMaps in case
    the policy class is not provided by the user in any single- or
    multi-agent PolicySpec.

    This method is experimental and currently only used, iff the Trainer
    class was not created using the `build_trainer` utility and if
    the Trainer sub-class does not override `_init()` and create it's
    own WorkerSet in `_init()`.
    """
    return getattr(self, "_policy_class", None)

get_policy(self, policy_id='default_policy')

Return policy for the specified id, or None.

Parameters:

Name Type Description Default
policy_id str

ID of the policy to return.

'default_policy'
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
    """Return policy for the specified id, or None.

    Args:
        policy_id: ID of the policy to return.
    """
    return self.workers.local_worker().get_policy(policy_id)

get_weights(self, policies=None)

Return a dictionary of policy ids to weights.

Parameters:

Name Type Description Default
policies Optional[List[str]]

Optional list of policies to return weights for, or None for all policies.

None
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
    """Return a dictionary of policy ids to weights.

    Args:
        policies: Optional list of policies to return weights for,
            or None for all policies.
    """
    return self.workers.local_worker().get_weights(policies)

import_model(self, import_file)

Imports a model from import_file.

Note: Currently, only h5 files are supported.

Parameters:

Name Type Description Default
import_file str

The file to import the model from.

required

Returns:

Type Description

A dict that maps ExportFormats to successfully exported models.

Source code in ray/rllib/agents/trainer.py
def import_model(self, import_file: str):
    """Imports a model from import_file.

    Note: Currently, only h5 files are supported.

    Args:
        import_file (str): The file to import the model from.

    Returns:
        A dict that maps ExportFormats to successfully exported models.
    """
    # Check for existence.
    if not os.path.exists(import_file):
        raise FileNotFoundError(
            "`import_file` '{}' does not exist! Can't import Model.".
            format(import_file))
    # Get the format of the given file.
    import_format = "h5"  # TODO(sven): Support checkpoint loading.

    ExportFormat.validate([import_format])
    if import_format != ExportFormat.H5:
        raise NotImplementedError
    else:
        return self.import_policy_model_from_h5(import_file)

import_policy_model_from_h5(self, import_file, policy_id='default_policy')

Imports a policy's model with given policy_id from a local h5 file.

Parameters:

Name Type Description Default
import_file str

The h5 file to import from.

required
policy_id str

Optional policy id to import into.

'default_policy'

Examples:

>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>>     trainer.train()
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def import_policy_model_from_h5(
        self,
        import_file: str,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
    """Imports a policy's model with given policy_id from a local h5 file.

    Args:
        import_file: The h5 file to import from.
        policy_id: Optional policy id to import into.

    Example:
        >>> trainer = MyTrainer()
        >>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
        >>> for _ in range(10):
        >>>     trainer.train()
    """
    self.get_policy(policy_id).import_model_from_h5(import_file)
    # Sync new weights to remote workers.
    self._sync_weights_to_workers(worker_set=self.workers)

load_checkpoint(self, checkpoint_path)

Subclasses should override this to implement restore().

!!! warning In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the example below.

.. code-block:: python

class Example(Trainable):
    def save_checkpoint(self, checkpoint_path):
        print(checkpoint_path)
        return os.path.join(checkpoint_path, "my/check/point")

    def load_checkpoint(self, checkpoint):
        print(checkpoint)

>>> trainer = Example()
>>> obj = trainer.save_to_object()  # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj)  # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point

.. versionadded:: 0.8.7

Parameters:

Name Type Description Default
checkpoint str|dict

If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.

required
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def load_checkpoint(self, checkpoint_path: str) -> None:
    extra_data = pickle.load(open(checkpoint_path, "rb"))
    self.__setstate__(extra_data)

log_result(self, result)

Subclasses can optionally override this to customize logging.

The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the loggers parameter in tune.run when overriding this function.

.. versionadded:: 0.8.7

Parameters:

Name Type Description Default
result dict

Training result returned by step().

required
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def log_result(self, result: ResultDict) -> None:
    # Log after the callback is invoked, so that the user has a chance
    # to mutate the result.
    self.callbacks.on_train_result(trainer=self, result=result)
    # Then log according to Trainable's logging logic.
    Trainable.log_result(self, result)

remove_policy(self, policy_id='default_policy', *, policy_mapping_fn=None, policies_to_train=None, evaluation_workers=True)

Removes a new policy from this Trainer.

Parameters:

Name Type Description Default
policy_id Optional[PolicyID]

ID of the policy to be removed.

'default_policy'
policy_mapping_fn Optional[Callable[[AgentID], PolicyID]]

An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

None
policies_to_train Optional[List[PolicyID]]

An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

None
evaluation_workers bool

Whether to also remove the policy from the evaluation WorkerSet.

True
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def remove_policy(
        self,
        policy_id: PolicyID = DEFAULT_POLICY_ID,
        *,
        policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
        policies_to_train: Optional[List[PolicyID]] = None,
        evaluation_workers: bool = True,
) -> None:
    """Removes a new policy from this Trainer.

    Args:
        policy_id (Optional[PolicyID]): ID of the policy to be removed.
        policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
            optional (updated) policy mapping function to use from here on.
            Note that already ongoing episodes will not change their
            mapping but will use the old mapping till the end of the
            episode.
        policies_to_train (Optional[List[PolicyID]]): An optional list of
            policy IDs to be trained. If None, will keep the existing list
            in place. Policies, whose IDs are not in the list will not be
            updated.
        evaluation_workers (bool): Whether to also remove the policy from
            the evaluation WorkerSet.
    """

    def fn(worker):
        worker.remove_policy(
            policy_id=policy_id,
            policy_mapping_fn=policy_mapping_fn,
            policies_to_train=policies_to_train,
        )

    self.workers.foreach_worker(fn)
    if evaluation_workers and self.evaluation_workers is not None:
        self.evaluation_workers.foreach_worker(fn)

resource_help(config) classmethod

Returns a help string for configuring this trainable's resources.

Parameters:

Name Type Description Default
config dict

The Trainer's config dict.

required
Source code in ray/rllib/agents/trainer.py
@classmethod
@override(Trainable)
def resource_help(cls, config: TrainerConfigDict) -> str:
    return ("\n\nYou can adjust the resource requests of RLlib agents by "
            "setting `num_workers`, `num_gpus`, and other configs. See "
            "the DEFAULT_CONFIG defined by each agent for more info.\n\n"
            "The config of this agent is: {}".format(config))

save_checkpoint(self, checkpoint_dir)

Subclasses should override this to implement save().

!!! warning Do not rely on absolute paths in the implementation of Trainable.save_checkpoint and Trainable.load_checkpoint.

Use validate_save_restore to catch Trainable.save_checkpoint/ Trainable.load_checkpoint errors before execution.

from ray.tune.utils import validate_save_restore validate_save_restore(MyTrainableClass) validate_save_restore(MyTrainableClass, use_object_store=True)

.. versionadded:: 0.8.7

Parameters:

Name Type Description Default
tmp_checkpoint_dir str

The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.

required

Returns:

Type Description
str

A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to Trainable.load_checkpoint().

Examples:

>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def save_checkpoint(self, checkpoint_dir: str) -> str:
    checkpoint_path = os.path.join(checkpoint_dir,
                                   "checkpoint-{}".format(self.iteration))
    pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))

    return checkpoint_path

set_weights(self, weights)

Set policy weights by policy id.

Parameters:

Name Type Description Default
weights Dict[str, dict]

Map of policy ids to weights to set.

required
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def set_weights(self, weights: Dict[PolicyID, dict]):
    """Set policy weights by policy id.

    Args:
        weights: Map of policy ids to weights to set.
    """
    self.workers.local_worker().set_weights(weights)

setup(self, config)

Subclasses should override this for custom initialization.

.. versionadded:: 0.8.7

Parameters:

Name Type Description Default
config dict

Hyperparameters and other configs given. Copy of self.config.

required
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def setup(self, config: PartialTrainerConfigDict):

    # Setup our config: Merge the user-supplied config (which could
    # be a partial config dict with the class' default).
    self.config = self.merge_trainer_configs(
        self.get_default_config(), config, self._allow_unknown_configs)

    # Setup the "env creator" callable.
    env = self._env_id
    if env:
        self.config["env"] = env

        # An already registered env.
        if _global_registry.contains(ENV_CREATOR, env):
            self.env_creator = _global_registry.get(ENV_CREATOR, env)

        # A class path specifier.
        elif "." in env:

            def env_creator_from_classpath(env_context):
                try:
                    env_obj = from_config(env, env_context)
                except ValueError:
                    raise EnvError(
                        ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env))
                return env_obj

            self.env_creator = env_creator_from_classpath
        # Try gym/PyBullet/Vizdoom.
        else:
            self.env_creator = functools.partial(
                gym_env_creator, env_descriptor=env)
    # No env -> Env creator always returns None.
    else:
        self.env_creator = lambda env_config: None

    # Check and resolve DL framework settings.
    # Tf-eager (tf2|tfe), possibly with tracing set to True. Recommend
    # setting tracing to True for speedups.
    if tf1 and self.config["framework"] in ["tf2", "tfe"]:
        if self.config["framework"] == "tf2" and tfv < 2:
            raise ValueError(
                "You configured `framework`=tf2, but your installed pip "
                "tf-version is < 2.0! Make sure your TensorFlow version "
                "is >= 2.x.")
        if not tf1.executing_eagerly():
            tf1.enable_eager_execution()
        logger.info(
            f"Executing eagerly (framework='{self.config['framework']}'),"
            f" with eager_tracing={self.config['eager_tracing']}. For "
            "production workloads, make sure to set `eager_tracing=True` "
            "in order to match the speed of tf-static-graph "
            "(framework='tf'). For debugging purposes, "
            "`eager_tracing=False` is the best choice.")
    # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
    # enabling eager tracing for similar speed.
    elif tf1 and self.config["framework"] == "tf":
        logger.info(
            "Your framework setting is 'tf', meaning you are using static"
            "-graph mode. Set framework='tf2' to enable eager execution "
            "with tf2.x. You may also want to then set "
            "`eager_tracing=True` in order to reach similar execution "
            "speed as with static-graph mode.")

    # Set Trainer's seed after we have - if necessary - enabled
    # tf eager-execution.
    update_global_seed_if_necessary(
        config.get("framework"), config.get("seed"))

    self._validate_config(self.config, trainer_obj_or_none=self)
    if not callable(self.config["callbacks"]):
        raise ValueError(
            "`callbacks` must be a callable method that "
            "returns a subclass of DefaultCallbacks, got {}".format(
                self.config["callbacks"]))
    self.callbacks = self.config["callbacks"]()
    log_level = self.config.get("log_level")
    if log_level in ["WARN", "ERROR"]:
        logger.info("Current log_level is {}. For more information, "
                    "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                    "-vv flags.".format(log_level))
    if self.config.get("log_level"):
        logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

    # Create local replay buffer if necessary.
    self.local_replay_buffer = (
        self._create_local_replay_buffer_if_necessary(self.config))

    # Deprecated way of implementing Trainer sub-classes (or "templates"
    # via the soon-to-be deprecated `build_trainer` utility function).
    # Instead, sub-classes should override the Trainable's `setup()`
    # method and call super().setup() from within that override at some
    # point.
    self.workers = None
    self.train_exec_impl = None

    # Old design: Override `Trainer._init` (or use `build_trainer()`, which
    # will do this for you).
    try:
        self._init(self.config, self.env_creator)
    # New design: Override `Trainable.setup()` (as indented by Trainable)
    # and do or don't call super().setup() from within your override.
    # By default, `super().setup()` will create both worker sets:
    # "rollout workers" for collecting samples for training and - if
    # applicable - "evaluation workers" for evaluation runs in between or
    # parallel to training.
    # TODO: Deprecate `_init()` and remove this try/except block.
    except NotImplementedError:
        # Only if user did not override `_init()`:
        # - Create rollout workers here automatically.
        # - Run the execution plan to create the local iterator to `next()`
        #   in each training iteration.
        # This matches the behavior of using `build_trainer()`, which
        # should no longer be used.
        self.workers = self._make_workers(
            env_creator=self.env_creator,
            validate_env=self.validate_env,
            policy_class=self.get_default_policy_class(self.config),
            config=self.config,
            num_workers=self.config["num_workers"])
        self.train_exec_impl = self.execution_plan(
            self.workers, self.config, **self._kwargs_for_execution_plan())

    # Evaluation WorkerSet setup.
    self.evaluation_workers = None
    self.evaluation_metrics = {}
    # User would like to setup a separate evaluation worker set.
    if self.config.get("evaluation_num_workers", 0) > 0 or \
            self.config.get("evaluation_interval"):
        # Update env_config with evaluation settings:
        extra_config = copy.deepcopy(self.config["evaluation_config"])
        # Assert that user has not unset "in_evaluation".
        assert "in_evaluation" not in extra_config or \
            extra_config["in_evaluation"] is True
        evaluation_config = merge_dicts(self.config, extra_config)
        # Validate evaluation config.
        self._validate_config(evaluation_config, trainer_obj_or_none=self)
        # Switch on complete_episode rollouts (evaluations are
        # always done on n complete episodes) and set the
        # `in_evaluation` flag. Also, make sure our rollout fragments
        # are short so we don't have more than one episode in one rollout.
        evaluation_config.update({
            "batch_mode": "complete_episodes",
            "rollout_fragment_length": 1,
            "in_evaluation": True,
        })
        logger.debug("using evaluation_config: {}".format(extra_config))
        # Create a separate evaluation worker set for evaluation.
        # If evaluation_num_workers=0, use the evaluation set's local
        # worker for evaluation, otherwise, use its remote workers
        # (parallelized evaluation).
        self.evaluation_workers = self._make_workers(
            env_creator=self.env_creator,
            validate_env=None,
            policy_class=self.get_default_policy_class(self.config),
            config=evaluation_config,
            num_workers=self.config["evaluation_num_workers"])

step(self)

Implements the main Trainer.train() logic.

Takes n attempts to perform a single training step. Thereby catches RayErrors resulting from worker failures. After n attempts, fails gracefully.

Override this method in your Trainer sub-classes if you would like to handle worker failures yourself. Otherwise, override self.step_attempt() to keep the n attempts (catch worker failures).

Returns:

Type Description
dict

The results dict with stats/infos on sampling, training, and - if required - evaluation.

Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def step(self) -> ResultDict:
    """Implements the main `Trainer.train()` logic.

    Takes n attempts to perform a single training step. Thereby
    catches RayErrors resulting from worker failures. After n attempts,
    fails gracefully.

    Override this method in your Trainer sub-classes if you would like to
    handle worker failures yourself. Otherwise, override
    `self.step_attempt()` to keep the n attempts (catch worker failures).

    Returns:
        The results dict with stats/infos on sampling, training,
        and - if required - evaluation.
    """
    result = None
    for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
        # Try to train one step.
        try:
            result = self.step_attempt()
        # @ray.remote RolloutWorker failure -> Try to recover,
        # if necessary.
        except RayError as e:
            if self.config["ignore_worker_failures"]:
                logger.exception(
                    "Error in train call, attempting to recover")
                self.try_recover_from_step_attempt()
            else:
                logger.info(
                    "Worker crashed during call to train(). To attempt to "
                    "continue training without the failed worker, set "
                    "`'ignore_worker_failures': True`.")
                raise e
        # Any other exception.
        except Exception as e:
            # Allow logs messages to propagate.
            time.sleep(0.5)
            raise e
        else:
            break

    # Still no result (even after n retries).
    if result is None:
        raise RuntimeError("Failed to recover from worker crash.")

    if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
        self._sync_filters_if_needed(self.workers)

    return result

step_attempt(self)

Attempts a single training step, including evaluation, if required.

Override this method in your Trainer sub-classes if you would like to keep the n attempts (catch worker failures) or override step() directly if you would like to handle worker failures yourself.

Returns:

Type Description
dict

The results dict with stats/infos on sampling, training, and - if required - evaluation.

Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
def step_attempt(self) -> ResultDict:
    """Attempts a single training step, including evaluation, if required.

    Override this method in your Trainer sub-classes if you would like to
    keep the n attempts (catch worker failures) or override `step()`
    directly if you would like to handle worker failures yourself.

    Returns:
        The results dict with stats/infos on sampling, training,
        and - if required - evaluation.
    """

    # self._iteration gets incremented after this function returns,
    # meaning that e. g. the first time this function is called,
    # self._iteration will be 0.
    evaluate_this_iter = \
        self.config["evaluation_interval"] and \
        (self._iteration + 1) % self.config["evaluation_interval"] == 0

    # No evaluation necessary, just run the next training iteration.
    if not evaluate_this_iter:
        step_results = next(self.train_exec_impl)
    # We have to evaluate in this training iteration.
    else:
        # No parallelism.
        if not self.config["evaluation_parallel_to_training"]:
            step_results = next(self.train_exec_impl)

        # Kick off evaluation-loop (and parallel train() call,
        # if requested).
        # Parallel eval + training.
        if self.config["evaluation_parallel_to_training"]:
            with concurrent.futures.ThreadPoolExecutor() as executor:
                train_future = executor.submit(
                    lambda: next(self.train_exec_impl))
                if self.config["evaluation_num_episodes"] == "auto":

                    # Run at least one `evaluate()` (num_episodes_done
                    # must be > 0), even if the training is very fast.
                    def episodes_left_fn(num_episodes_done):
                        if num_episodes_done > 0 and \
                                train_future.done():
                            return 0
                        else:
                            return self.config["evaluation_num_workers"]

                    evaluation_metrics = self.evaluate(
                        episodes_left_fn=episodes_left_fn)
                else:
                    evaluation_metrics = self.evaluate()
                # Collect the training results from the future.
                step_results = train_future.result()
        # Sequential: train (already done above), then eval.
        else:
            evaluation_metrics = self.evaluate()

        # Add evaluation results to train results.
        assert isinstance(evaluation_metrics, dict), \
            "Trainer.evaluate() needs to return a dict."
        step_results.update(evaluation_metrics)

    # Check `env_task_fn` for possible update of the env's task.
    if self.config["env_task_fn"] is not None:
        if not callable(self.config["env_task_fn"]):
            raise ValueError(
                "`env_task_fn` must be None or a callable taking "
                "[train_results, env, env_ctx] as args!")

        def fn(env, env_context, task_fn):
            new_task = task_fn(step_results, env, env_context)
            cur_task = env.get_task()
            if cur_task != new_task:
                env.set_task(new_task)

        fn = functools.partial(fn, task_fn=self.config["env_task_fn"])
        self.workers.foreach_env_with_context(fn)

    return step_results

try_recover_from_step_attempt(self)

Try to identify and remove any unhealthy workers.

This method is called after an unexpected remote error is encountered from a worker during the call to self.step_attempt() (within self.step()). It issues check requests to all current workers and removes any that respond with error. If no healthy workers remain, an error is raised. Otherwise, tries to re-build the execution plan with the remaining (healthy) workers.

Source code in ray/rllib/agents/trainer.py
def try_recover_from_step_attempt(self) -> None:
    """Try to identify and remove any unhealthy workers.

    This method is called after an unexpected remote error is encountered
    from a worker during the call to `self.step_attempt()` (within
    `self.step()`). It issues check requests to all current workers and
    removes any that respond with error. If no healthy workers remain,
    an error is raised. Otherwise, tries to re-build the execution plan
    with the remaining (healthy) workers.
    """

    workers = getattr(self, "workers", None)
    if not isinstance(workers, WorkerSet):
        return

    logger.info("Health checking all workers...")
    checks = []
    for ev in workers.remote_workers():
        _, obj_ref = ev.sample_with_count.remote()
        checks.append(obj_ref)

    healthy_workers = []
    for i, obj_ref in enumerate(checks):
        w = workers.remote_workers()[i]
        try:
            ray.get(obj_ref)
            healthy_workers.append(w)
            logger.info("Worker {} looks healthy".format(i + 1))
        except RayError:
            logger.exception("Removing unhealthy worker {}".format(i + 1))
            try:
                w.__ray_terminate__.remote()
            except Exception:
                logger.exception("Error terminating unhealthy worker")

    if len(healthy_workers) < 1:
        raise RuntimeError(
            "Not enough healthy workers remain to continue.")

    logger.warning("Recreating execution plan after failure.")
    workers.reset(healthy_workers)
    if self.train_exec_impl is not None:
        if callable(self.execution_plan):
            self.train_exec_impl = self.execution_plan(
                workers, self.config, **self._kwargs_for_execution_plan())

validate_env(env, env_context) staticmethod

Env validator function for this Trainer class.

Override this in child classes to define custom validation behavior.

Parameters:

Name Type Description Default
env Any

The (sub-)environment to validate. This is normally a single sub-environment (e.g. a gym.Env) within a vectorized setup.

required
env_context EnvContext

The EnvContext to configure the environment.

required
Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
@staticmethod
def validate_env(env: EnvType, env_context: EnvContext) -> None:
    """Env validator function for this Trainer class.

    Override this in child classes to define custom validation
    behavior.

    Args:
        env: The (sub-)environment to validate. This is normally a
            single sub-environment (e.g. a gym.Env) within a vectorized
            setup.
        env_context: The EnvContext to configure the environment.

    Raises:
        Exception in case something is wrong with the given environment.
    """
    pass

ray.rllib.agents.callbacks.DefaultCallbacks

Abstract base class for RLlib callbacks (similar to Keras callbacks).

These callbacks can be used for custom metrics and custom postprocessing.

By default, all of these callbacks are no-ops. To configure custom training callbacks, subclass DefaultCallbacks and then set {"callbacks": YourCallbacksClass} in the trainer config.

on_episode_end(self, *, worker, base_env, policies, episode, **kwargs)

Runs when an episode is done.

Parameters:

Name Type Description Default
worker RolloutWorker

Reference to the current rollout worker.

required
base_env BaseEnv

BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling base_env.get_sub_environments().

required
policies Dict[str, ray.rllib.policy.policy.Policy]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy".

required
episode Episode

Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                   policies: Dict[PolicyID, Policy], episode: Episode,
                   **kwargs) -> None:
    """Runs when an episode is done.

    Args:
        worker: Reference to the current rollout worker.
        base_env: BaseEnv running the episode. The underlying
            sub environment objects can be retrieved by calling
            `base_env.get_sub_environments()`.
        policies: Mapping of policy id to policy
            objects. In single agent mode there will only be a single
            "default_policy".
        episode: Episode object which contains episode
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_episode_end"):
        self.legacy_callbacks["on_episode_end"]({
            "env": base_env,
            "policy": policies,
            "episode": episode,
        })

on_episode_start(self, *, worker, base_env, policies, episode, **kwargs)

Callback run on the rollout worker before each episode starts.

Parameters:

Name Type Description Default
worker RolloutWorker

Reference to the current rollout worker.

required
base_env BaseEnv

BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling base_env.get_sub_environments().

required
policies Dict[str, ray.rllib.policy.policy.Policy]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.

required
episode Episode

Episode object which contains the episode's state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                     policies: Dict[PolicyID, Policy], episode: Episode,
                     **kwargs) -> None:
    """Callback run on the rollout worker before each episode starts.

    Args:
        worker: Reference to the current rollout worker.
        base_env: BaseEnv running the episode. The underlying
            sub environment objects can be retrieved by calling
            `base_env.get_sub_environments()`.
        policies: Mapping of policy id to policy objects. In single
            agent mode there will only be a single "default" policy.
        episode: Episode object which contains the episode's
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_episode_start"):
        self.legacy_callbacks["on_episode_start"]({
            "env": base_env,
            "policy": policies,
            "episode": episode,
        })

on_episode_step(self, *, worker, base_env, policies=None, episode, **kwargs)

Runs on each episode step.

Parameters:

Name Type Description Default
worker RolloutWorker

Reference to the current rollout worker.

required
base_env BaseEnv

BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling base_env.get_sub_environments().

required
policies Optional[Dict[str, ray.rllib.policy.policy.Policy]]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy".

None
episode Episode

Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_episode_step(self,
                    *,
                    worker: "RolloutWorker",
                    base_env: BaseEnv,
                    policies: Optional[Dict[PolicyID, Policy]] = None,
                    episode: Episode,
                    **kwargs) -> None:
    """Runs on each episode step.

    Args:
        worker: Reference to the current rollout worker.
        base_env: BaseEnv running the episode. The underlying
            sub environment objects can be retrieved by calling
            `base_env.get_sub_environments()`.
        policies: Mapping of policy id
            to policy objects. In single agent mode there will only be a
            single "default_policy".
        episode: Episode object which contains episode
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_episode_step"):
        self.legacy_callbacks["on_episode_step"]({
            "env": base_env,
            "episode": episode
        })

on_learn_on_batch(self, *, policy, train_batch, result, **kwargs)

Called at the beginning of Policy.learn_on_batch().

Note: This is called before 0-padding via pad_batch_to_sequences_of_same_size.

Parameters:

Name Type Description Default
policy Policy

Reference to the current Policy object.

required
train_batch SampleBatch

SampleBatch to be trained on. You can mutate this object to modify the samples generated.

required
result dict

A results dict to add custom metrics to.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
                      result: dict, **kwargs) -> None:
    """Called at the beginning of Policy.learn_on_batch().

    Note: This is called before 0-padding via
    `pad_batch_to_sequences_of_same_size`.

    Args:
        policy: Reference to the current Policy object.
        train_batch: SampleBatch to be trained on. You can
            mutate this object to modify the samples generated.
        result: A results dict to add custom metrics to.
        kwargs: Forward compatibility placeholder.
    """

    pass

on_postprocess_trajectory(self, *, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs)

Called immediately after a policy's postprocess_fn is called.

You can use this callback to do additional postprocessing for a policy, including looking at the trajectory data of other agents in multi-agent settings.

Parameters:

Name Type Description Default
worker RolloutWorker

Reference to the current rollout worker.

required
episode Episode

Episode object.

required
agent_id Any

Id of the current agent.

required
policy_id str

Id of the current policy for the agent.

required
policies Dict[str, ray.rllib.policy.policy.Policy]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy".

required
postprocessed_batch SampleBatch

The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing.

required
original_batches Dict[Any, ray.rllib.policy.sample_batch.SampleBatch]

Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_postprocess_trajectory(
        self, *, worker: "RolloutWorker", episode: Episode,
        agent_id: AgentID, policy_id: PolicyID,
        policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
        original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
    """Called immediately after a policy's postprocess_fn is called.

    You can use this callback to do additional postprocessing for a policy,
    including looking at the trajectory data of other agents in multi-agent
    settings.

    Args:
        worker: Reference to the current rollout worker.
        episode: Episode object.
        agent_id: Id of the current agent.
        policy_id: Id of the current policy for the agent.
        policies: Mapping of policy id to policy objects. In single
            agent mode there will only be a single "default_policy".
        postprocessed_batch: The postprocessed sample batch
            for this agent. You can mutate this object to apply your own
            trajectory postprocessing.
        original_batches: Mapping of agents to their unpostprocessed
            trajectory data. You should not mutate this object.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_postprocess_traj"):
        self.legacy_callbacks["on_postprocess_traj"]({
            "episode": episode,
            "agent_id": agent_id,
            "pre_batch": original_batches[agent_id],
            "post_batch": postprocessed_batch,
            "all_pre_batches": original_batches,
        })

on_sample_end(self, *, worker, samples, **kwargs)

Called at the end of RolloutWorker.sample().

Parameters:

Name Type Description Default
worker RolloutWorker

Reference to the current rollout worker.

required
samples SampleBatch

Batch to be returned. You can mutate this object to modify the samples generated.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
                  **kwargs) -> None:
    """Called at the end of RolloutWorker.sample().

    Args:
        worker: Reference to the current rollout worker.
        samples: Batch to be returned. You can mutate this
            object to modify the samples generated.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_sample_end"):
        self.legacy_callbacks["on_sample_end"]({
            "worker": worker,
            "samples": samples,
        })

on_train_result(self, *, trainer, result, **kwargs)

Called at the end of Trainable.train().

Parameters:

Name Type Description Default
trainer Trainer

Current trainer instance.

required
result dict

Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics.

required
kwargs

Forward compatibility placeholder.

{}
Source code in ray/rllib/agents/callbacks.py
def on_train_result(self, *, trainer: "Trainer", result: dict,
                    **kwargs) -> None:
    """Called at the end of Trainable.train().

    Args:
        trainer: Current trainer instance.
        result: Dict of results returned from trainer.train() call.
            You can mutate this object to add additional metrics.
        kwargs: Forward compatibility placeholder.
    """

    if self.legacy_callbacks.get("on_train_result"):
        self.legacy_callbacks["on_train_result"]({
            "trainer": trainer,
            "result": result,
        })
Back to top