agents
package
ray.rllib.agents.trainer.Trainer (Trainable)
An RLlib algorithm responsible for optimizing one or more Policies.
Trainers contain a WorkerSet under self.workers
. A WorkerSet is
normally composed of a single local worker
(self.workers.local_worker()), used to compute and apply learning updates,
and optionally one or more remote workers (self.workers.remote_workers()),
used to generate environment samples in parallel.
Each worker (remotes or local) contains a PolicyMap, which itself may contain either one policy for single-agent training or one or more policies for multi-agent training. Policies are synchronized automatically from time to time using ray.remote calls. The exact synchronization logic depends on the specific algorithm (Trainer) used, but this usually happens from local worker to all remote workers and after each training update.
You can write your own Trainer sub-classes by using the
rllib.agents.trainer_template.py::build_trainer() utility function.
This allows you to provide a custom execution_plan
. You can find the
different built-in algorithms' execution plans in their respective main
py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py.
The most important API methods a Trainer exposes are train()
,
evaluate()
, save()
and restore()
. Trainer objects retain internal
model state between calls to train(), so you should create a new
Trainer instance for each training session.
__init__(self, config=None, env=None, logger_creator=None, remote_checkpoint_dir=None, sync_function_tpl=None)
special
Initializes a Trainer instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Optional[dict] |
Algorithm-specific configuration dict. |
None |
env |
Union[str, Any] |
Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.env.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in |
None |
logger_creator |
Optional[Callable[[], ray.tune.logger.Logger]] |
Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. |
None |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def __init__(self,
config: Optional[PartialTrainerConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None):
"""Initializes a Trainer instance.
Args:
config: Algorithm-specific configuration dict.
env: Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.env.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
# User provided (partial) config (this may be w/o the default
# Trainer's `COMMON_CONFIG` (see above)). Will get merged with
# COMMON_CONFIG in self.setup().
config = config or {}
# Trainers allow env ids to be passed directly to the constructor.
self._env_id = self._register_if_needed(
env or config.get("env"), config)
# The env creator callable, taking an EnvContext (config dict)
# as arg and returning an RLlib supported Env type (e.g. a gym.Env).
self.env_creator: Callable[[EnvContext], EnvType] = None
# Placeholder for a local replay buffer instance.
self.local_replay_buffer = None
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
# Default logdir prefix containing the agent's name and the
# env id.
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
timestr)
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
# Allow users to more precisely configure the created logger
# via "logger_config.type".
if config.get(
"logger_config") and "type" in config["logger_config"]:
def default_logger_creator(config):
"""Creates a custom logger with the default prefix."""
cfg = config["logger_config"].copy()
cls = cfg.pop("type")
# Provide default for logdir, in case the user does
# not specify this in the "logger_config" dict.
logdir_ = cfg.pop("logdir", logdir)
return from_config(cls=cls, _args=[cfg], logdir=logdir_)
# If no `type` given, use tune's UnifiedLogger as last resort.
else:
def default_logger_creator(config):
"""Creates a Unified logger with the default prefix."""
return UnifiedLogger(config, logdir, loggers=None)
logger_creator = default_logger_creator
super().__init__(config, logger_creator, remote_checkpoint_dir,
sync_function_tpl)
add_policy(self, policy_id, policy_cls, *, observation_space=None, action_space=None, config=None, policy_mapping_fn=None, policies_to_train=None, evaluation_workers=True)
Adds a new policy to this Trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
PolicyID |
ID of the policy to add. |
required |
policy_cls |
Type[Policy] |
The Policy class to use for constructing the new Policy. |
required |
observation_space |
Optional[gym.spaces.Space] |
The observation space of the policy to add. |
None |
action_space |
Optional[gym.spaces.Space] |
The action space of the policy to add. |
None |
config |
Optional[PartialTrainerConfigDict] |
The config overrides for the policy to add. |
None |
policy_mapping_fn |
Optional[Callable[[AgentID], PolicyID]] |
An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. |
None |
policies_to_train |
Optional[List[PolicyID]] |
An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated. |
None |
evaluation_workers |
bool |
Whether to add the new policy also to the evaluation WorkerSet. |
True |
Returns:
Type | Description |
---|---|
Policy |
The newly added policy (the copy that got added to the local worker). |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def add_policy(
self,
policy_id: PolicyID,
policy_cls: Type[Policy],
*,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID],
PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
evaluation_workers: bool = True,
) -> Policy:
"""Adds a new policy to this Trainer.
Args:
policy_id (PolicyID): ID of the policy to add.
policy_cls (Type[Policy]): The Policy class to use for
constructing the new Policy.
observation_space (Optional[gym.spaces.Space]): The observation
space of the policy to add.
action_space (Optional[gym.spaces.Space]): The action space
of the policy to add.
config (Optional[PartialTrainerConfigDict]): The config overrides
for the policy to add.
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
optional (updated) policy mapping function to use from here on.
Note that already ongoing episodes will not change their
mapping but will use the old mapping till the end of the
episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
evaluation_workers (bool): Whether to add the new policy also
to the evaluation WorkerSet.
Returns:
Policy: The newly added policy (the copy that got added to the
local worker).
"""
def fn(worker: RolloutWorker):
# `foreach_worker` function: Adds the policy the the worker (and
# maybe changes its policy_mapping_fn - if provided here).
worker.add_policy(
policy_id=policy_id,
policy_cls=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
# Run foreach_worker fn on all workers (incl. evaluation workers).
self.workers.foreach_worker(fn)
if evaluation_workers and self.evaluation_workers is not None:
self.evaluation_workers.foreach_worker(fn)
# Return newly added policy (from the local rollout worker).
return self.get_policy(policy_id)
cleanup(self)
Subclasses should override this for any cleanup on stop.
If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.
You can kill a Ray actor by calling actor.__ray_terminate__.remote()
on the actor.
.. versionadded:: 0.8.7
collect_metrics(self, selected_workers=None)
Collects metrics from the remote workers of this agent.
This is the same data as returned by a call to train().
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def collect_metrics(self,
selected_workers: List[ActorHandle] = None) -> dict:
"""Collects metrics from the remote workers of this agent.
This is the same data as returned by a call to train().
"""
return self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"],
min_history=self.config["metrics_smoothing_episodes"],
selected_workers=selected_workers)
compute_actions(self, observations, state=None, *, prev_action=None, prev_reward=None, info=None, policy_id='default_policy', full_fetch=False, explore=None, timestep=None, episodes=None, unsquash_actions=None, clip_actions=None, normalize_actions=None, **kwargs)
Computes an action for the specified policy on the local Worker.
Note that you can also access the policy object through self.get_policy(policy_id) and call compute_actions() on it directly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
observation |
Observation from the environment. |
required | |
state |
Optional[List[Union[Any, dict, tuple]]] |
RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action). |
None |
prev_action |
Union[Any, dict, tuple] |
Previous action value, if any. |
None |
prev_reward |
Union[Any, dict, tuple] |
Previous reward, if any. |
None |
info |
Optional[dict] |
Env info dict, if any. |
None |
policy_id |
str |
Policy to query (only applies to multi-agent). |
'default_policy' |
full_fetch |
bool |
Whether to return extra action fetch results. This is always set to True if RNN state is specified. |
False |
explore |
Optional[bool] |
Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). |
None |
timestep |
Optional[int] |
The current (sampling) time step. |
None |
episodes |
Optional[List[ray.rllib.evaluation.episode.Episode]] |
This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms. |
None |
unsquash_actions |
Optional[bool] |
Should actions be unsquashed according to the env's/Policy's action space? If None, use self.config["normalize_actions"]. |
None |
clip_actions |
Optional[bool] |
Should actions be clipped according to the env's/Policy's action space? If None, use self.config["clip_actions"]. |
None |
Returns:
Type | Description |
---|---|
any |
The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def compute_actions(
self,
observations: TensorStructType,
state: Optional[List[TensorStructType]] = None,
*,
prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[TensorStructType] = None,
info: Optional[EnvInfoDict] = None,
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episodes: Optional[List[Episode]] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
# Deprecated.
normalize_actions=None,
**kwargs,
):
"""Computes an action for the specified policy on the local Worker.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Args:
observation: Observation from the environment.
state: RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state(s), logits dictionary).
Otherwise compute_single_action(...)[0] is returned
(computed action).
prev_action: Previous action value, if any.
prev_reward: Previous reward, if any.
info: Env info dict, if any.
policy_id: Policy to query (only applies to multi-agent).
full_fetch: Whether to return extra action fetch results.
This is always set to True if RNN state is specified.
explore: Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep: The current (sampling) time step.
episodes: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
unsquash_actions: Should actions be unsquashed according
to the env's/Policy's action space? If None, use
self.config["normalize_actions"].
clip_actions: Should actions be clipped according to the
env's/Policy's action space? If None, use
self.config["clip_actions"].
Keyword Args:
kwargs: forward compatibility placeholder
Returns:
any: The computed action if full_fetch=False, or
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
"""
if normalize_actions is not None:
deprecation_warning(
old="Trainer.compute_actions(`normalize_actions`=...)",
new="Trainer.compute_actions(`unsquash_actions`=...)",
error=False)
unsquash_actions = normalize_actions
# Preprocess obs and states.
state_defined = state is not None
policy = self.get_policy(policy_id)
filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items():
worker = self.workers.local_worker()
preprocessed = worker.preprocessors[policy_id].transform(ob)
filtered = worker.filters[policy_id](preprocessed, update=False)
filtered_obs.append(filtered)
if state is None:
continue
elif agent_id in state:
filtered_state.append(state[agent_id])
else:
filtered_state.append(policy.get_initial_state())
# Batch obs and states
obs_batch = np.stack(filtered_obs)
if state is None:
state = []
else:
state = list(zip(*filtered_state))
state = [np.stack(s) for s in state]
input_dict = {SampleBatch.OBS: obs_batch}
if prev_action:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action
if prev_reward:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward
if info:
input_dict[SampleBatch.INFOS] = info
for i, s in enumerate(state):
input_dict[f"state_in_{i}"] = s
# Batch compute actions
actions, states, infos = policy.compute_actions_from_input_dict(
input_dict=input_dict,
explore=explore,
timestep=timestep,
episodes=episodes,
)
# Unbatch actions for the environment into a multi-agent dict.
single_actions = space_utils.unbatch(actions)
actions = {}
for key, a in zip(observations, single_actions):
# If we work in normalized action space (normalize_actions=True),
# we re-translate here into the env's action space.
if unsquash_actions:
a = space_utils.unsquash_action(a, policy.action_space_struct)
# Clip, according to env's action space.
elif clip_actions:
a = space_utils.clip_action(a, policy.action_space_struct)
actions[key] = a
# Unbatch states into a multi-agent dict.
unbatched_states = {}
for idx, agent_id in enumerate(observations):
unbatched_states[agent_id] = [s[idx] for s in states]
# Return only actions or full tuple
if state_defined or full_fetch:
return actions, unbatched_states, infos
else:
return actions
compute_single_action(self, observation=None, state=None, *, prev_action=None, prev_reward=None, info=None, input_dict=None, policy_id='default_policy', full_fetch=False, explore=None, timestep=None, episode=None, unsquash_action=None, clip_action=None, unsquash_actions=-1, clip_actions=-1, **kwargs)
Computes an action for the specified policy on the local worker.
Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it directly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
observation |
Union[Any, dict, tuple] |
Single (unbatched) observation from the environment. |
None |
state |
Optional[List[Union[Any, dict, tuple]]] |
List of all RNN hidden (single, unbatched) state tensors. |
None |
prev_action |
Union[Any, dict, tuple] |
Single (unbatched) previous action value. |
None |
prev_reward |
Optional[float] |
Single (unbatched) previous reward value. |
None |
info |
Optional[dict] |
Env info dict, if any. |
None |
input_dict |
Optional[ray.rllib.policy.sample_batch.SampleBatch] |
An optional SampleBatch that holds all the values
for: obs, state, prev_action, and prev_reward, plus maybe
custom defined views of the current env trajectory. Note
that only one of |
None |
policy_id |
str |
Policy to query (only applies to multi-agent). Default: "default_policy". |
'default_policy' |
full_fetch |
bool |
Whether to return extra action fetch results.
This is always set to True if |
False |
explore |
Optional[bool] |
Whether to apply exploration to the action. Default: None -> use self.config["explore"]. |
None |
timestep |
Optional[int] |
The current (sampling) time step. |
None |
episode |
Optional[ray.rllib.evaluation.episode.Episode] |
This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms. |
None |
unsquash_action |
Optional[bool] |
Should actions be unsquashed according to the env's/Policy's action space? If None, use the value of self.config["normalize_actions"]. |
None |
clip_action |
Optional[bool] |
Should actions be clipped according to the env's/Policy's action space? If None, use the value of self.config["clip_actions"]. |
None |
Returns:
Type | Description |
---|---|
Union[Any, dict, tuple, Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]]] |
The computed action if full_fetch=False, or a tuple of a) the full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def compute_single_action(
self,
observation: Optional[TensorStructType] = None,
state: Optional[List[TensorStructType]] = None,
*,
prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[float] = None,
info: Optional[EnvInfoDict] = None,
input_dict: Optional[SampleBatch] = None,
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episode: Optional[Episode] = None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
# Deprecated args.
unsquash_actions=DEPRECATED_VALUE,
clip_actions=DEPRECATED_VALUE,
# Kwargs placeholder for future compatibility.
**kwargs,
) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType],
Dict[str, TensorType]]]:
"""Computes an action for the specified policy on the local worker.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_single_action() on it
directly.
Args:
observation: Single (unbatched) observation from the
environment.
state: List of all RNN hidden (single, unbatched) state tensors.
prev_action: Single (unbatched) previous action value.
prev_reward: Single (unbatched) previous reward value.
info: Env info dict, if any.
input_dict: An optional SampleBatch that holds all the values
for: obs, state, prev_action, and prev_reward, plus maybe
custom defined views of the current env trajectory. Note
that only one of `obs` or `input_dict` must be non-None.
policy_id: Policy to query (only applies to multi-agent).
Default: "default_policy".
full_fetch: Whether to return extra action fetch results.
This is always set to True if `state` is specified.
explore: Whether to apply exploration to the action.
Default: None -> use self.config["explore"].
timestep: The current (sampling) time step.
episode: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
unsquash_action: Should actions be unsquashed according to the
env's/Policy's action space? If None, use the value of
self.config["normalize_actions"].
clip_action: Should actions be clipped according to the
env's/Policy's action space? If None, use the value of
self.config["clip_actions"].
Keyword Args:
kwargs: forward compatibility placeholder
Returns:
The computed action if full_fetch=False, or a tuple of a) the
full output of policy.compute_actions() if full_fetch=True
or we have an RNN-based Policy.
Raises:
KeyError: If the `policy_id` cannot be found in this Trainer's
local worker.
"""
if clip_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`clip_actions`=...)",
new="Trainer.compute_single_action(`clip_action`=...)",
error=False)
clip_action = clip_actions
if unsquash_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`unsquash_actions`=...)",
new="Trainer.compute_single_action(`unsquash_action`=...)",
error=False)
unsquash_action = unsquash_actions
# User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
# are all None.
err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \
"args to Trainer.compute_single_action!"
if input_dict is not None:
assert observation is None and prev_action is None and \
prev_reward is None and state is None, err_msg
observation = input_dict[SampleBatch.OBS]
else:
assert observation is not None, err_msg
# Get the policy to compute the action for (in the multi-agent case,
# Trainer may hold >1 policies).
policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Trainer's local worker!")
local_worker = self.workers.local_worker()
# Check the preprocessor and preprocess, if necessary.
pp = local_worker.preprocessors[policy_id]
if pp and type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
observation = local_worker.filters[policy_id](
observation, update=False)
# Input-dict.
if input_dict is not None:
input_dict[SampleBatch.OBS] = observation
action, state, extra = policy.compute_single_action(
input_dict=input_dict,
explore=explore,
timestep=timestep,
episode=episode,
)
# Individual args.
else:
action, state, extra = policy.compute_single_action(
obs=observation,
state=state,
prev_action=prev_action,
prev_reward=prev_reward,
info=info,
explore=explore,
timestep=timestep,
episode=episode,
)
# If we work in normalized action space (normalize_actions=True),
# we re-translate here into the env's action space.
if unsquash_action:
action = space_utils.unsquash_action(action,
policy.action_space_struct)
# Clip, according to env's action space.
elif clip_action:
action = space_utils.clip_action(action,
policy.action_space_struct)
# Return 3-Tuple: Action, states, and extra-action fetches.
if state or full_fetch:
return action, state, extra
# Ensure backward compatibility.
else:
return action
default_resource_request(config)
classmethod
Provides a static resource requirement for the given configuration.
This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.
.. code-block:: python
@classmethod
def default_resource_request(cls, config):
return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config[Dict[str, |
Any]] |
The Trainable's config dict. |
required |
Returns:
Type | Description |
---|---|
Union[Resources, PlacementGroupFactory] |
A Resources object or PlacementGroupFactory consumed by Tune for queueing. |
Source code in ray/rllib/agents/trainer.py
@classmethod
@override(Trainable)
def default_resource_request(
cls, config: PartialTrainerConfigDict) -> \
Union[Resources, PlacementGroupFactory]:
# Default logic for RLlib algorithms (Trainers):
# Create one bundle per individual worker (local or remote).
# Use `num_cpus_for_driver` and `num_gpus` for the local worker and
# `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
# workers to determine their CPU/GPU resource needs.
# Convenience config handles.
cf = dict(cls.get_default_config(), **config)
eval_cf = cf["evaluation_config"]
# TODO(ekl): add custom resources here once tune supports them
# Return PlacementGroupFactory containing all needed resources
# (already properly defined as device bundles).
return PlacementGroupFactory(
bundles=[{
# Local worker.
"CPU": cf["num_cpus_for_driver"],
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
}] + [
{
# RolloutWorkers.
"CPU": cf["num_cpus_per_worker"],
"GPU": cf["num_gpus_per_worker"],
} for _ in range(cf["num_workers"])
] + ([
{
# Evaluation workers.
# Note: The local eval worker is located on the driver CPU.
"CPU": eval_cf.get("num_cpus_per_worker",
cf["num_cpus_per_worker"]),
"GPU": eval_cf.get("num_gpus_per_worker",
cf["num_gpus_per_worker"]),
} for _ in range(cf["evaluation_num_workers"])
] if cf["evaluation_interval"] else []),
strategy=config.get("placement_strategy", "PACK"))
evaluate(self, episodes_left_fn=None)
Evaluates current policy under evaluation_config
settings.
Note that this default implementation does not do anything beyond merging evaluation_config with the normal trainer config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
episodes_left_fn |
Optional[Callable[[int], int]] |
An optional callable taking the already run num episodes as only arg and returning the number of episodes left to run. It's used to find out whether evaluation should continue. |
None |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
Args:
episodes_left_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
episodes left to run. It's used to find out whether
evaluation should continue.
"""
# In case we are evaluating (in a thread) parallel to training,
# we may have to re-enable eager mode here (gets disabled in the
# thread).
if self.config.get("framework") in ["tf2", "tfe"] and \
not tf.executing_eagerly():
tf1.enable_eager_execution()
# Call the `_before_evaluate` hook.
self._before_evaluate()
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self._sync_weights_to_workers(worker_set=self.evaluation_workers)
self._sync_filters_if_needed(self.evaluation_workers)
if self.config["custom_eval_function"]:
logger.info("Running custom eval function {}".format(
self.config["custom_eval_function"]))
metrics = self.config["custom_eval_function"](
self, self.evaluation_workers)
if not metrics or not isinstance(metrics, dict):
raise ValueError("Custom eval function must return "
"dict of metrics, got {}.".format(metrics))
else:
# How many episodes do we need to run?
# In "auto" mode (only for parallel eval + training): Run one
# episode per eval worker.
num_episodes = self.config["evaluation_num_episodes"] if \
self.config["evaluation_num_episodes"] != "auto" else \
(self.config["evaluation_num_workers"] or 1)
# Default done-function returns True, whenever num episodes
# have been completed.
if episodes_left_fn is None:
def episodes_left_fn(num_episodes_done):
return num_episodes - num_episodes_done
logger.info(
f"Evaluating current policy for {num_episodes} episodes.")
metrics = None
# No evaluation worker set ->
# Do evaluation using the local worker. Expect error due to the
# local worker not having an env.
if self.evaluation_workers is None:
try:
for _ in range(num_episodes):
self.workers.local_worker().sample()
metrics = collect_metrics(self.workers.local_worker())
except ValueError as e:
if "RolloutWorker has no `input_reader` object" in \
e.args[0]:
raise ValueError(
"Cannot evaluate w/o an evaluation worker set in "
"the Trainer or w/o an env on the local worker!\n"
"Try one of the following:\n1) Set "
"`evaluation_interval` >= 0 to force creating a "
"separate evaluation worker set.\n2) Set "
"`create_env_on_driver=True` to force the local "
"(non-eval) worker to have an environment to "
"evaluate on.")
else:
raise e
# Evaluation worker set only has local worker.
elif self.config["evaluation_num_workers"] == 0:
for _ in range(num_episodes):
self.evaluation_workers.local_worker().sample()
# Evaluation worker set has n remote workers.
else:
# How many episodes have we run (across all eval workers)?
num_episodes_done = 0
round_ = 0
while True:
episodes_left_to_do = episodes_left_fn(num_episodes_done)
if episodes_left_to_do <= 0:
break
round_ += 1
batches = ray.get([
w.sample.remote() for i, w in enumerate(
self.evaluation_workers.remote_workers())
if i < episodes_left_to_do
])
# Per our config for the evaluation workers
# (`rollout_fragment_length=1` and
# `batch_mode=complete_episode`), we know that we'll have
# exactly one episode per returned batch.
num_episodes_done += len(batches)
logger.info(
f"Ran round {round_} of parallel evaluation "
f"({num_episodes_done}/{num_episodes} episodes done)")
if metrics is None:
metrics = collect_metrics(
self.evaluation_workers.local_worker(),
self.evaluation_workers.remote_workers())
return {"evaluation": metrics}
export_policy_checkpoint(self, export_dir, filename_prefix='model', policy_id='default_policy')
Exports policy model checkpoint to a local directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
export_dir |
str |
Writable local directory. |
required |
filename_prefix |
str |
file name prefix of checkpoint files. |
'model' |
policy_id |
str |
Optional policy id to export. |
'default_policy' |
Examples:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def export_policy_checkpoint(
self,
export_dir: str,
filename_prefix: str = "model",
policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
"""Exports policy model checkpoint to a local directory.
Args:
export_dir: Writable local directory.
filename_prefix: file name prefix of checkpoint files.
policy_id: Optional policy id to export.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
"""
self.get_policy(policy_id).export_checkpoint(export_dir,
filename_prefix)
export_policy_model(self, export_dir, policy_id='default_policy', onnx=None)
Exports policy model with given policy_id to a local directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
export_dir |
str |
Writable local directory. |
required |
policy_id |
str |
Optional policy id to export. |
'default_policy' |
onnx |
Optional[int] |
If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. If None, the output format will be DL framework specific. |
None |
Examples:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_model("/tmp/dir")
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def export_policy_model(self,
export_dir: str,
policy_id: PolicyID = DEFAULT_POLICY_ID,
onnx: Optional[int] = None) -> None:
"""Exports policy model with given policy_id to a local directory.
Args:
export_dir: Writable local directory.
policy_id: Optional policy id to export.
onnx: If given, will export model in ONNX format. The
value of this parameter set the ONNX OpSet version to use.
If None, the output format will be DL framework specific.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_model("/tmp/dir")
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
"""
self.get_policy(policy_id).export_model(export_dir, onnx)
get_default_policy_class(self, config)
Returns a default Policy class to use, given a config.
This class will be used inside RolloutWorkers' PolicyMaps in case the policy class is not provided by the user in any single- or multi-agent PolicySpec.
This method is experimental and currently only used, iff the Trainer
class was not created using the build_trainer
utility and if
the Trainer sub-class does not override _init()
and create it's
own WorkerSet in _init()
.
Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
def get_default_policy_class(self, config: PartialTrainerConfigDict):
"""Returns a default Policy class to use, given a config.
This class will be used inside RolloutWorkers' PolicyMaps in case
the policy class is not provided by the user in any single- or
multi-agent PolicySpec.
This method is experimental and currently only used, iff the Trainer
class was not created using the `build_trainer` utility and if
the Trainer sub-class does not override `_init()` and create it's
own WorkerSet in `_init()`.
"""
return getattr(self, "_policy_class", None)
get_policy(self, policy_id='default_policy')
Return policy for the specified id, or None.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
str |
ID of the policy to return. |
'default_policy' |
get_weights(self, policies=None)
Return a dictionary of policy ids to weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policies |
Optional[List[str]] |
Optional list of policies to return weights for, or None for all policies. |
None |
Source code in ray/rllib/agents/trainer.py
import_model(self, import_file)
Imports a model from import_file.
Note: Currently, only h5 files are supported.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
import_file |
str |
The file to import the model from. |
required |
Returns:
Type | Description |
---|---|
A dict that maps ExportFormats to successfully exported models. |
Source code in ray/rllib/agents/trainer.py
def import_model(self, import_file: str):
"""Imports a model from import_file.
Note: Currently, only h5 files are supported.
Args:
import_file (str): The file to import the model from.
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
# Check for existence.
if not os.path.exists(import_file):
raise FileNotFoundError(
"`import_file` '{}' does not exist! Can't import Model.".
format(import_file))
# Get the format of the given file.
import_format = "h5" # TODO(sven): Support checkpoint loading.
ExportFormat.validate([import_format])
if import_format != ExportFormat.H5:
raise NotImplementedError
else:
return self.import_policy_model_from_h5(import_file)
import_policy_model_from_h5(self, import_file, policy_id='default_policy')
Imports a policy's model with given policy_id from a local h5 file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
import_file |
str |
The h5 file to import from. |
required |
policy_id |
str |
Optional policy id to import into. |
'default_policy' |
Examples:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
Source code in ray/rllib/agents/trainer.py
@DeveloperAPI
def import_policy_model_from_h5(
self,
import_file: str,
policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
"""Imports a policy's model with given policy_id from a local h5 file.
Args:
import_file: The h5 file to import from.
policy_id: Optional policy id to import into.
Example:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
"""
self.get_policy(policy_id).import_model_from_h5(import_file)
# Sync new weights to remote workers.
self._sync_weights_to_workers(worker_set=self.workers)
load_checkpoint(self, checkpoint_path)
Subclasses should override this to implement restore().
!!! warning
In this method, do not rely on absolute paths. The absolute
path of the checkpoint_dir used in Trainable.save_checkpoint
may be changed.
If Trainable.save_checkpoint
returned a prefixed string, the
prefix of the checkpoint string returned by
Trainable.save_checkpoint
may be changed.
This is because trial pausing depends on temporary directories.
The directory structure under the checkpoint_dir provided to
Trainable.save_checkpoint
is preserved.
See the example below.
.. code-block:: python
class Example(Trainable):
def save_checkpoint(self, checkpoint_path):
print(checkpoint_path)
return os.path.join(checkpoint_path, "my/check/point")
def load_checkpoint(self, checkpoint):
print(checkpoint)
>>> trainer = Example()
>>> obj = trainer.save_to_object() # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj) # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
.. versionadded:: 0.8.7
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint |
str|dict |
If dict, the return value is as
returned by |
required |
log_result(self, result)
Subclasses can optionally override this to customize logging.
The logging here is done on the worker process rather than
the driver. You may want to turn off driver logging via the
loggers
parameter in tune.run
when overriding this function.
.. versionadded:: 0.8.7
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
dict |
Training result returned by step(). |
required |
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def log_result(self, result: ResultDict) -> None:
# Log after the callback is invoked, so that the user has a chance
# to mutate the result.
self.callbacks.on_train_result(trainer=self, result=result)
# Then log according to Trainable's logging logic.
Trainable.log_result(self, result)
remove_policy(self, policy_id='default_policy', *, policy_mapping_fn=None, policies_to_train=None, evaluation_workers=True)
Removes a new policy from this Trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_id |
Optional[PolicyID] |
ID of the policy to be removed. |
'default_policy' |
policy_mapping_fn |
Optional[Callable[[AgentID], PolicyID]] |
An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. |
None |
policies_to_train |
Optional[List[PolicyID]] |
An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated. |
None |
evaluation_workers |
bool |
Whether to also remove the policy from the evaluation WorkerSet. |
True |
Source code in ray/rllib/agents/trainer.py
@PublicAPI
def remove_policy(
self,
policy_id: PolicyID = DEFAULT_POLICY_ID,
*,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
evaluation_workers: bool = True,
) -> None:
"""Removes a new policy from this Trainer.
Args:
policy_id (Optional[PolicyID]): ID of the policy to be removed.
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
optional (updated) policy mapping function to use from here on.
Note that already ongoing episodes will not change their
mapping but will use the old mapping till the end of the
episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
evaluation_workers (bool): Whether to also remove the policy from
the evaluation WorkerSet.
"""
def fn(worker):
worker.remove_policy(
policy_id=policy_id,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
self.workers.foreach_worker(fn)
if evaluation_workers and self.evaluation_workers is not None:
self.evaluation_workers.foreach_worker(fn)
resource_help(config)
classmethod
Returns a help string for configuring this trainable's resources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict |
The Trainer's config dict. |
required |
Source code in ray/rllib/agents/trainer.py
@classmethod
@override(Trainable)
def resource_help(cls, config: TrainerConfigDict) -> str:
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers`, `num_gpus`, and other configs. See "
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
save_checkpoint(self, checkpoint_dir)
Subclasses should override this to implement save()
.
!!! warning
Do not rely on absolute paths in the implementation of
Trainable.save_checkpoint
and Trainable.load_checkpoint
.
Use validate_save_restore
to catch Trainable.save_checkpoint
/
Trainable.load_checkpoint
errors before execution.
from ray.tune.utils import validate_save_restore validate_save_restore(MyTrainableClass) validate_save_restore(MyTrainableClass, use_object_store=True)
.. versionadded:: 0.8.7
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tmp_checkpoint_dir |
str |
The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved. |
required |
Returns:
Type | Description |
---|---|
str |
A dict or string. If string, the return value is expected to be
prefixed by |
Examples:
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
set_weights(self, weights)
Set policy weights by policy id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights |
Dict[str, dict] |
Map of policy ids to weights to set. |
required |
setup(self, config)
Subclasses should override this for custom initialization.
.. versionadded:: 0.8.7
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict |
Hyperparameters and other configs given.
Copy of |
required |
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def setup(self, config: PartialTrainerConfigDict):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs)
# Setup the "env creator" callable.
env = self._env_id
if env:
self.config["env"] = env
# An already registered env.
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
# A class path specifier.
elif "." in env:
def env_creator_from_classpath(env_context):
try:
env_obj = from_config(env, env_context)
except ValueError:
raise EnvError(
ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env))
return env_obj
self.env_creator = env_creator_from_classpath
# Try gym/PyBullet/Vizdoom.
else:
self.env_creator = functools.partial(
gym_env_creator, env_descriptor=env)
# No env -> Env creator always returns None.
else:
self.env_creator = lambda env_config: None
# Check and resolve DL framework settings.
# Tf-eager (tf2|tfe), possibly with tracing set to True. Recommend
# setting tracing to True for speedups.
if tf1 and self.config["framework"] in ["tf2", "tfe"]:
if self.config["framework"] == "tf2" and tfv < 2:
raise ValueError(
"You configured `framework`=tf2, but your installed pip "
"tf-version is < 2.0! Make sure your TensorFlow version "
"is >= 2.x.")
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
logger.info(
f"Executing eagerly (framework='{self.config['framework']}'),"
f" with eager_tracing={self.config['eager_tracing']}. For "
"production workloads, make sure to set `eager_tracing=True` "
"in order to match the speed of tf-static-graph "
"(framework='tf'). For debugging purposes, "
"`eager_tracing=False` is the best choice.")
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
# enabling eager tracing for similar speed.
elif tf1 and self.config["framework"] == "tf":
logger.info(
"Your framework setting is 'tf', meaning you are using static"
"-graph mode. Set framework='tf2' to enable eager execution "
"with tf2.x. You may also want to then set "
"`eager_tracing=True` in order to reach similar execution "
"speed as with static-graph mode.")
# Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(
config.get("framework"), config.get("seed"))
self._validate_config(self.config, trainer_obj_or_none=self)
if not callable(self.config["callbacks"]):
raise ValueError(
"`callbacks` must be a callable method that "
"returns a subclass of DefaultCallbacks, got {}".format(
self.config["callbacks"]))
self.callbacks = self.config["callbacks"]()
log_level = self.config.get("log_level")
if log_level in ["WARN", "ERROR"]:
logger.info("Current log_level is {}. For more information, "
"set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
"-vv flags.".format(log_level))
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# Create local replay buffer if necessary.
self.local_replay_buffer = (
self._create_local_replay_buffer_if_necessary(self.config))
# Deprecated way of implementing Trainer sub-classes (or "templates"
# via the soon-to-be deprecated `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
self.workers = None
self.train_exec_impl = None
# Old design: Override `Trainer._init` (or use `build_trainer()`, which
# will do this for you).
try:
self._init(self.config, self.env_creator)
# New design: Override `Trainable.setup()` (as indented by Trainable)
# and do or don't call super().setup() from within your override.
# By default, `super().setup()` will create both worker sets:
# "rollout workers" for collecting samples for training and - if
# applicable - "evaluation workers" for evaluation runs in between or
# parallel to training.
# TODO: Deprecate `_init()` and remove this try/except block.
except NotImplementedError:
# Only if user did not override `_init()`:
# - Create rollout workers here automatically.
# - Run the execution plan to create the local iterator to `next()`
# in each training iteration.
# This matches the behavior of using `build_trainer()`, which
# should no longer be used.
self.workers = self._make_workers(
env_creator=self.env_creator,
validate_env=self.validate_env,
policy_class=self.get_default_policy_class(self.config),
config=self.config,
num_workers=self.config["num_workers"])
self.train_exec_impl = self.execution_plan(
self.workers, self.config, **self._kwargs_for_execution_plan())
# Evaluation WorkerSet setup.
self.evaluation_workers = None
self.evaluation_metrics = {}
# User would like to setup a separate evaluation worker set.
if self.config.get("evaluation_num_workers", 0) > 0 or \
self.config.get("evaluation_interval"):
# Update env_config with evaluation settings:
extra_config = copy.deepcopy(self.config["evaluation_config"])
# Assert that user has not unset "in_evaluation".
assert "in_evaluation" not in extra_config or \
extra_config["in_evaluation"] is True
evaluation_config = merge_dicts(self.config, extra_config)
# Validate evaluation config.
self._validate_config(evaluation_config, trainer_obj_or_none=self)
# Switch on complete_episode rollouts (evaluations are
# always done on n complete episodes) and set the
# `in_evaluation` flag. Also, make sure our rollout fragments
# are short so we don't have more than one episode in one rollout.
evaluation_config.update({
"batch_mode": "complete_episodes",
"rollout_fragment_length": 1,
"in_evaluation": True,
})
logger.debug("using evaluation_config: {}".format(extra_config))
# Create a separate evaluation worker set for evaluation.
# If evaluation_num_workers=0, use the evaluation set's local
# worker for evaluation, otherwise, use its remote workers
# (parallelized evaluation).
self.evaluation_workers = self._make_workers(
env_creator=self.env_creator,
validate_env=None,
policy_class=self.get_default_policy_class(self.config),
config=evaluation_config,
num_workers=self.config["evaluation_num_workers"])
step(self)
Implements the main Trainer.train()
logic.
Takes n attempts to perform a single training step. Thereby catches RayErrors resulting from worker failures. After n attempts, fails gracefully.
Override this method in your Trainer sub-classes if you would like to
handle worker failures yourself. Otherwise, override
self.step_attempt()
to keep the n attempts (catch worker failures).
Returns:
Type | Description |
---|---|
dict |
The results dict with stats/infos on sampling, training, and - if required - evaluation. |
Source code in ray/rllib/agents/trainer.py
@override(Trainable)
def step(self) -> ResultDict:
"""Implements the main `Trainer.train()` logic.
Takes n attempts to perform a single training step. Thereby
catches RayErrors resulting from worker failures. After n attempts,
fails gracefully.
Override this method in your Trainer sub-classes if you would like to
handle worker failures yourself. Otherwise, override
`self.step_attempt()` to keep the n attempts (catch worker failures).
Returns:
The results dict with stats/infos on sampling, training,
and - if required - evaluation.
"""
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
# Try to train one step.
try:
result = self.step_attempt()
# @ray.remote RolloutWorker failure -> Try to recover,
# if necessary.
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self.try_recover_from_step_attempt()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
# Any other exception.
except Exception as e:
# Allow logs messages to propagate.
time.sleep(0.5)
raise e
else:
break
# Still no result (even after n retries).
if result is None:
raise RuntimeError("Failed to recover from worker crash.")
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
self._sync_filters_if_needed(self.workers)
return result
step_attempt(self)
Attempts a single training step, including evaluation, if required.
Override this method in your Trainer sub-classes if you would like to
keep the n attempts (catch worker failures) or override step()
directly if you would like to handle worker failures yourself.
Returns:
Type | Description |
---|---|
dict |
The results dict with stats/infos on sampling, training, and - if required - evaluation. |
Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
def step_attempt(self) -> ResultDict:
"""Attempts a single training step, including evaluation, if required.
Override this method in your Trainer sub-classes if you would like to
keep the n attempts (catch worker failures) or override `step()`
directly if you would like to handle worker failures yourself.
Returns:
The results dict with stats/infos on sampling, training,
and - if required - evaluation.
"""
# self._iteration gets incremented after this function returns,
# meaning that e. g. the first time this function is called,
# self._iteration will be 0.
evaluate_this_iter = \
self.config["evaluation_interval"] and \
(self._iteration + 1) % self.config["evaluation_interval"] == 0
# No evaluation necessary, just run the next training iteration.
if not evaluate_this_iter:
step_results = next(self.train_exec_impl)
# We have to evaluate in this training iteration.
else:
# No parallelism.
if not self.config["evaluation_parallel_to_training"]:
step_results = next(self.train_exec_impl)
# Kick off evaluation-loop (and parallel train() call,
# if requested).
# Parallel eval + training.
if self.config["evaluation_parallel_to_training"]:
with concurrent.futures.ThreadPoolExecutor() as executor:
train_future = executor.submit(
lambda: next(self.train_exec_impl))
if self.config["evaluation_num_episodes"] == "auto":
# Run at least one `evaluate()` (num_episodes_done
# must be > 0), even if the training is very fast.
def episodes_left_fn(num_episodes_done):
if num_episodes_done > 0 and \
train_future.done():
return 0
else:
return self.config["evaluation_num_workers"]
evaluation_metrics = self.evaluate(
episodes_left_fn=episodes_left_fn)
else:
evaluation_metrics = self.evaluate()
# Collect the training results from the future.
step_results = train_future.result()
# Sequential: train (already done above), then eval.
else:
evaluation_metrics = self.evaluate()
# Add evaluation results to train results.
assert isinstance(evaluation_metrics, dict), \
"Trainer.evaluate() needs to return a dict."
step_results.update(evaluation_metrics)
# Check `env_task_fn` for possible update of the env's task.
if self.config["env_task_fn"] is not None:
if not callable(self.config["env_task_fn"]):
raise ValueError(
"`env_task_fn` must be None or a callable taking "
"[train_results, env, env_ctx] as args!")
def fn(env, env_context, task_fn):
new_task = task_fn(step_results, env, env_context)
cur_task = env.get_task()
if cur_task != new_task:
env.set_task(new_task)
fn = functools.partial(fn, task_fn=self.config["env_task_fn"])
self.workers.foreach_env_with_context(fn)
return step_results
try_recover_from_step_attempt(self)
Try to identify and remove any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker during the call to self.step_attempt()
(within
self.step()
). It issues check requests to all current workers and
removes any that respond with error. If no healthy workers remain,
an error is raised. Otherwise, tries to re-build the execution plan
with the remaining (healthy) workers.
Source code in ray/rllib/agents/trainer.py
def try_recover_from_step_attempt(self) -> None:
"""Try to identify and remove any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker during the call to `self.step_attempt()` (within
`self.step()`). It issues check requests to all current workers and
removes any that respond with error. If no healthy workers remain,
an error is raised. Otherwise, tries to re-build the execution plan
with the remaining (healthy) workers.
"""
workers = getattr(self, "workers", None)
if not isinstance(workers, WorkerSet):
return
logger.info("Health checking all workers...")
checks = []
for ev in workers.remote_workers():
_, obj_ref = ev.sample_with_count.remote()
checks.append(obj_ref)
healthy_workers = []
for i, obj_ref in enumerate(checks):
w = workers.remote_workers()[i]
try:
ray.get(obj_ref)
healthy_workers.append(w)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Removing unhealthy worker {}".format(i + 1))
try:
w.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_workers) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
logger.warning("Recreating execution plan after failure.")
workers.reset(healthy_workers)
if self.train_exec_impl is not None:
if callable(self.execution_plan):
self.train_exec_impl = self.execution_plan(
workers, self.config, **self._kwargs_for_execution_plan())
validate_env(env, env_context)
staticmethod
Env validator function for this Trainer class.
Override this in child classes to define custom validation behavior.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env |
Any |
The (sub-)environment to validate. This is normally a single sub-environment (e.g. a gym.Env) within a vectorized setup. |
required |
env_context |
EnvContext |
The EnvContext to configure the environment. |
required |
Source code in ray/rllib/agents/trainer.py
@ExperimentalAPI
@staticmethod
def validate_env(env: EnvType, env_context: EnvContext) -> None:
"""Env validator function for this Trainer class.
Override this in child classes to define custom validation
behavior.
Args:
env: The (sub-)environment to validate. This is normally a
single sub-environment (e.g. a gym.Env) within a vectorized
setup.
env_context: The EnvContext to configure the environment.
Raises:
Exception in case something is wrong with the given environment.
"""
pass
ray.rllib.agents.callbacks.DefaultCallbacks
Abstract base class for RLlib callbacks (similar to Keras callbacks).
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops. To configure custom training callbacks, subclass DefaultCallbacks and then set {"callbacks": YourCallbacksClass} in the trainer config.
on_episode_end(self, *, worker, base_env, policies, episode, **kwargs)
Runs when an episode is done.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
Reference to the current rollout worker. |
required |
base_env |
BaseEnv |
BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
|
required |
policies |
Dict[str, ray.rllib.policy.policy.Policy] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". |
required |
episode |
Episode |
Episode object which contains episode
state. You can use the |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: Episode,
**kwargs) -> None:
"""Runs when an episode is done.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy
objects. In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_end"):
self.legacy_callbacks["on_episode_end"]({
"env": base_env,
"policy": policies,
"episode": episode,
})
on_episode_start(self, *, worker, base_env, policies, episode, **kwargs)
Callback run on the rollout worker before each episode starts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
Reference to the current rollout worker. |
required |
base_env |
BaseEnv |
BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
|
required |
policies |
Dict[str, ray.rllib.policy.policy.Policy] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. |
required |
episode |
Episode |
Episode object which contains the episode's
state. You can use the |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: Episode,
**kwargs) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_start"):
self.legacy_callbacks["on_episode_start"]({
"env": base_env,
"policy": policies,
"episode": episode,
})
on_episode_step(self, *, worker, base_env, policies=None, episode, **kwargs)
Runs on each episode step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
Reference to the current rollout worker. |
required |
base_env |
BaseEnv |
BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
|
required |
policies |
Optional[Dict[str, ray.rllib.policy.policy.Policy]] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". |
None |
episode |
Episode |
Episode object which contains episode
state. You can use the |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_episode_step(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs) -> None:
"""Runs on each episode step.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id
to policy objects. In single agent mode there will only be a
single "default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_step"):
self.legacy_callbacks["on_episode_step"]({
"env": base_env,
"episode": episode
})
on_learn_on_batch(self, *, policy, train_batch, result, **kwargs)
Called at the beginning of Policy.learn_on_batch().
Note: This is called before 0-padding via
pad_batch_to_sequences_of_same_size
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy |
Policy |
Reference to the current Policy object. |
required |
train_batch |
SampleBatch |
SampleBatch to be trained on. You can mutate this object to modify the samples generated. |
required |
result |
dict |
A results dict to add custom metrics to. |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
result: dict, **kwargs) -> None:
"""Called at the beginning of Policy.learn_on_batch().
Note: This is called before 0-padding via
`pad_batch_to_sequences_of_same_size`.
Args:
policy: Reference to the current Policy object.
train_batch: SampleBatch to be trained on. You can
mutate this object to modify the samples generated.
result: A results dict to add custom metrics to.
kwargs: Forward compatibility placeholder.
"""
pass
on_postprocess_trajectory(self, *, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs)
Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy, including looking at the trajectory data of other agents in multi-agent settings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
Reference to the current rollout worker. |
required |
episode |
Episode |
Episode object. |
required |
agent_id |
Any |
Id of the current agent. |
required |
policy_id |
str |
Id of the current policy for the agent. |
required |
policies |
Dict[str, ray.rllib.policy.policy.Policy] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". |
required |
postprocessed_batch |
SampleBatch |
The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing. |
required |
original_batches |
Dict[Any, ray.rllib.policy.sample_batch.SampleBatch] |
Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object. |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: Episode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
"""Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Args:
worker: Reference to the current rollout worker.
episode: Episode object.
agent_id: Id of the current agent.
policy_id: Id of the current policy for the agent.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default_policy".
postprocessed_batch: The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches: Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_postprocess_traj"):
self.legacy_callbacks["on_postprocess_traj"]({
"episode": episode,
"agent_id": agent_id,
"pre_batch": original_batches[agent_id],
"post_batch": postprocessed_batch,
"all_pre_batches": original_batches,
})
on_sample_end(self, *, worker, samples, **kwargs)
Called at the end of RolloutWorker.sample().
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
RolloutWorker |
Reference to the current rollout worker. |
required |
samples |
SampleBatch |
Batch to be returned. You can mutate this object to modify the samples generated. |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
**kwargs) -> None:
"""Called at the end of RolloutWorker.sample().
Args:
worker: Reference to the current rollout worker.
samples: Batch to be returned. You can mutate this
object to modify the samples generated.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_sample_end"):
self.legacy_callbacks["on_sample_end"]({
"worker": worker,
"samples": samples,
})
on_train_result(self, *, trainer, result, **kwargs)
Called at the end of Trainable.train().
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
Trainer |
Current trainer instance. |
required |
result |
dict |
Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics. |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in ray/rllib/agents/callbacks.py
def on_train_result(self, *, trainer: "Trainer", result: dict,
**kwargs) -> None:
"""Called at the end of Trainable.train().
Args:
trainer: Current trainer instance.
result: Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_train_result"):
self.legacy_callbacks["on_train_result"]({
"trainer": trainer,
"result": result,
})