Skip to content

policy package

ray.rllib.policy.policy.Policy

Policy base class: Calculates actions, losses, and holds NN models.

Policy is the abstract superclass for all DL-framework specific sub-classes (e.g. TFPolicy or TorchPolicy). It exposes APIs to

1) Compute actions from observation (and possibly other) inputs. 2) Manage the Policy's NN model(s), like exporting and loading their weights. 3) Postprocess a given trajectory from the environment or other input via the postprocess_trajectory method. 4) Compute losses from a train batch. 5) Perform updates from a train batch on the NN-models (this normally includes loss calculations) either a) in one monolithic step (train_on_batch) or b) via batch pre-loading, then n steps of actual loss computations and updates (load_batch_into_buffer + learn_on_loaded_batch).

Note: It is not recommended to sub-class Policy directly, but rather use one of the following two convenience methods: rllib.policy.policy_template::build_policy_class (PyTorch) or rllib.policy.tf_policy_template::build_tf_policy_class (TF).

__init__(self, observation_space, action_space, config) special

Initializes a Policy instance.

Parameters:

Name Type Description Default
observation_space Space

Observation space of the policy.

required
action_space Space

Action space of the policy.

required
config dict

A complete Trainer/Policy config dict. For the default config keys and values, see rllib/trainer/trainer.py.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def __init__(self, observation_space: gym.Space, action_space: gym.Space,
             config: TrainerConfigDict):
    """Initializes a Policy instance.

    Args:
        observation_space: Observation space of the policy.
        action_space: Action space of the policy.
        config: A complete Trainer/Policy config dict. For the default
            config keys and values, see rllib/trainer/trainer.py.
    """
    self.observation_space: gym.Space = observation_space
    self.action_space: gym.Space = action_space
    # The base struct of the action space.
    # E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
    # action_space_struct = {"a": Discrete(2)}
    self.action_space_struct = get_base_struct_from_space(action_space)

    self.config: TrainerConfigDict = config
    self.framework = self.config.get("framework")
    # Create the callbacks object to use for handling custom callbacks.
    if self.config.get("callbacks"):
        self.callbacks: "DefaultCallbacks" = self.config.get("callbacks")()
    else:
        from ray.rllib.agents.callbacks import DefaultCallbacks
        self.callbacks: "DefaultCallbacks" = DefaultCallbacks()

    # The global timestep, broadcast down from time to time from the
    # local worker to all remote workers.
    self.global_timestep: int = 0

    # The action distribution class to use for action sampling, if any.
    # Child classes may set this.
    self.dist_class: Optional[Type] = None

    # Maximal view requirements dict for `learn_on_batch()` and
    # `compute_actions` calls.
    # View requirements will be automatically filtered out later based
    # on the postprocessing and loss functions to ensure optimal data
    # collection and transfer performance.
    view_reqs = self._get_default_view_requirements()
    if not hasattr(self, "view_requirements"):
        self.view_requirements = view_reqs
    else:
        for k, v in view_reqs.items():
            if k not in self.view_requirements:
                self.view_requirements[k] = v
    # Whether the Model's initial state (method) has been added
    # automatically based on the given view requirements of the model.
    self._model_init_state_automatically_added = False

apply_gradients(self, gradients)

Applies the (previously) computed gradients.

Either this in combination with compute_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters:

Name Type Description Default
gradients Union[List[Tuple[Any, Any]], List[Any]]

The already calculated gradients to apply to this Policy.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
    """Applies the (previously) computed gradients.

    Either this in combination with `compute_gradients()` or
    `learn_on_batch()` must be implemented by subclasses.

    Args:
        gradients: The already calculated gradients to apply to this
            Policy.
    """
    raise NotImplementedError

compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs)

Computes actions for the current policy.

Parameters:

Name Type Description Default
obs_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of observations.

required
state_batches Optional[List[Any]]

List of RNN state input batches, if any.

None
prev_action_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of previous action values.

None
prev_reward_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of previous rewards.

None
info_batch Optional[Dict[str, list]]

Batch of info objects.

None
episodes Optional[List[Episode]]

List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None
explore Optional[bool]

Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config["explore"].

None
timestep Optional[int]

The current (sampling) time step.

None

Returns:

Type Description
actions (TensorType)

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/policy.py
@abstractmethod
@DeveloperAPI
def compute_actions(
        self,
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        prev_reward_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs) -> \
        Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
    """Computes actions for the current policy.

    Args:
        obs_batch: Batch of observations.
        state_batches: List of RNN state input batches, if any.
        prev_action_batch: Batch of previous action values.
        prev_reward_batch: Batch of previous rewards.
        info_batch: Batch of info objects.
        episodes: List of Episode objects, one for each obs in
            obs_batch. This provides access to all of the internal
            episode state, which may be useful for model-based or
            multi-agent algorithms.
        explore: Whether to pick an exploitation or exploration action.
            Set to None (default) for using the value of
            `self.config["explore"]`.
        timestep: The current (sampling) time step.

    Keyword Args:
        kwargs: Forward compatibility placeholder

    Returns:
        actions (TensorType): Batch of output actions, with shape like
            [BATCH_SIZE, ACTION_SHAPE].
        state_outs (List[TensorType]): List of RNN state output
            batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
        info (List[dict]): Dictionary of extra feature batches, if any,
            with shape like
            {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
    """
    raise NotImplementedError

compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs)

Computes actions from collected samples (across multiple-agents).

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.

Parameters:

Name Type Description Default
input_dict Union[ray.rllib.policy.sample_batch.SampleBatch, Dict[str, Union[Any, dict, tuple]]]

A SampleBatch or input dict containing the Tensors to compute actions. input_dict already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is.

required
explore bool

Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]).

None
timestep Optional[int]

The current (sampling) time step.

None
episodes Optional[List[Episode]]

This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms.

None

Returns:

Type Description
actions

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs: List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info: Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def compute_actions_from_input_dict(
        self,
        input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes: Optional[List["Episode"]] = None,
        **kwargs) -> \
        Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
    """Computes actions from collected samples (across multiple-agents).

    Takes an input dict (usually a SampleBatch) as its main data input.
    This allows for using this method in case a more complex input pattern
    (view requirements) is needed, for example when the Model requires the
    last n observations, the last m actions/rewards, or a combination
    of any of these.

    Args:
        input_dict: A SampleBatch or input dict containing the Tensors
            to compute actions. `input_dict` already abides to the
            Policy's as well as the Model's view requirements and can
            thus be passed to the Model as-is.
        explore: Whether to pick an exploitation or exploration
            action (default: None -> use self.config["explore"]).
        timestep: The current (sampling) time step.
        episodes: This provides access to all of the internal episodes'
            state, which may be useful for model-based or multi-agent
            algorithms.

    Keyword Args:
        kwargs: Forward compatibility placeholder.

    Returns:
        actions: Batch of output actions, with shape like
            [BATCH_SIZE, ACTION_SHAPE].
        state_outs: List of RNN state output
            batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
        info: Dictionary of extra feature batches, if any, with shape like
            {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
    """
    # Default implementation just passes obs, prev-a/r, and states on to
    # `self.compute_actions()`.
    state_batches = [
        s for k, s in input_dict.items() if k[:9] == "state_in_"
    ]
    return self.compute_actions(
        input_dict[SampleBatch.OBS],
        state_batches,
        prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
        prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
        info_batch=input_dict.get(SampleBatch.INFOS),
        explore=explore,
        timestep=timestep,
        episodes=episodes,
        **kwargs,
    )

compute_gradients(self, postprocessed_batch)

Computes gradients given a batch of experiences.

Either this in combination with apply_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters:

Name Type Description Default
postprocessed_batch SampleBatch

The SampleBatch object to use for calculating gradients.

required

Returns:

Type Description
grads

List of gradient output values. grad_info: Extra policy-specific info values.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
        Tuple[ModelGradients, Dict[str, TensorType]]:
    """Computes gradients given a batch of experiences.

    Either this in combination with `apply_gradients()` or
    `learn_on_batch()` must be implemented by subclasses.

    Args:
        postprocessed_batch: The SampleBatch object to use
            for calculating gradients.

    Returns:
        grads: List of gradient output values.
        grad_info: Extra policy-specific info values.
    """
    raise NotImplementedError

compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, actions_normalized=True)

Computes the log-prob/likelihood for a given action and observation.

The log-likelihood is calculated using this Policy's action distribution class (self.dist_class).

Parameters:

Name Type Description Default
actions Union[List[Any], Any]

Batch of actions, for which to retrieve the log-probs/likelihoods (given all other inputs: obs, states, ..).

required
obs_batch Union[List[Any], Any]

Batch of observations.

required
state_batches Optional[List[Any]]

List of RNN state input batches, if any.

None
prev_action_batch Union[List[Any], Any]

Batch of previous action values.

None
prev_reward_batch Union[List[Any], Any]

Batch of previous rewards.

None
actions_normalized bool

Is the given actions already normalized (between -1.0 and 1.0) or not? If not and normalize_actions=True, we need to normalize the given actions first, before calculating log likelihoods.

True

Returns:

Type Description
Batch of log probs/likelihoods, with shape

[BATCH_SIZE].

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        actions_normalized: bool = True,
) -> TensorType:
    """Computes the log-prob/likelihood for a given action and observation.

    The log-likelihood is calculated using this Policy's action
    distribution class (self.dist_class).

    Args:
        actions: Batch of actions, for which to retrieve the
            log-probs/likelihoods (given all other inputs: obs,
            states, ..).
        obs_batch: Batch of observations.
        state_batches: List of RNN state input batches, if any.
        prev_action_batch: Batch of previous action values.
        prev_reward_batch: Batch of previous rewards.
        actions_normalized: Is the given `actions` already normalized
            (between -1.0 and 1.0) or not? If not and
            `normalize_actions=True`, we need to normalize the given
            actions first, before calculating log likelihoods.

    Returns:
        Batch of log probs/likelihoods, with shape: [BATCH_SIZE].
    """
    raise NotImplementedError

compute_single_action(self, obs=None, state=None, *, prev_action=None, prev_reward=None, info=None, input_dict=None, episode=None, explore=None, timestep=None, **kwargs)

Computes and returns a single (B=1) action value.

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these. Alternatively, in case no complex inputs are required, takes a single obs values (and possibly single state values, prev-action/reward values, etc..).

Parameters:

Name Type Description Default
obs Union[Any, dict, tuple]

Single observation.

None
state Optional[List[Any]]

List of RNN state inputs, if any.

None
prev_action Union[Any, dict, tuple]

Previous action value, if any.

None
prev_reward Union[Any, dict, tuple]

Previous reward, if any.

None
info dict

Info object, if any.

None
input_dict Optional[ray.rllib.policy.sample_batch.SampleBatch]

A SampleBatch or input dict containing the single (unbatched) Tensors to compute actions. If given, it'll be used instead of obs, state, prev_action|reward, and info.

None
episode Optional[Episode]

This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None
explore Optional[bool]

Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]).

None
timestep Optional[int]

The current (sampling) time step.

None

Returns:

Type Description
Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]]

Tuple consisting of the action, the list of RNN state outputs (if any), and a dictionary of extra features (if any).

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def compute_single_action(
        self,
        obs: Optional[TensorStructType] = None,
        state: Optional[List[TensorType]] = None,
        *,
        prev_action: Optional[TensorStructType] = None,
        prev_reward: Optional[TensorStructType] = None,
        info: dict = None,
        input_dict: Optional[SampleBatch] = None,
        episode: Optional["Episode"] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        # Kwars placeholder for future compatibility.
        **kwargs) -> \
        Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
    """Computes and returns a single (B=1) action value.

    Takes an input dict (usually a SampleBatch) as its main data input.
    This allows for using this method in case a more complex input pattern
    (view requirements) is needed, for example when the Model requires the
    last n observations, the last m actions/rewards, or a combination
    of any of these.
    Alternatively, in case no complex inputs are required, takes a single
    `obs` values (and possibly single state values, prev-action/reward
    values, etc..).

    Args:
        obs: Single observation.
        state: List of RNN state inputs, if any.
        prev_action: Previous action value, if any.
        prev_reward: Previous reward, if any.
        info: Info object, if any.
        input_dict: A SampleBatch or input dict containing the
            single (unbatched) Tensors to compute actions. If given, it'll
            be used instead of `obs`, `state`, `prev_action|reward`, and
            `info`.
        episode: This provides access to all of the internal episode state,
            which may be useful for model-based or multi-agent algorithms.
        explore: Whether to pick an exploitation or
            exploration action
            (default: None -> use self.config["explore"]).
        timestep: The current (sampling) time step.

    Keyword Args:
        kwargs: Forward compatibility placeholder.

    Returns:
        Tuple consisting of the action, the list of RNN state outputs (if
        any), and a dictionary of extra features (if any).
    """
    # Build the input-dict used for the call to
    # `self.compute_actions_from_input_dict()`.
    if input_dict is None:
        input_dict = {SampleBatch.OBS: obs}
        if state is not None:
            for i, s in enumerate(state):
                input_dict[f"state_in_{i}"] = s
        if prev_action is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action
        if prev_reward is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward
        if info is not None:
            input_dict[SampleBatch.INFOS] = info

    # Batch all data in input dict.
    input_dict = tree.map_structure_with_path(
        lambda p, s: (s if p == "seq_lens" else s.unsqueeze(0) if
                      torch and isinstance(s, torch.Tensor) else
                      np.expand_dims(s, 0)),
        input_dict)

    episodes = None
    if episode is not None:
        episodes = [episode]

    out = self.compute_actions_from_input_dict(
        input_dict=SampleBatch(input_dict),
        episodes=episodes,
        explore=explore,
        timestep=timestep,
    )

    # Some policies don't return a tuple, but always just a single action.
    # E.g. ES and ARS.
    if not isinstance(out, tuple):
        single_action = out
        state_out = []
        info = {}
    # Normal case: Policy should return (action, state, info) tuple.
    else:
        batched_action, state_out, info = out
        single_action = unbatch(batched_action)
    assert len(single_action) == 1
    single_action = single_action[0]

    # Return action, internal state(s), infos.
    return single_action, [s[0] for s in state_out], \
        {k: v[0] for k, v in info.items()}

export_checkpoint(self, export_dir)

Export Policy checkpoint to local directory.

Parameters:

Name Type Description Default
export_dir str

Local writable directory.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def export_checkpoint(self, export_dir: str) -> None:
    """Export Policy checkpoint to local directory.

    Args:
        export_dir: Local writable directory.
    """
    raise NotImplementedError

export_model(self, export_dir, onnx=None)

Exports the Policy's Model to local directory for serving.

Note: The file format will depend on the deep learning framework used. See the child classed of Policy and their export_model implementations for more details.

Parameters:

Name Type Description Default
export_dir str

Local writable directory.

required
onnx Optional[int]

If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use.

None
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def export_model(self, export_dir: str,
                 onnx: Optional[int] = None) -> None:
    """Exports the Policy's Model to local directory for serving.

    Note: The file format will depend on the deep learning framework used.
    See the child classed of Policy and their `export_model`
    implementations for more details.

    Args:
        export_dir: Local writable directory.
        onnx: If given, will export model in ONNX format. The
            value of this parameter set the ONNX OpSet version to use.
    """
    raise NotImplementedError

get_exploration_state(self)

Returns the state of this Policy's exploration component.

Returns:

Type Description
Dict[str, Any]

Serializable information on the self.exploration object.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_exploration_state(self) -> Dict[str, TensorType]:
    """Returns the state of this Policy's exploration component.

    Returns:
        Serializable information on the `self.exploration` object.
    """
    return self.exploration.get_state()

get_initial_state(self)

Returns initial RNN state for the current policy.

Returns:

Type Description
List[TensorType]

Initial RNN state for the current policy.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
    """Returns initial RNN state for the current policy.

    Returns:
        List[TensorType]: Initial RNN state for the current policy.
    """
    return []

get_num_samples_loaded_into_buffer(self, buffer_index=0)

Returns the number of currently loaded samples in the given buffer.

Parameters:

Name Type Description Default
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description
int

The number of tuples loaded per device.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
    """Returns the number of currently loaded samples in the given buffer.

    Args:
        buffer_index: The index of the buffer (a MultiGPUTowerStack)
            to use on the devices. The number of buffers on each device
            depends on the value of the `num_multi_gpu_tower_stacks` config
            key.

    Returns:
        The number of tuples loaded per device.
    """
    raise NotImplementedError

get_session(self)

Returns tf.Session object to use for computing actions or None.

Note: This method only applies to TFPolicy sub-classes. All other sub-classes should expect a None to be returned from this method.

Returns:

Type Description
Optional[tf1.Session]

The tf Session to use for computing actions and losses with this policy or None.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_session(self) -> Optional["tf1.Session"]:
    """Returns tf.Session object to use for computing actions or None.

    Note: This method only applies to TFPolicy sub-classes. All other
    sub-classes should expect a None to be returned from this method.

    Returns:
        The tf Session to use for computing actions and losses with
            this policy or None.
    """
    return None

get_state(self)

Returns the entire current state of this Policy.

Note: Not to be confused with an RNN model's internal state. State includes the Model(s)' weights, optimizer weights, the exploration component's state, as well as global variables, such as sampling timesteps.

Returns:

Type Description
Union[Dict[str, Any], List[Any]]

Serialized local state.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
    """Returns the entire current state of this Policy.

    Note: Not to be confused with an RNN model's internal state.
    State includes the Model(s)' weights, optimizer weights,
    the exploration component's state, as well as global variables, such
    as sampling timesteps.

    Returns:
        Serialized local state.
    """
    state = {
        # All the policy's weights.
        "weights": self.get_weights(),
        # The current global timestep.
        "global_timestep": self.global_timestep,
    }
    return state

get_weights(self)

Returns model weights.

Note: The return value of this method will reside under the "weights" key in the return value of Policy.get_state(). Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Returns:

Type Description
dict

Serializable copy or view of model weights.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def get_weights(self) -> ModelWeights:
    """Returns model weights.

    Note: The return value of this method will reside under the "weights"
    key in the return value of Policy.get_state(). Model weights are only
    one part of a Policy's state. Other state information contains:
    optimizer variables, exploration state, and global state vars such as
    the sampling timestep.

    Returns:
        Serializable copy or view of model weights.
    """
    raise NotImplementedError

import_model_from_h5(self, import_file)

Imports Policy from local file.

Parameters:

Name Type Description Default
import_file str

Local readable file.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
    """Imports Policy from local file.

    Args:
        import_file (str): Local readable file.
    """
    raise NotImplementedError

is_recurrent(self)

Whether this Policy holds a recurrent Model.

Returns:

Type Description
bool

True if this Policy has-a RNN-based Model.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def is_recurrent(self) -> bool:
    """Whether this Policy holds a recurrent Model.

    Returns:
        True if this Policy has-a RNN-based Model.
    """
    return False

learn_on_batch(self, samples)

Perform one learning update, given samples.

Either this method or the combination of compute_gradients and apply_gradients must be implemented by subclasses.

Parameters:

Name Type Description Default
samples SampleBatch

The SampleBatch object to learn from.

required

Returns:

Type Description
Dict[str, Any]

Dictionary of extra metadata from compute_gradients().

Examples:

>>> sample_batch = ev.sample()
>>> ev.learn_on_batch(sample_batch)
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
    """Perform one learning update, given `samples`.

    Either this method or the combination of `compute_gradients` and
    `apply_gradients` must be implemented by subclasses.

    Args:
        samples: The SampleBatch object to learn from.

    Returns:
        Dictionary of extra metadata from `compute_gradients()`.

    Examples:
        >>> sample_batch = ev.sample()
        >>> ev.learn_on_batch(sample_batch)
    """
    # The default implementation is simply a fused `compute_gradients` plus
    # `apply_gradients` call.
    grads, grad_info = self.compute_gradients(samples)
    self.apply_gradients(grads)
    return grad_info

learn_on_loaded_batch(self, offset=0, buffer_index=0)

Runs a single step of SGD on an already loaded data in a buffer.

Runs an SGD step over a slice of the pre-loaded batch, offset by the offset argument (useful for performing n minibatch SGD updates repeatedly on the same, already pre-loaded data).

Updates the model weights based on the averaged per-device gradients.

Parameters:

Name Type Description Default
offset int

Offset into the preloaded data. Used for pre-loading a train-batch once to a device, then iterating over (subsampling through) this batch n times doing minibatch SGD.

0
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to take the already pre-loaded data from. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description

The outputs of extra_ops evaluated over the batch.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
    """Runs a single step of SGD on an already loaded data in a buffer.

    Runs an SGD step over a slice of the pre-loaded batch, offset by
    the `offset` argument (useful for performing n minibatch SGD
    updates repeatedly on the same, already pre-loaded data).

    Updates the model weights based on the averaged per-device gradients.

    Args:
        offset: Offset into the preloaded data. Used for pre-loading
            a train-batch once to a device, then iterating over
            (subsampling through) this batch n times doing minibatch SGD.
        buffer_index: The index of the buffer (a MultiGPUTowerStack)
            to take the already pre-loaded data from. The number of buffers
            on each device depends on the value of the
            `num_multi_gpu_tower_stacks` config key.

    Returns:
        The outputs of extra_ops evaluated over the batch.
    """
    raise NotImplementedError

load_batch_into_buffer(self, batch, buffer_index=0)

Bulk-loads the given SampleBatch into the devices' memories.

The data is split equally across all the Policy's devices. If the data is not evenly divisible by the batch size, excess data should be discarded.

Parameters:

Name Type Description Default
batch SampleBatch

The SampleBatch to load.

required
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description
int

The number of tuples loaded per device.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def load_batch_into_buffer(self, batch: SampleBatch,
                           buffer_index: int = 0) -> int:
    """Bulk-loads the given SampleBatch into the devices' memories.

    The data is split equally across all the Policy's devices.
    If the data is not evenly divisible by the batch size, excess data
    should be discarded.

    Args:
        batch: The SampleBatch to load.
        buffer_index: The index of the buffer (a MultiGPUTowerStack) to use
            on the devices. The number of buffers on each device depends
            on the value of the `num_multi_gpu_tower_stacks` config key.

    Returns:
        The number of tuples loaded per device.
    """
    raise NotImplementedError

loss(self, model, dist_class, train_batch)

Loss function for this Policy.

Override this method in order to implement custom loss computations.

Parameters:

Name Type Description Default
model ModelV2

The model to calculate the loss(es).

required
dist_class ActionDistribution

The action distribution class to sample actions from the model's outputs.

required
train_batch SampleBatch

The input batch on which to calculate the loss.

required

Returns:

Type Description
Union[Any, List[Any]]

Either a single loss tensor or a list of loss tensors.

Source code in ray/rllib/policy/policy.py
@ExperimentalAPI
@OverrideToImplementCustomLogic
def loss(self, model: ModelV2, dist_class: ActionDistribution,
         train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Loss function for this Policy.

    Override this method in order to implement custom loss computations.

    Args:
        model: The model to calculate the loss(es).
        dist_class: The action distribution class to sample actions
            from the model's outputs.
        train_batch: The input batch on which to calculate the loss.

    Returns:
        Either a single loss tensor or a list of loss tensors.
    """
    raise NotImplementedError

num_state_tensors(self)

The number of internal states needed by the RNN-Model of the Policy.

Returns:

Type Description
int

The number of RNN internal states kept by this Policy's Model.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def num_state_tensors(self) -> int:
    """The number of internal states needed by the RNN-Model of the Policy.

    Returns:
        int: The number of RNN internal states kept by this Policy's Model.
    """
    return 0

on_global_var_update(self, global_vars)

Called on an update to global vars.

Parameters:

Name Type Description Default
global_vars Dict[str, Any]

Global variables by str key, broadcast from the driver.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
    """Called on an update to global vars.

    Args:
        global_vars: Global variables by str key, broadcast from the
            driver.
    """
    # Store the current global time step (sum over all policies' sample
    # steps).
    self.global_timestep = global_vars["timestep"]

postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None)

Implements algorithm-specific trajectory postprocessing.

This will be called on each trajectory fragment computed during policy evaluation. Each fragment is guaranteed to be only from one episode. The given fragment may or may not contain the end of this episode, depending on the batch_mode=truncate_episodes|complete_episodes, rollout_fragment_length, and other settings.

Parameters:

Name Type Description Default
sample_batch SampleBatch

batch of experiences for the policy, which will contain at most one episode trajectory.

required
other_agent_batches Optional[Dict[Any, Tuple[Policy, ray.rllib.policy.sample_batch.SampleBatch]]]

In a multi-agent env, this contains a mapping of agent ids to (policy, agent_batch) tuples containing the policy and experiences of the other agents.

None
episode Optional[Episode]

An optional multi-agent episode object to provide access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None

Returns:

Type Description
SampleBatch

The postprocessed sample batch.

Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def postprocess_trajectory(
        self,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[AgentID, Tuple[
            "Policy", SampleBatch]]] = None,
        episode: Optional["Episode"] = None) -> SampleBatch:
    """Implements algorithm-specific trajectory postprocessing.

    This will be called on each trajectory fragment computed during policy
    evaluation. Each fragment is guaranteed to be only from one episode.
    The given fragment may or may not contain the end of this episode,
    depending on the `batch_mode=truncate_episodes|complete_episodes`,
    `rollout_fragment_length`, and other settings.

    Args:
        sample_batch: batch of experiences for the policy,
            which will contain at most one episode trajectory.
        other_agent_batches: In a multi-agent env, this contains a
            mapping of agent ids to (policy, agent_batch) tuples
            containing the policy and experiences of the other agents.
        episode: An optional multi-agent episode object to provide
            access to all of the internal episode state, which may
            be useful for model-based or multi-agent algorithms.

    Returns:
        The postprocessed sample batch.
    """
    # The default implementation just returns the same, unaltered batch.
    return sample_batch

set_state(self, state)

Restores the entire current state of this Policy from state.

Parameters:

Name Type Description Default
state Union[Dict[str, Any], List[Any]]

The new state to set this policy to. Can be obtained by calling self.get_state().

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def set_state(
        self,
        state: Union[Dict[str, TensorType], List[TensorType]],
) -> None:
    """Restores the entire current state of this Policy from `state`.

    Args:
        state: The new state to set this policy to. Can be
            obtained by calling `self.get_state()`.
    """
    self.set_weights(state["weights"])
    self.global_timestep = state["global_timestep"]

set_weights(self, weights)

Sets this Policy's model's weights.

Note: Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Parameters:

Name Type Description Default
weights dict

Serializable copy or view of model weights.

required
Source code in ray/rllib/policy/policy.py
@DeveloperAPI
def set_weights(self, weights: ModelWeights) -> None:
    """Sets this Policy's model's weights.

    Note: Model weights are only one part of a Policy's state. Other
    state information contains: optimizer variables, exploration state,
    and global state vars such as the sampling timestep.

    Args:
        weights: Serializable copy or view of model weights.
    """
    raise NotImplementedError

ray.rllib.policy.torch_policy.TorchPolicy (Policy)

PyTorch specific Policy class to use with RLlib.

__init__(self, observation_space, action_space, config, *, model=None, loss=None, action_distribution_class=None, action_sampler_fn=None, action_distribution_fn=None, max_seq_len=20, get_batch_divisibility_req=None) special

Initializes a TorchPolicy instance.

Parameters:

Name Type Description Default
observation_space Space

Observation space of the policy.

required
action_space Space

Action space of the policy.

required
config dict

The Policy's config dict.

required
model Optional[ray.rllib.models.torch.torch_modelv2.TorchModelV2]

PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value.

None
loss Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], ray.rllib.policy.sample_batch.SampleBatch], Union[Any, List[Any]]]]

Callable that returns one or more (a list of) scalar loss terms.

None
action_distribution_class Optional[Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper]]

Class for a torch action distribution.

None
action_sampler_fn Optional[Callable[[Any, List[Any]], Tuple[Any, Any]]]

A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, state batches (optional), explore, and timestep. Provide action_sampler_fn if you would like to have full control over the action computation step, including the model forward pass, possible sampling from a distribution, and exploration logic. Note: If action_sampler_fn is given, action_distribution_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, input_dict (SampleBatch), state_batches (optional), explore, and timestep.

None
action_distribution_fn Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Any, Any, Any], Tuple[Any, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], List[Any]]]]

A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Provide action_distribution_fn if you would like to only customize the model forward pass call. The resulting distribution parameters are then used by RLlib to create a distribution object, sample from it, and execute any exploration logic. Note: If action_distribution_fn is given, action_sampler_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, ModelInputDict, explore, timestep, is_training.

None
max_seq_len int

Max sequence length for LSTM training.

20
get_batch_divisibility_req Optional[Callable[[ray.rllib.policy.policy.Policy], int]]

Optional callable that returns the divisibility requirement for sample batches given the Policy.

None
Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        model: Optional[TorchModelV2] = None,
        loss: Optional[Callable[[
            Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
        ], Union[TensorType, List[TensorType]]]] = None,
        action_distribution_class: Optional[Type[
            TorchDistributionWrapper]] = None,
        action_sampler_fn: Optional[Callable[[
            TensorType, List[TensorType]
        ], Tuple[TensorType, TensorType]]] = None,
        action_distribution_fn: Optional[Callable[[
            Policy, ModelV2, TensorType, TensorType, TensorType
        ], Tuple[TensorType, Type[TorchDistributionWrapper], List[
            TensorType]]]] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy],
                                                      int]] = None,
):
    """Initializes a TorchPolicy instance.

    Args:
        observation_space: Observation space of the policy.
        action_space: Action space of the policy.
        config: The Policy's config dict.
        model: PyTorch policy module. Given observations as
            input, this module must return a list of outputs where the
            first item is action logits, and the rest can be any value.
        loss: Callable that returns one or more (a list of) scalar loss
            terms.
        action_distribution_class: Class for a torch action distribution.
        action_sampler_fn: A callable returning a sampled action and its
            log-likelihood given Policy, ModelV2, input_dict, state batches
            (optional), explore, and timestep.
            Provide `action_sampler_fn` if you would like to have full
            control over the action computation step, including the
            model forward pass, possible sampling from a distribution,
            and exploration logic.
            Note: If `action_sampler_fn` is given, `action_distribution_fn`
            must be None. If both `action_sampler_fn` and
            `action_distribution_fn` are None, RLlib will simply pass
            inputs through `self.model` to get distribution inputs, create
            the distribution object, sample from it, and apply some
            exploration logic to the results.
            The callable takes as inputs: Policy, ModelV2, input_dict
            (SampleBatch), state_batches (optional), explore, and timestep.
        action_distribution_fn: A callable returning distribution inputs
            (parameters), a dist-class to generate an action distribution
            object from, and internal-state outputs (or an empty list if
            not applicable).
            Provide `action_distribution_fn` if you would like to only
            customize the model forward pass call. The resulting
            distribution parameters are then used by RLlib to create a
            distribution object, sample from it, and execute any
            exploration logic.
            Note: If `action_distribution_fn` is given, `action_sampler_fn`
            must be None. If both `action_sampler_fn` and
            `action_distribution_fn` are None, RLlib will simply pass
            inputs through `self.model` to get distribution inputs, create
            the distribution object, sample from it, and apply some
            exploration logic to the results.
            The callable takes as inputs: Policy, ModelV2, ModelInputDict,
            explore, timestep, is_training.
        max_seq_len: Max sequence length for LSTM training.
        get_batch_divisibility_req: Optional callable that returns the
            divisibility requirement for sample batches given the Policy.
    """
    self.framework = config["framework"] = "torch"
    super().__init__(observation_space, action_space, config)

    # Create multi-GPU model towers, if necessary.
    # - The central main model will be stored under self.model, residing
    #   on self.device (normally, a CPU).
    # - Each GPU will have a copy of that model under
    #   self.model_gpu_towers, matching the devices in self.devices.
    # - Parallelization is done by splitting the train batch and passing
    #   it through the model copies in parallel, then averaging over the
    #   resulting gradients, applying these averages on the main model and
    #   updating all towers' weights from the main model.
    # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
    #   parallelization will be done.

    # If no Model is provided, build a default one here.
    if model is None:
        dist_class, logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"], framework=self.framework)
        model = ModelCatalog.get_model_v2(
            obs_space=self.observation_space,
            action_space=self.action_space,
            num_outputs=logit_dim,
            model_config=self.config["model"],
            framework=self.framework)
        if action_distribution_class is None:
            action_distribution_class = dist_class

    # Get devices to build the graph on.
    worker_idx = self.config.get("worker_index", 0)
    if not config["_fake_gpus"] and \
            ray.worker._mode() == ray.worker.LOCAL_MODE:
        num_gpus = 0
    elif worker_idx == 0:
        num_gpus = config["num_gpus"]
    else:
        num_gpus = config["num_gpus_per_worker"]
    gpu_ids = list(range(torch.cuda.device_count()))

    # Place on one or more CPU(s) when either:
    # - Fake GPU mode.
    # - num_gpus=0 (either set by user or we are in local_mode=True).
    # - No GPUs available.
    if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
        logger.info("TorchPolicy (worker={}) running on {}.".format(
            worker_idx
            if worker_idx > 0 else "local", "{} fake-GPUs".format(num_gpus)
            if config["_fake_gpus"] else "CPU"))
        self.device = torch.device("cpu")
        self.devices = [
            self.device for _ in range(int(math.ceil(num_gpus)) or 1)
        ]
        self.model_gpu_towers = [
            model if i == 0 else copy.deepcopy(model)
            for i in range(int(math.ceil(num_gpus)) or 1)
        ]
        if hasattr(self, "target_model"):
            self.target_models = {
                m: self.target_model
                for m in self.model_gpu_towers
            }
        self.model = model
    # Place on one or more actual GPU(s), when:
    # - num_gpus > 0 (set by user) AND
    # - local_mode=False AND
    # - actual GPUs available AND
    # - non-fake GPU mode.
    else:
        logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
            worker_idx if worker_idx > 0 else "local", num_gpus))
        # We are a remote worker (WORKER_MODE=1):
        # GPUs should be assigned to us by ray.
        if ray.worker._mode() == ray.worker.WORKER_MODE:
            gpu_ids = ray.get_gpu_ids()

        if len(gpu_ids) < num_gpus:
            raise ValueError(
                "TorchPolicy was not able to find enough GPU IDs! Found "
                f"{gpu_ids}, but num_gpus={num_gpus}.")

        self.devices = [
            torch.device("cuda:{}".format(i))
            for i, id_ in enumerate(gpu_ids) if i < num_gpus
        ]
        self.device = self.devices[0]
        ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
        self.model_gpu_towers = []
        for i, _ in enumerate(ids):
            model_copy = copy.deepcopy(model)
            self.model_gpu_towers.append(model_copy.to(self.devices[i]))
        if hasattr(self, "target_model"):
            self.target_models = {
                m: copy.deepcopy(self.target_model).to(self.devices[i])
                for i, m in enumerate(self.model_gpu_towers)
            }
        self.model = self.model_gpu_towers[0]

    # Lock used for locking some methods on the object-level.
    # This prevents possible race conditions when calling the model
    # first, then its value function (e.g. in a loss function), in
    # between of which another model call is made (e.g. to compute an
    # action).
    self._lock = threading.RLock()

    self._state_inputs = self.model.get_initial_state()
    self._is_recurrent = len(self._state_inputs) > 0
    # Auto-update model's inference view requirements, if recurrent.
    self._update_model_view_requirements_from_init_state()
    # Combine view_requirements for Model and Policy.
    self.view_requirements.update(self.model.view_requirements)

    self.exploration = self._create_exploration()
    self.unwrapped_model = model  # used to support DistributedDataParallel
    # To ensure backward compatibility:
    # Old way: If `loss` provided here, use as-is (as a function).
    if loss is not None:
        self._loss = loss
    # New way: Convert the overridden `self.loss` into a plain function,
    # so it can be called the same way as `loss` would be, ensuring
    # backward compatibility.
    elif self.loss.__func__.__qualname__ != "Policy.loss":
        self._loss = self.loss.__func__
    # `loss` not provided nor overridden from Policy -> Set to None.
    else:
        self._loss = None
    self._optimizers = force_list(self.optimizer())
    # Store, which params (by index within the model's list of
    # parameters) should be updated per optimizer.
    # Maps optimizer idx to set or param indices.
    self.multi_gpu_param_groups: List[Set[int]] = []
    main_params = {p: i for i, p in enumerate(self.model.parameters())}
    for o in self._optimizers:
        param_indices = []
        for pg_idx, pg in enumerate(o.param_groups):
            for p in pg["params"]:
                param_indices.append(main_params[p])
        self.multi_gpu_param_groups.append(set(param_indices))

    # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
    # one with m towers (num_gpus).
    num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
    self._loaded_batches = [[] for _ in range(num_buffers)]

    self.dist_class = action_distribution_class
    self.action_sampler_fn = action_sampler_fn
    self.action_distribution_fn = action_distribution_fn

    # If set, means we are using distributed allreduce during learning.
    self.distributed_world_size = None

    self.max_seq_len = max_seq_len
    self.batch_divisibility_req = get_batch_divisibility_req(self) if \
        callable(get_batch_divisibility_req) else \
        (get_batch_divisibility_req or 1)

apply_gradients(self, gradients)

Applies the (previously) computed gradients.

Either this in combination with compute_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters:

Name Type Description Default
gradients Union[List[Tuple[Any, Any]], List[Any]]

The already calculated gradients to apply to this Policy.

required
Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
    if gradients == _directStepOptimizerSingleton:
        for i, opt in enumerate(self._optimizers):
            opt.step()
    else:
        # TODO(sven): Not supported for multiple optimizers yet.
        assert len(self._optimizers) == 1
        for g, p in zip(gradients, self.model.parameters()):
            if g is not None:
                if torch.is_tensor(g):
                    p.grad = g.to(self.device)
                else:
                    p.grad = torch.from_numpy(g).to(self.device)

        self._optimizers[0].step()

compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs)

Computes actions for the current policy.

Parameters:

Name Type Description Default
obs_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of observations.

required
state_batches Optional[List[Any]]

List of RNN state input batches, if any.

None
prev_action_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of previous action values.

None
prev_reward_batch Union[List[Union[Any, dict, tuple]], Any, dict, tuple]

Batch of previous rewards.

None
info_batch Optional[Dict[str, list]]

Batch of info objects.

None
episodes Optional[List[Episode]]

List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None
explore Optional[bool]

Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config["explore"].

None
timestep Optional[int]

The current (sampling) time step.

None

Returns:

Type Description
actions (TensorType)

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def compute_actions(
        self,
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        prev_reward_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs) -> \
        Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:

    with torch.no_grad():
        seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
        input_dict = self._lazy_tensor_dict({
            SampleBatch.CUR_OBS: obs_batch,
            "is_training": False,
        })
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = \
                np.asarray(prev_action_batch)
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = \
                np.asarray(prev_reward_batch)
        state_batches = [
            convert_to_torch_tensor(s, self.device)
            for s in (state_batches or [])
        ]
        return self._compute_action_helper(input_dict, state_batches,
                                           seq_lens, explore, timestep)

compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, **kwargs)

Computes actions from collected samples (across multiple-agents).

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.

Parameters:

Name Type Description Default
input_dict Dict[str, Any]

A SampleBatch or input dict containing the Tensors to compute actions. input_dict already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is.

required
explore bool

Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]).

None
timestep Optional[int]

The current (sampling) time step.

None
episodes

This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms.

required

Returns:

Type Description
actions

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs: List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info: Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
def compute_actions_from_input_dict(
        self,
        input_dict: Dict[str, TensorType],
        explore: bool = None,
        timestep: Optional[int] = None,
        **kwargs) -> \
        Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

    with torch.no_grad():
        # Pass lazy (torch) tensor dict to Model as `input_dict`.
        input_dict = self._lazy_tensor_dict(input_dict)
        input_dict.set_training(True)
        # Pack internal state inputs into (separate) list.
        state_batches = [
            input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
        ]
        # Calculate RNN sequence lengths.
        seq_lens = np.array([1] * len(input_dict["obs"])) \
            if state_batches else None

        return self._compute_action_helper(input_dict, state_batches,
                                           seq_lens, explore, timestep)

export_checkpoint(self, export_dir)

Export Policy checkpoint to local directory.

Parameters:

Name Type Description Default
export_dir str

Local writable directory.

required
Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
def export_checkpoint(self, export_dir: str) -> None:
    raise NotImplementedError

export_model(self, export_dir, onnx=None)

Exports the Policy's Model to local directory for serving.

Creates a TorchScript model and saves it.

Parameters:

Name Type Description Default
export_dir str

Local writable directory or filename.

required
onnx Optional[int]

If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use.

None
Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def export_model(self, export_dir: str,
                 onnx: Optional[int] = None) -> None:
    """Exports the Policy's Model to local directory for serving.

    Creates a TorchScript model and saves it.

    Args:
        export_dir: Local writable directory or filename.
        onnx: If given, will export model in ONNX format. The
            value of this parameter set the ONNX OpSet version to use.
    """
    self._lazy_tensor_dict(self._dummy_batch)
    # Provide dummy state inputs if not an RNN (torch cannot jit with
    # returned empty internal states list).
    if "state_in_0" not in self._dummy_batch:
        self._dummy_batch["state_in_0"] = \
            self._dummy_batch[SampleBatch.SEQ_LENS] = np.array([1.0])

    state_ins = []
    i = 0
    while "state_in_{}".format(i) in self._dummy_batch:
        state_ins.append(self._dummy_batch["state_in_{}".format(i)])
        i += 1
    dummy_inputs = {
        k: self._dummy_batch[k]
        for k in self._dummy_batch.keys() if k != "is_training"
    }

    if not os.path.exists(export_dir):
        os.makedirs(export_dir)

    seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
    if onnx:
        file_name = os.path.join(export_dir, "model.onnx")
        torch.onnx.export(
            self.model, (dummy_inputs, state_ins, seq_lens),
            file_name,
            export_params=True,
            opset_version=onnx,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()) +
            ["state_ins", SampleBatch.SEQ_LENS],
            output_names=["output", "state_outs"],
            dynamic_axes={
                k: {
                    0: "batch_size"
                }
                for k in list(dummy_inputs.keys()) +
                ["state_ins", SampleBatch.SEQ_LENS]
            })
    else:
        traced = torch.jit.trace(self.model,
                                 (dummy_inputs, state_ins, seq_lens))
        file_name = os.path.join(export_dir, "model.pt")
        traced.save(file_name)

extra_action_out(self, input_dict, state_batches, model, action_dist)

Returns dict of extra info to include in experience batch.

Parameters:

Name Type Description Default
input_dict Dict[str, Any]

Dict of model input tensors.

required
state_batches List[Any]

List of state tensors.

required
model TorchModelV2

Reference to the model object.

required
action_dist TorchDistributionWrapper

Torch action dist object to get log-probs (e.g. for already sampled actions).

required

Returns:

Type Description
Dict[str, Any]

Extra outputs to return in a compute_actions_from_input_dict() call (3rd return value).

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def extra_action_out(
        self, input_dict: Dict[str, TensorType],
        state_batches: List[TensorType], model: TorchModelV2,
        action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
    """Returns dict of extra info to include in experience batch.

    Args:
        input_dict: Dict of model input tensors.
        state_batches: List of state tensors.
        model: Reference to the model object.
        action_dist: Torch action dist object
            to get log-probs (e.g. for already sampled actions).

    Returns:
        Extra outputs to return in a `compute_actions_from_input_dict()`
        call (3rd return value).
    """
    return {}

extra_compute_grad_fetches(self)

Extra values to fetch and return from compute_gradients().

Returns:

Type Description
Dict[str, Any]

Extra fetch dict to be added to the fetch dict of the compute_gradients call.

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def extra_compute_grad_fetches(self) -> Dict[str, Any]:
    """Extra values to fetch and return from compute_gradients().

    Returns:
        Extra fetch dict to be added to the fetch dict of the
        `compute_gradients` call.
    """
    return {LEARNER_STATS_KEY: {}}  # e.g, stats, td error, etc.

extra_grad_info(self, train_batch)

Return dict of extra grad info.

Parameters:

Name Type Description Default
train_batch SampleBatch

The training batch for which to produce extra grad info for.

required

Returns:

Type Description
Dict[str, Any]

The info dict carrying grad info per str key.

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def extra_grad_info(self,
                    train_batch: SampleBatch) -> Dict[str, TensorType]:
    """Return dict of extra grad info.

    Args:
        train_batch: The training batch for which to produce
            extra grad info for.

    Returns:
        The info dict carrying grad info per str key.
    """
    return {}

extra_grad_process(self, optimizer, loss)

Called after each optimizer.zero_grad() + loss.backward() call.

Called for each self._optimizers/loss-value pair. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping.

Parameters:

Name Type Description Default
optimizer torch.optim.Optimizer

A torch optimizer object.

required
loss Any

The loss tensor associated with the optimizer.

required

Returns:

Type Description
Dict[str, Any]

An dict with information on the gradient processing step.

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
                       loss: TensorType) -> Dict[str, TensorType]:
    """Called after each optimizer.zero_grad() + loss.backward() call.

    Called for each self._optimizers/loss-value pair.
    Allows for gradient processing before optimizer.step() is called.
    E.g. for gradient clipping.

    Args:
        optimizer: A torch optimizer object.
        loss: The loss tensor associated with the optimizer.

    Returns:
        An dict with information on the gradient processing step.
    """
    return {}

get_initial_state(self)

Returns initial RNN state for the current policy.

Returns:

Type Description
List[TensorType]

Initial RNN state for the current policy.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
    return [
        s.detach().cpu().numpy() for s in self.model.get_initial_state()
    ]

get_num_samples_loaded_into_buffer(self, buffer_index=0)

Returns the number of currently loaded samples in the given buffer.

Parameters:

Name Type Description Default
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description
int

The number of tuples loaded per device.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
    if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
        assert buffer_index == 0
    return len(self._loaded_batches[buffer_index])

get_state(self)

Returns the entire current state of this Policy.

Note: Not to be confused with an RNN model's internal state. State includes the Model(s)' weights, optimizer weights, the exploration component's state, as well as global variables, such as sampling timesteps.

Returns:

Type Description
Union[Dict[str, Any], List[Any]]

Serialized local state.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
    state = super().get_state()
    state["_optimizer_variables"] = []
    for i, o in enumerate(self._optimizers):
        optim_state_dict = convert_to_numpy(o.state_dict())
        state["_optimizer_variables"].append(optim_state_dict)
    # Add exploration state.
    state["_exploration_state"] = \
        self.exploration.get_state()
    return state

get_tower_stats(self, stats_name)

Returns list of per-tower stats, copied to this Policy's device.

Parameters:

Name Type Description Default
stats_name str

The name of the stats to average over (this str must exist as a key inside each tower's tower_stats dict).

required

Returns:

Type Description
List[Union[Any, dict, tuple]]

The list of stats tensor (structs) of all towers, copied to this Policy's device.

Exceptions:

Type Description
AssertionError

If the stats_name cannot be found in any one

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
    """Returns list of per-tower stats, copied to this Policy's device.

    Args:
        stats_name: The name of the stats to average over (this str
            must exist as a key inside each tower's `tower_stats` dict).

    Returns:
        The list of stats tensor (structs) of all towers, copied to this
        Policy's device.

    Raises:
        AssertionError: If the `stats_name` cannot be found in any one
        of the tower's `tower_stats` dicts.
    """
    data = []
    for tower in self.model_gpu_towers:
        if stats_name in tower.tower_stats:
            data.append(
                tree.map_structure(lambda s: s.to(self.device),
                                   tower.tower_stats[stats_name]))
    assert len(data) > 0, \
        f"Stats `{stats_name}` not found in any of the towers (you have " \
        f"{len(self.model_gpu_towers)} towers in total)! Make " \
        "sure you call the loss function on at least one of the towers."
    return data

get_weights(self)

Returns model weights.

Note: The return value of this method will reside under the "weights" key in the return value of Policy.get_state(). Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Returns:

Type Description
dict

Serializable copy or view of model weights.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def get_weights(self) -> ModelWeights:
    return {
        k: v.cpu().detach().numpy()
        for k, v in self.model.state_dict().items()
    }

import_model_from_h5(self, import_file)

Imports weights into torch model.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
    """Imports weights into torch model."""
    return self.model.import_from_h5(import_file)

is_recurrent(self)

Whether this Policy holds a recurrent Model.

Returns:

Type Description
bool

True if this Policy has-a RNN-based Model.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def is_recurrent(self) -> bool:
    return self._is_recurrent

learn_on_loaded_batch(self, offset=0, buffer_index=0)

Runs a single step of SGD on an already loaded data in a buffer.

Runs an SGD step over a slice of the pre-loaded batch, offset by the offset argument (useful for performing n minibatch SGD updates repeatedly on the same, already pre-loaded data).

Updates the model weights based on the averaged per-device gradients.

Parameters:

Name Type Description Default
offset int

Offset into the preloaded data. Used for pre-loading a train-batch once to a device, then iterating over (subsampling through) this batch n times doing minibatch SGD.

0
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to take the already pre-loaded data from. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description

The outputs of extra_ops evaluated over the batch.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
    if not self._loaded_batches[buffer_index]:
        raise ValueError(
            "Must call Policy.load_batch_into_buffer() before "
            "Policy.learn_on_loaded_batch()!")

    # Get the correct slice of the already loaded batch to use,
    # based on offset and batch size.
    device_batch_size = \
        self.config.get(
            "sgd_minibatch_size", self.config["train_batch_size"]) // \
        len(self.devices)

    # Set Model to train mode.
    if self.model_gpu_towers:
        for t in self.model_gpu_towers:
            t.train()

    # Shortcut for 1 CPU only: Batch should already be stored in
    # `self._loaded_batches`.
    if len(self.devices) == 1 and self.devices[0].type == "cpu":
        assert buffer_index == 0
        if device_batch_size >= len(self._loaded_batches[0][0]):
            batch = self._loaded_batches[0][0]
        else:
            batch = self._loaded_batches[0][0][offset:offset +
                                               device_batch_size]
        return self.learn_on_batch(batch)

    if len(self.devices) > 1:
        # Copy weights of main model (tower-0) to all other towers.
        state_dict = self.model.state_dict()
        # Just making sure tower-0 is really the same as self.model.
        assert self.model_gpu_towers[0] is self.model
        for tower in self.model_gpu_towers[1:]:
            tower.load_state_dict(state_dict)

    if device_batch_size >= sum(
            len(s) for s in self._loaded_batches[buffer_index]):
        device_batches = self._loaded_batches[buffer_index]
    else:
        device_batches = [
            b[offset:offset + device_batch_size]
            for b in self._loaded_batches[buffer_index]
        ]

    # Do the (maybe parallelized) gradient calculation step.
    tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)

    # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
    all_grads = []
    for i in range(len(tower_outputs[0][0])):
        if tower_outputs[0][0][i] is not None:
            all_grads.append(
                torch.mean(
                    torch.stack(
                        [t[0][i].to(self.device) for t in tower_outputs]),
                    dim=0))
        else:
            all_grads.append(None)
    # Set main model's grads to mean-reduced values.
    for i, p in enumerate(self.model.parameters()):
        p.grad = all_grads[i]

    self.apply_gradients(_directStepOptimizerSingleton)

    batch_fetches = {}
    for i, batch in enumerate(device_batches):
        batch_fetches[f"tower_{i}"] = {
            LEARNER_STATS_KEY: self.extra_grad_info(batch)
        }

    batch_fetches.update(self.extra_compute_grad_fetches())

    return batch_fetches

load_batch_into_buffer(self, batch, buffer_index=0)

Bulk-loads the given SampleBatch into the devices' memories.

The data is split equally across all the Policy's devices. If the data is not evenly divisible by the batch size, excess data should be discarded.

Parameters:

Name Type Description Default
batch SampleBatch

The SampleBatch to load.

required
buffer_index int

The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

0

Returns:

Type Description
int

The number of tuples loaded per device.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def load_batch_into_buffer(
        self,
        batch: SampleBatch,
        buffer_index: int = 0,
) -> int:
    # Set the is_training flag of the batch.
    batch.set_training(True)

    # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
    if len(self.devices) == 1 and self.devices[0].type == "cpu":
        assert buffer_index == 0
        pad_batch_to_sequences_of_same_size(
            batch=batch,
            max_seq_len=self.max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req,
            view_requirements=self.view_requirements,
        )
        self._lazy_tensor_dict(batch)
        self._loaded_batches[0] = [batch]
        return len(batch)

    # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
    # 0123 0123456 0123 0123456789ABC

    # 1) split into n per-GPU sub batches (n=2).
    # [0123 0123456] [012] [3 0123456789 ABC]
    # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
    slices = batch.timeslices(num_slices=len(self.devices))

    # 2) zero-padding (max-seq-len=10).
    # - [0123000000 0123456000 0120000000]
    # - [3000000000 0123456789 ABC0000000]
    for slice in slices:
        pad_batch_to_sequences_of_same_size(
            batch=slice,
            max_seq_len=self.max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req,
            view_requirements=self.view_requirements,
        )

    # 3) Load splits into the given buffer (consisting of n GPUs).
    slices = [
        slice.to_device(self.devices[i]) for i, slice in enumerate(slices)
    ]
    self._loaded_batches[buffer_index] = slices

    # Return loaded samples per-device.
    return len(slices[0])

num_state_tensors(self)

The number of internal states needed by the RNN-Model of the Policy.

Returns:

Type Description
int

The number of RNN internal states kept by this Policy's Model.

Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def num_state_tensors(self) -> int:
    return len(self.model.get_initial_state())

optimizer(self)

Custom the local PyTorch optimizer(s) to use.

Returns:

Type Description
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]

The local PyTorch optimizer(s) to use for this Policy.

Source code in ray/rllib/policy/torch_policy.py
@DeveloperAPI
def optimizer(
        self
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
    """Custom the local PyTorch optimizer(s) to use.

    Returns:
        The local PyTorch optimizer(s) to use for this Policy.
    """
    if hasattr(self, "config"):
        optimizers = [
            torch.optim.Adam(
                self.model.parameters(), lr=self.config["lr"])
        ]
    else:
        optimizers = [torch.optim.Adam(self.model.parameters())]
    if getattr(self, "exploration", None):
        optimizers = self.exploration.get_exploration_optimizer(optimizers)
    return optimizers

set_state(self, state)

Restores the entire current state of this Policy from state.

Parameters:

Name Type Description Default
state dict

The new state to set this policy to. Can be obtained by calling self.get_state().

required
Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def set_state(self, state: dict) -> None:
    # Set optimizer vars first.
    optimizer_vars = state.get("_optimizer_variables", None)
    if optimizer_vars:
        assert len(optimizer_vars) == len(self._optimizers)
        for o, s in zip(self._optimizers, optimizer_vars):
            optim_state_dict = convert_to_torch_tensor(
                s, device=self.device)
            o.load_state_dict(optim_state_dict)
    # Set exploration's state.
    if hasattr(self, "exploration") and "_exploration_state" in state:
        self.exploration.set_state(state=state["_exploration_state"])
    # Then the Policy's (NN) weights.
    super().set_state(state)

set_weights(self, weights)

Sets this Policy's model's weights.

Note: Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Parameters:

Name Type Description Default
weights dict

Serializable copy or view of model weights.

required
Source code in ray/rllib/policy/torch_policy.py
@override(Policy)
@DeveloperAPI
def set_weights(self, weights: ModelWeights) -> None:
    weights = convert_to_torch_tensor(weights, device=self.device)
    self.model.load_state_dict(weights)

ray.rllib.policy.tf_policy.TFPolicy (Policy)

An agent policy and loss implemented in TensorFlow.

Do not sub-class this class directly (neither should you sub-class DynamicTFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy to generate your custom tf (graph-mode or eager) Policy classes.

Extending this class enables RLlib to perform TensorFlow specific optimizations on the policy, e.g., parallelization across gpus or fusing multiple graphs together in the multi-agent setting.

Input tensors are typically shaped like [BATCH_SIZE, ...].

Examples:

>>> policy = TFPolicySubclass(
    sess, obs_input, sampled_action, loss, loss_inputs)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
SampleBatch({"action": ..., "advantages": ..., ...})

__init__(self, observation_space, action_space, config, sess, obs_input, sampled_action, loss, loss_inputs, model=None, sampled_action_logp=None, action_input=None, log_likelihood=None, dist_inputs=None, dist_class=None, state_inputs=None, state_outputs=None, prev_action_input=None, prev_reward_input=None, seq_lens=None, max_seq_len=20, batch_divisibility_req=1, update_ops=None, explore=None, timestep=None) special

Initializes a Policy object.

Parameters:

Name Type Description Default
observation_space Space

Observation space of the policy.

required
action_space Space

Action space of the policy.

required
config dict

Policy-specific configuration data.

required
sess tf1.Session

The TensorFlow session to use.

required
obs_input Any

Input placeholder for observations, of shape [BATCH_SIZE, obs...].

required
sampled_action Any

Tensor for sampling an action, of shape [BATCH_SIZE, action...]

required
loss Union[Any, List[Any]]

Scalar policy loss output tensor or a list thereof (in case there is more than one loss).

required
loss_inputs List[Tuple[str, Any]]

A (name, placeholder) tuple for each loss input argument. Each placeholder name must correspond to a SampleBatch column key returned by postprocess_trajectory(), and has shape [BATCH_SIZE, data...]. These keys will be read from postprocessed sample batches and fed into the specified placeholders during loss computation.

required
model Optional[ray.rllib.models.modelv2.ModelV2]

The optional ModelV2 to use for calculating actions and losses. If not None, TFPolicy will provide functionality for getting variables, calling the model's custom loss (if provided), and importing weights into the model.

None
sampled_action_logp Optional[Any]

log probability of the sampled action.

None
action_input Optional[Any]

Input placeholder for actions for logp/log-likelihood calculations.

None
log_likelihood Optional[Any]

Tensor to calculate the log_likelihood (given action_input and obs_input).

None
dist_class Optional[type]

An optional ActionDistribution class to use for generating a dist object from distribution inputs.

None
dist_inputs Optional[Any]

Tensor to calculate the distribution inputs/parameters.

None
state_inputs Optional[List[Any]]

List of RNN state input Tensors.

None
state_outputs Optional[List[Any]]

List of RNN state output Tensors.

None
prev_action_input Optional[Any]

placeholder for previous actions.

None
prev_reward_input Optional[Any]

placeholder for previous rewards.

None
seq_lens Optional[Any]

Placeholder for RNN sequence lengths, of shape [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See policy/rnn_sequencing.py for more information.

None
max_seq_len int

Max sequence length for LSTM training.

20
batch_divisibility_req int

pad all agent experiences batches to multiples of this value. This only has an effect if not using a LSTM model.

1
update_ops List[Any]

override the batchnorm update ops to run when applying gradients. Otherwise we run all update ops found in the current variable scope.

None
explore Optional[Any]

Placeholder for explore parameter into call to Exploration.get_exploration_action. Explicitly set this to False for not creating any Exploration component.

None
timestep Optional[Any]

Placeholder for the global sampling timestep.

None
Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def __init__(self,
             observation_space: gym.spaces.Space,
             action_space: gym.spaces.Space,
             config: TrainerConfigDict,
             sess: "tf1.Session",
             obs_input: TensorType,
             sampled_action: TensorType,
             loss: Union[TensorType, List[TensorType]],
             loss_inputs: List[Tuple[str, TensorType]],
             model: Optional[ModelV2] = None,
             sampled_action_logp: Optional[TensorType] = None,
             action_input: Optional[TensorType] = None,
             log_likelihood: Optional[TensorType] = None,
             dist_inputs: Optional[TensorType] = None,
             dist_class: Optional[type] = None,
             state_inputs: Optional[List[TensorType]] = None,
             state_outputs: Optional[List[TensorType]] = None,
             prev_action_input: Optional[TensorType] = None,
             prev_reward_input: Optional[TensorType] = None,
             seq_lens: Optional[TensorType] = None,
             max_seq_len: int = 20,
             batch_divisibility_req: int = 1,
             update_ops: List[TensorType] = None,
             explore: Optional[TensorType] = None,
             timestep: Optional[TensorType] = None):
    """Initializes a Policy object.

    Args:
        observation_space: Observation space of the policy.
        action_space: Action space of the policy.
        config: Policy-specific configuration data.
        sess: The TensorFlow session to use.
        obs_input: Input placeholder for observations, of shape
            [BATCH_SIZE, obs...].
        sampled_action: Tensor for sampling an action, of shape
            [BATCH_SIZE, action...]
        loss: Scalar policy loss output tensor or a list thereof
            (in case there is more than one loss).
        loss_inputs: A (name, placeholder) tuple for each loss input
            argument. Each placeholder name must
            correspond to a SampleBatch column key returned by
            postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
            These keys will be read from postprocessed sample batches and
            fed into the specified placeholders during loss computation.
        model: The optional ModelV2 to use for calculating actions and
            losses. If not None, TFPolicy will provide functionality for
            getting variables, calling the model's custom loss (if
            provided), and importing weights into the model.
        sampled_action_logp: log probability of the sampled action.
        action_input: Input placeholder for actions for
            logp/log-likelihood calculations.
        log_likelihood: Tensor to calculate the log_likelihood (given
            action_input and obs_input).
        dist_class: An optional ActionDistribution class to use for
            generating a dist object from distribution inputs.
        dist_inputs: Tensor to calculate the distribution
            inputs/parameters.
        state_inputs: List of RNN state input Tensors.
        state_outputs: List of RNN state output Tensors.
        prev_action_input: placeholder for previous actions.
        prev_reward_input: placeholder for previous rewards.
        seq_lens: Placeholder for RNN sequence lengths, of shape
            [NUM_SEQUENCES].
            Note that NUM_SEQUENCES << BATCH_SIZE. See
            policy/rnn_sequencing.py for more information.
        max_seq_len: Max sequence length for LSTM training.
        batch_divisibility_req: pad all agent experiences batches to
            multiples of this value. This only has an effect if not using
            a LSTM model.
        update_ops: override the batchnorm update ops
            to run when applying gradients. Otherwise we run all update
            ops found in the current variable scope.
        explore: Placeholder for `explore` parameter into call to
            Exploration.get_exploration_action. Explicitly set this to
            False for not creating any Exploration component.
        timestep: Placeholder for the global sampling timestep.
    """
    self.framework = "tf"
    super().__init__(observation_space, action_space, config)

    # Get devices to build the graph on.
    worker_idx = self.config.get("worker_index", 0)
    if not config["_fake_gpus"] and \
            ray.worker._mode() == ray.worker.LOCAL_MODE:
        num_gpus = 0
    elif worker_idx == 0:
        num_gpus = config["num_gpus"]
    else:
        num_gpus = config["num_gpus_per_worker"]
    gpu_ids = get_gpu_devices()

    # Place on one or more CPU(s) when either:
    # - Fake GPU mode.
    # - num_gpus=0 (either set by user or we are in local_mode=True).
    # - no GPUs available.
    if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
        logger.info("TFPolicy (worker={}) running on {}.".format(
            worker_idx
            if worker_idx > 0 else "local", f"{num_gpus} fake-GPUs"
            if config["_fake_gpus"] else "CPU"))
        self.devices = [
            "/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)
        ]
    # Place on one or more actual GPU(s), when:
    # - num_gpus > 0 (set by user) AND
    # - local_mode=False AND
    # - actual GPUs available AND
    # - non-fake GPU mode.
    else:
        logger.info("TFPolicy (worker={}) running on {} GPU(s).".format(
            worker_idx if worker_idx > 0 else "local", num_gpus))

        # We are a remote worker (WORKER_MODE=1):
        # GPUs should be assigned to us by ray.
        if ray.worker._mode() == ray.worker.WORKER_MODE:
            gpu_ids = ray.get_gpu_ids()

        if len(gpu_ids) < num_gpus:
            raise ValueError(
                "TFPolicy was not able to find enough GPU IDs! Found "
                f"{gpu_ids}, but num_gpus={num_gpus}.")

        self.devices = [
            f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus
        ]

    # Disable env-info placeholder.
    if SampleBatch.INFOS in self.view_requirements:
        self.view_requirements[SampleBatch.INFOS].used_for_training = False
        self.view_requirements[
            SampleBatch.INFOS].used_for_compute_actions = False

    assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), \
        "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` " \
        "not allowed! You passed in {}.".format(model)
    self.model = model
    # Auto-update model's inference view requirements, if recurrent.
    if self.model is not None:
        self._update_model_view_requirements_from_init_state()

    # If `explore` is explicitly set to False, don't create an exploration
    # component.
    self.exploration = self._create_exploration() if explore is not False \
        else None

    self._sess = sess
    self._obs_input = obs_input
    self._prev_action_input = prev_action_input
    self._prev_reward_input = prev_reward_input
    self._sampled_action = sampled_action
    self._is_training = self._get_is_training_placeholder()
    self._is_exploring = explore if explore is not None else \
        tf1.placeholder_with_default(True, (), name="is_exploring")
    self._sampled_action_logp = sampled_action_logp
    self._sampled_action_prob = (tf.math.exp(self._sampled_action_logp)
                                 if self._sampled_action_logp is not None
                                 else None)
    self._action_input = action_input  # For logp calculations.
    self._dist_inputs = dist_inputs
    self.dist_class = dist_class

    self._state_inputs = state_inputs or []
    self._state_outputs = state_outputs or []
    self._seq_lens = seq_lens
    self._max_seq_len = max_seq_len

    if self._state_inputs and self._seq_lens is None:
        raise ValueError(
            "seq_lens tensor must be given if state inputs are defined")

    self._batch_divisibility_req = batch_divisibility_req
    self._update_ops = update_ops
    self._apply_op = None
    self._stats_fetches = {}
    self._timestep = timestep if timestep is not None else \
        tf1.placeholder_with_default(
            tf.zeros((), dtype=tf.int64), (), name="timestep")

    self._optimizers: List[LocalOptimizer] = []
    # Backward compatibility and for some code shared with tf-eager Policy.
    self._optimizer = None

    self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
    self._grads: Union[ModelGradients, List[ModelGradients]] = []
    # Policy tf-variables (weights), whose values to get/set via
    # get_weights/set_weights.
    self._variables = None
    # Local optimizer(s)' tf-variables (e.g. state vars for Adam).
    # Will be stored alongside `self._variables` when checkpointing.
    self._optimizer_variables: \
        Optional[ray.experimental.tf_utils.TensorFlowVariables] = None

    # The loss tf-op(s). Number of losses must match number of optimizers.
    self._losses = []
    # Backward compatibility (in case custom child TFPolicies access this
    # property).
    self._loss = None
    # A batch dict passed into loss function as input.
    self._loss_input_dict = {}
    losses = force_list(loss)
    if len(losses) > 0:
        self._initialize_loss(losses, loss_inputs)

    # The log-likelihood calculator op.
    self._log_likelihood = log_likelihood
    if self._log_likelihood is None and self._dist_inputs is not None and \
            self.dist_class is not None:
        self._log_likelihood = self.dist_class(
            self._dist_inputs, self.model).logp(self._action_input)

apply_gradients(self, gradients)

Applies the (previously) computed gradients.

Either this in combination with compute_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters:

Name Type Description Default
gradients Union[List[Tuple[Any, Any]], List[Any]]

The already calculated gradients to apply to this Policy.

required
Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
    assert self.loss_initialized()
    builder = TFRunBuilder(self.get_session(), "apply_gradients")
    fetches = self._build_apply_gradients(builder, gradients)
    builder.get(fetches)

build_apply_op(self, optimizer, grads_and_vars)

Override this for a custom gradient apply computation behavior.

Parameters:

Name Type Description Default
optimizer Union[LocalOptimizer, List[LocalOptimizer]]

The local tf optimizer to use for applying the grads and vars.

required
grads_and_vars Union[ModelGradients, List[ModelGradients]]

List of tuples with grad values and the grad-value's corresponding tf.variable in it.

required

Returns:

Type Description
tf.Operation

The tf op that applies all computed gradients (grads_and_vars) to the model(s) via the given optimizer(s).

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def build_apply_op(
        self,
        optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
        grads_and_vars: Union[ModelGradients, List[ModelGradients]],
) -> "tf.Operation":
    """Override this for a custom gradient apply computation behavior.

    Args:
        optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
            tf optimizer to use for applying the grads and vars.
        grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
            of tuples with grad values and the grad-value's corresponding
            tf.variable in it.

    Returns:
        tf.Operation: The tf op that applies all computed gradients
            (`grads_and_vars`) to the model(s) via the given optimizer(s).
    """
    optimizers = force_list(optimizer)

    # We have more than one optimizers and loss terms.
    if self.config["_tf_policy_handles_more_than_one_loss"]:
        ops = []
        for i, optim in enumerate(optimizers):
            # Specify global_step (e.g. for TD3 which needs to count the
            # num updates that have happened).
            ops.append(
                optim.apply_gradients(
                    grads_and_vars[i],
                    global_step=tf1.train.get_or_create_global_step()))
        return tf.group(ops)
    # We have only one optimizer and one loss term.
    else:
        return optimizers[0].apply_gradients(
            grads_and_vars,
            global_step=tf1.train.get_or_create_global_step())

compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs)

Computes actions for the current policy.

Parameters:

Name Type Description Default
obs_batch Union[List[Any], Any]

Batch of observations.

required
state_batches Optional[List[Any]]

List of RNN state input batches, if any.

None
prev_action_batch Union[List[Any], Any]

Batch of previous action values.

None
prev_reward_batch Union[List[Any], Any]

Batch of previous rewards.

None
info_batch Optional[Dict[str, list]]

Batch of info objects.

None
episodes Optional[List[Episode]]

List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None
explore Optional[bool]

Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config["explore"].

None
timestep Optional[int]

The current (sampling) time step.

None

Returns:

Type Description
actions (TensorType)

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
def compute_actions(
        self,
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorType], TensorType] = None,
        prev_reward_batch: Union[List[TensorType], TensorType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs):

    explore = explore if explore is not None else self.config["explore"]
    timestep = timestep if timestep is not None else self.global_timestep

    builder = TFRunBuilder(self.get_session(), "compute_actions")

    input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
    if state_batches:
        for i, s in enumerate(state_batches):
            input_dict[f"state_in_{i}"] = s
    if prev_action_batch is not None:
        input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
    if prev_reward_batch is not None:
        input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch

    to_fetch = self._build_compute_actions(
        builder, input_dict=input_dict, explore=explore, timestep=timestep)

    # Execute session run to get action (and other fetches).
    fetched = builder.get(to_fetch)

    # Update our global timestep by the batch size.
    self.global_timestep += \
        len(obs_batch) if isinstance(obs_batch, list) \
        else tree.flatten(obs_batch)[0].shape[0]

    return fetched

compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs)

Computes actions from collected samples (across multiple-agents).

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.

Parameters:

Name Type Description Default
input_dict Union[ray.rllib.policy.sample_batch.SampleBatch, Dict[str, Any]]

A SampleBatch or input dict containing the Tensors to compute actions. input_dict already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is.

required
explore bool

Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]).

None
timestep Optional[int]

The current (sampling) time step.

None
episodes Optional[List[Episode]]

This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms.

None

Returns:

Type Description
actions

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs: List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info: Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
def compute_actions_from_input_dict(
        self,
        input_dict: Union[SampleBatch, Dict[str, TensorType]],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes: Optional[List["Episode"]] = None,
        **kwargs) -> \
        Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

    explore = explore if explore is not None else self.config["explore"]
    timestep = timestep if timestep is not None else self.global_timestep

    # Switch off is_training flag in our batch.
    input_dict["is_training"] = False

    builder = TFRunBuilder(self.get_session(),
                           "compute_actions_from_input_dict")
    obs_batch = input_dict[SampleBatch.OBS]
    to_fetch = self._build_compute_actions(
        builder, input_dict=input_dict, explore=explore, timestep=timestep)

    # Execute session run to get action (and other fetches).
    fetched = builder.get(to_fetch)

    # Update our global timestep by the batch size.
    self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
        else len(input_dict) if isinstance(input_dict, SampleBatch) \
        else obs_batch.shape[0]

    return fetched

compute_gradients(self, postprocessed_batch)

Computes gradients given a batch of experiences.

Either this in combination with apply_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters:

Name Type Description Default
postprocessed_batch SampleBatch

The SampleBatch object to use for calculating gradients.

required

Returns:

Type Description
grads

List of gradient output values. grad_info: Extra policy-specific info values.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def compute_gradients(
        self,
        postprocessed_batch: SampleBatch) -> \
        Tuple[ModelGradients, Dict[str, TensorType]]:
    assert self.loss_initialized()
    # Switch on is_training flag in our batch.
    postprocessed_batch.set_training(True)
    builder = TFRunBuilder(self.get_session(), "compute_gradients")
    fetches = self._build_compute_gradients(builder, postprocessed_batch)
    return builder.get(fetches)

compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, actions_normalized=True)

Computes the log-prob/likelihood for a given action and observation.

The log-likelihood is calculated using this Policy's action distribution class (self.dist_class).

Parameters:

Name Type Description Default
actions Union[List[Any], Any]

Batch of actions, for which to retrieve the log-probs/likelihoods (given all other inputs: obs, states, ..).

required
obs_batch Union[List[Any], Any]

Batch of observations.

required
state_batches Optional[List[Any]]

List of RNN state input batches, if any.

None
prev_action_batch Union[List[Any], Any]

Batch of previous action values.

None
prev_reward_batch Union[List[Any], Any]

Batch of previous rewards.

None
actions_normalized bool

Is the given actions already normalized (between -1.0 and 1.0) or not? If not and normalize_actions=True, we need to normalize the given actions first, before calculating log likelihoods.

True

Returns:

Type Description
Batch of log probs/likelihoods, with shape

[BATCH_SIZE].

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        actions_normalized: bool = True,
) -> TensorType:

    if self._log_likelihood is None:
        raise ValueError("Cannot compute log-prob/likelihood w/o a "
                         "self._log_likelihood op!")

    # Exploration hook before each forward pass.
    self.exploration.before_compute_actions(
        explore=False, tf_sess=self.get_session())

    builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods")

    # Normalize actions if necessary.
    if actions_normalized is False and self.config["normalize_actions"]:
        actions = normalize_action(actions, self.action_space_struct)

    # Feed actions (for which we want logp values) into graph.
    builder.add_feed_dict({self._action_input: actions})
    # Feed observations.
    builder.add_feed_dict({self._obs_input: obs_batch})
    # Internal states.
    state_batches = state_batches or []
    if len(self._state_inputs) != len(state_batches):
        raise ValueError(
            "Must pass in RNN state batches for placeholders {}, got {}".
            format(self._state_inputs, state_batches))
    builder.add_feed_dict(
        {k: v
         for k, v in zip(self._state_inputs, state_batches)})
    if state_batches:
        builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
    # Prev-a and r.
    if self._prev_action_input is not None and \
       prev_action_batch is not None:
        builder.add_feed_dict({self._prev_action_input: prev_action_batch})
    if self._prev_reward_input is not None and \
       prev_reward_batch is not None:
        builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
    # Fetch the log_likelihoods output and return.
    fetches = builder.add_fetches([self._log_likelihood])
    return builder.get(fetches)[0]

copy(self, existing_inputs)

Creates a copy of self using existing input placeholders.

Optional: Only required to work with the multi-GPU optimizer.

Parameters:

Name Type Description Default
existing_inputs List[Tuple[str, tf1.placeholder]]

Dict mapping names (str) to tf1.placeholders to re-use (share) with the returned copy of self.

required

Returns:

Type Description
TFPolicy

A copy of self.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def copy(self,
         existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> \
        "TFPolicy":
    """Creates a copy of self using existing input placeholders.

    Optional: Only required to work with the multi-GPU optimizer.

    Args:
        existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
            names (str) to tf1.placeholders to re-use (share) with the
            returned copy of self.

    Returns:
        TFPolicy: A copy of self.
    """
    raise NotImplementedError

export_checkpoint(self, export_dir, filename_prefix='model')

Export tensorflow checkpoint to export_dir.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def export_checkpoint(self,
                      export_dir: str,
                      filename_prefix: str = "model") -> None:
    """Export tensorflow checkpoint to export_dir."""
    try:
        os.makedirs(export_dir)
    except OSError as e:
        # ignore error if export dir already exists
        if e.errno != errno.EEXIST:
            raise
    save_path = os.path.join(export_dir, filename_prefix)
    with self.get_session().graph.as_default():
        saver = tf1.train.Saver()
        saver.save(self.get_session(), save_path)

export_model(self, export_dir, onnx=None)

Export tensorflow graph to export_dir for serving.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def export_model(self, export_dir: str,
                 onnx: Optional[int] = None) -> None:
    """Export tensorflow graph to export_dir for serving."""
    if onnx:
        try:
            import tf2onnx
        except ImportError as e:
            raise RuntimeError(
                "Converting a TensorFlow model to ONNX requires "
                "`tf2onnx` to be installed. Install with "
                "`pip install tf2onnx`.") from e

        with self.get_session().graph.as_default():
            signature_def_map = self._build_signature_def()

            sd = signature_def_map[tf1.saved_model.signature_constants.
                                   DEFAULT_SERVING_SIGNATURE_DEF_KEY]
            inputs = [v.name for k, v in sd.inputs.items()]
            outputs = [v.name for k, v in sd.outputs.items()]

            from tf2onnx import tf_loader
            frozen_graph_def = tf_loader.freeze_session(
                self._sess, input_names=inputs, output_names=outputs)

        with tf1.Session(graph=tf.Graph()) as session:
            tf.import_graph_def(frozen_graph_def, name="")

            g = tf2onnx.tfonnx.process_tf_graph(
                session.graph,
                input_names=inputs,
                output_names=outputs,
                inputs_as_nchw=inputs)

            model_proto = g.make_model("onnx_model")
            tf2onnx.utils.save_onnx_model(
                export_dir,
                "saved_model",
                feed_dict={},
                model_proto=model_proto)
    else:
        with self.get_session().graph.as_default():
            signature_def_map = self._build_signature_def()
            builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
            builder.add_meta_graph_and_variables(
                self.get_session(),
                [tf1.saved_model.tag_constants.SERVING],
                signature_def_map=signature_def_map,
                saver=tf1.summary.FileWriter(export_dir).add_graph(
                    graph=self.get_session().graph))
            builder.save()

extra_compute_action_feed_dict(self)

Extra dict to pass to the compute actions session run.

Returns:

Type Description
Dict[TensorType, TensorType]

A feed dict to be added to the feed_dict passed to the compute_actions session.run() call.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
    """Extra dict to pass to the compute actions session run.

    Returns:
        Dict[TensorType, TensorType]: A feed dict to be added to the
            feed_dict passed to the compute_actions session.run() call.
    """
    return {}

extra_compute_action_fetches(self)

Extra values to fetch and return from compute_actions().

By default we return action probability/log-likelihood info and action distribution inputs (if present).

Returns:

Type Description
Dict[str, TensorType]

An extra fetch-dict to be passed to and returned from the compute_actions() call.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
    """Extra values to fetch and return from compute_actions().

    By default we return action probability/log-likelihood info
    and action distribution inputs (if present).

    Returns:
         Dict[str, TensorType]: An extra fetch-dict to be passed to and
            returned from the compute_actions() call.
    """
    extra_fetches = {}
    # Action-logp and action-prob.
    if self._sampled_action_logp is not None:
        extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
        extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
    # Action-dist inputs.
    if self._dist_inputs is not None:
        extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
    return extra_fetches

extra_compute_grad_feed_dict(self)

Extra dict to pass to the compute gradients session run.

Returns:

Type Description
Dict[TensorType, TensorType]

Extra feed_dict to be passed to the compute_gradients Session.run() call.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
    """Extra dict to pass to the compute gradients session run.

    Returns:
        Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
            compute_gradients Session.run() call.
    """
    return {}  # e.g, kl_coeff

extra_compute_grad_fetches(self)

Extra values to fetch and return from compute_gradients().

Returns:

Type Description
Dict[str, any]

Extra fetch dict to be added to the fetch dict of the compute_gradients Session.run() call.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def extra_compute_grad_fetches(self) -> Dict[str, any]:
    """Extra values to fetch and return from compute_gradients().

    Returns:
        Dict[str, any]: Extra fetch dict to be added to the fetch dict
            of the compute_gradients Session.run() call.
    """
    return {LEARNER_STATS_KEY: {}}  # e.g, stats, td error, etc.

get_exploration_state(self)

Returns the state of this Policy's exploration component.

Returns:

Type Description
Dict[str, Any]

Serializable information on the self.exploration object.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def get_exploration_state(self) -> Dict[str, TensorType]:
    return self.exploration.get_state(sess=self.get_session())

get_placeholder(self, name)

Returns the given action or loss input placeholder by name.

If the loss has not been initialized and a loss input placeholder is requested, an error is raised.

Parameters:

Name Type Description Default
name str

The name of the placeholder to return. One of SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from self._loss_input_dict.

required

Returns:

Type Description
tf1.placeholder

The placeholder under the given str key.

Source code in ray/rllib/policy/tf_policy.py
def get_placeholder(self, name) -> "tf1.placeholder":
    """Returns the given action or loss input placeholder by name.

    If the loss has not been initialized and a loss input placeholder is
    requested, an error is raised.

    Args:
        name (str): The name of the placeholder to return. One of
            SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
            `self._loss_input_dict`.

    Returns:
        tf1.placeholder: The placeholder under the given str key.
    """
    if name == SampleBatch.CUR_OBS:
        return self._obs_input
    elif name == SampleBatch.PREV_ACTIONS:
        return self._prev_action_input
    elif name == SampleBatch.PREV_REWARDS:
        return self._prev_reward_input

    assert self._loss_input_dict, \
        "You need to populate `self._loss_input_dict` before " \
        "`get_placeholder()` can be called"
    return self._loss_input_dict[name]

get_session(self)

Returns a reference to the TF session for this policy.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
def get_session(self) -> Optional["tf1.Session"]:
    """Returns a reference to the TF session for this policy."""
    return self._sess

get_state(self)

Returns the entire current state of this Policy.

Note: Not to be confused with an RNN model's internal state. State includes the Model(s)' weights, optimizer weights, the exploration component's state, as well as global variables, such as sampling timesteps.

Returns:

Type Description
Union[Dict[str, Any], List[Any]]

Serialized local state.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
    # For tf Policies, return Policy weights and optimizer var values.
    state = super().get_state()
    if len(self._optimizer_variables.variables) > 0:
        state["_optimizer_variables"] = \
            self.get_session().run(self._optimizer_variables.variables)
    # Add exploration state.
    state["_exploration_state"] = \
        self.exploration.get_state(self.get_session())
    return state

get_weights(self)

Returns model weights.

Note: The return value of this method will reside under the "weights" key in the return value of Policy.get_state(). Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Returns:

Type Description
Union[Dict[str, Any], List[Any]]

Serializable copy or view of model weights.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
    return self._variables.get_weights()

gradients(self, optimizer, loss)

Override this for a custom gradient computation behavior.

Parameters:

Name Type Description Default
optimizer Union[LocalOptimizer, List[LocalOptimizer]]

A single LocalOptimizer of a list thereof to use for gradient calculations. If more than one optimizer given, the number of optimizers must match the number of losses provided.

required
loss Union[TensorType, List[TensorType]]

A single loss term or a list thereof to use for gradient calculations. If more than one loss given, the number of loss terms must match the number of optimizers provided.

required

Returns:

Type Description
Union[List[ModelGradients], List[List[ModelGradients]]]

List of ModelGradients (grads and vars OR just grads) OR List of List of ModelGradients in case we have more than one optimizer/loss.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def gradients(
        self,
        optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
        loss: Union[TensorType, List[TensorType]],
) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
    """Override this for a custom gradient computation behavior.

    Args:
        optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
            LocalOptimizer of a list thereof to use for gradient
            calculations. If more than one optimizer given, the number of
            optimizers must match the number of losses provided.
        loss (Union[TensorType, List[TensorType]]): A single loss term
            or a list thereof to use for gradient calculations.
            If more than one loss given, the number of loss terms must
            match the number of optimizers provided.

    Returns:
        Union[List[ModelGradients], List[List[ModelGradients]]]: List of
            ModelGradients (grads and vars OR just grads) OR List of List
            of ModelGradients in case we have more than one
            optimizer/loss.
    """
    optimizers = force_list(optimizer)
    losses = force_list(loss)

    # We have more than one optimizers and loss terms.
    if self.config["_tf_policy_handles_more_than_one_loss"]:
        grads = []
        for optim, loss_ in zip(optimizers, losses):
            grads.append(optim.compute_gradients(loss_))
    # We have only one optimizer and one loss term.
    else:
        return optimizers[0].compute_gradients(losses[0])

import_model_from_h5(self, import_file)

Imports weights into tf model.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
    """Imports weights into tf model."""
    if self.model is None:
        raise NotImplementedError("No `self.model` to import into!")

    # Make sure the session is the right one (see issue #7046).
    with self.get_session().graph.as_default():
        with self.get_session().as_default():
            return self.model.import_from_h5(import_file)

is_recurrent(self)

Whether this Policy holds a recurrent Model.

Returns:

Type Description
bool

True if this Policy has-a RNN-based Model.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def is_recurrent(self) -> bool:
    return len(self._state_inputs) > 0

learn_on_batch(self, postprocessed_batch)

Perform one learning update, given samples.

Either this method or the combination of compute_gradients and apply_gradients must be implemented by subclasses.

Parameters:

Name Type Description Default
samples

The SampleBatch object to learn from.

required

Returns:

Type Description
Dict[str, Any]

Dictionary of extra metadata from compute_gradients().

Examples:

>>> sample_batch = ev.sample()
>>> ev.learn_on_batch(sample_batch)
Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def learn_on_batch(
        self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
    assert self.loss_initialized()

    # Switch on is_training flag in our batch.
    postprocessed_batch.set_training(True)

    builder = TFRunBuilder(self.get_session(), "learn_on_batch")

    # Callback handling.
    learn_stats = {}
    self.callbacks.on_learn_on_batch(
        policy=self, train_batch=postprocessed_batch, result=learn_stats)

    fetches = self._build_learn_on_batch(builder, postprocessed_batch)
    stats = builder.get(fetches)
    stats.update({"custom_metrics": learn_stats})
    return stats

loss_initialized(self)

Returns whether the loss term(s) have been initialized.

Source code in ray/rllib/policy/tf_policy.py
def loss_initialized(self) -> bool:
    """Returns whether the loss term(s) have been initialized."""
    return len(self._losses) > 0

num_state_tensors(self)

The number of internal states needed by the RNN-Model of the Policy.

Returns:

Type Description
int

The number of RNN internal states kept by this Policy's Model.

Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def num_state_tensors(self) -> int:
    return len(self._state_inputs)

optimizer(self)

TF optimizer to use for policy optimization.

Returns:

Type Description
tf.keras.optimizers.Optimizer

The local optimizer to use for this Policy's Model.

Source code in ray/rllib/policy/tf_policy.py
@DeveloperAPI
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
    """TF optimizer to use for policy optimization.

    Returns:
        tf.keras.optimizers.Optimizer: The local optimizer to use for this
            Policy's Model.
    """
    if hasattr(self, "config") and "lr" in self.config:
        return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
    else:
        return tf1.train.AdamOptimizer()

set_state(self, state)

Restores the entire current state of this Policy from state.

Parameters:

Name Type Description Default
state dict

The new state to set this policy to. Can be obtained by calling self.get_state().

required
Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def set_state(self, state: dict) -> None:
    # Set optimizer vars first.
    optimizer_vars = state.get("_optimizer_variables", None)
    if optimizer_vars is not None:
        self._optimizer_variables.set_weights(optimizer_vars)
    # Set exploration's state.
    if hasattr(self, "exploration") and "_exploration_state" in state:
        self.exploration.set_state(
            state=state["_exploration_state"], sess=self.get_session())

    # Set the Policy's (NN) weights.
    super().set_state(state)

set_weights(self, weights)

Sets this Policy's model's weights.

Note: Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Parameters:

Name Type Description Default
weights

Serializable copy or view of model weights.

required
Source code in ray/rllib/policy/tf_policy.py
@override(Policy)
@DeveloperAPI
def set_weights(self, weights) -> None:
    return self._variables.set_weights(weights)

variables(self)

Return the list of all savable variables for this policy.

Source code in ray/rllib/policy/tf_policy.py
def variables(self):
    """Return the list of all savable variables for this policy."""
    if self.model is None:
        raise NotImplementedError("No `self.model` to get variables for!")
    elif isinstance(self.model, tf.keras.Model):
        return self.model.variables
    else:
        return self.model.variables()

ray.rllib.policy.policy_template.build_policy_class(name, framework, *, loss_fn, get_default_config=None, stats_fn=None, postprocess_fn=None, extra_action_out_fn=None, extra_grad_process_fn=None, extra_learn_fetches_fn=None, optimizer_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, _after_loss_init=None, action_sampler_fn=None, action_distribution_fn=None, make_model=None, make_model_and_action_dist=None, compute_gradients_fn=None, apply_gradients_fn=None, mixins=None, get_batch_divisibility_req=None)

Helper function for creating a new Policy class at runtime.

Supports frameworks JAX and PyTorch.

Parameters:

Name Type Description Default
name str

name of the policy (e.g., "PPOTorchPolicy")

required
framework str

Either "jax" or "torch".

required
loss_fn Optional[Callable[[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]]]

Callable that returns a loss tensor.

required
get_default_config Optional[Callable[[None], TrainerConfigDict]]

Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy.

None
postprocess_fn Optional[Callable[[Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional["Episode"]], SampleBatch]]

Optional callable for post-processing experience batches (called after the super's postprocess_trajectory method).

None
stats_fn Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]

Optional callable that returns a dict of values given the policy and training batch. If None, will use TorchPolicy.extra_grad_info() instead. The stats dict is used for logging (e.g. in TensorBoard).

None
extra_action_out_fn Optional[Callable[[Policy, Dict[str, TensorType], List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, TensorType]]]

Optional callable that returns a dict of extra values to include in experiences. If None, no extra computations will be performed.

None
extra_grad_process_fn Optional[Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]

Optional callable that is called after gradients are computed and returns a processing info dict. If None, will call the TorchPolicy.extra_grad_process() method instead.

None
# TODO

(sven) dissolve naming mismatch between "learn" and "compute.."

required
extra_learn_fetches_fn Optional[Callable[[Policy], Dict[str, TensorType]]]

Optional callable that returns a dict of extra tensors from the policy after loss evaluation. If None, will call the TorchPolicy.extra_compute_grad_fetches() method instead.

None
optimizer_fn Optional[Callable[[Policy, TrainerConfigDict], "torch.optim.Optimizer"]]

Optional callable that returns a torch optimizer given the policy and config. If None, will call the TorchPolicy.optimizer() method instead (which returns a torch Adam optimizer).

None
validate_spaces Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

Optional callable that takes the Policy, observation_space, action_space, and config to check for correctness. If None, no spaces checking will be done.

None
before_init Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

Optional callable to run at the beginning of Policy.__init__ that takes the same arguments as the Policy constructor. If None, this step will be skipped.

None
before_loss_init Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]

Optional callable to run prior to loss init. If None, this step will be skipped.

None
after_init Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

instead.

None
_after_loss_init Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]

Optional callable to run after the loss init. If None, this step will be skipped. This will be deprecated at some point and renamed into after_init to match build_tf_policy() behavior.

None
action_sampler_fn Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]

Optional callable returning a sampled action and its log-likelihood given some (obs and state) inputs. If None, will either use action_distribution_fn or compute actions by calling self.model, then sampling from the so parameterized action distribution.

None
action_distribution_fn Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]]]

A callable that takes the Policy, Model, the observation batch, an explore-flag, a timestep, and an is_training flag and returns a tuple of a) distribution inputs (parameters), b) a dist-class to generate an action distribution object from, and c) internal-state outputs (empty list if not applicable). If None, will either use action_sampler_fn or compute actions by calling self.model, then sampling from the parameterized action distribution.

None
make_model Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]

Optional callable that takes the same arguments as Policy.init and returns a model instance. The distribution class will be determined automatically. Note: Only one of make_model or make_model_and_action_dist should be provided. If both are None, a default Model will be created.

None
make_model_and_action_dist Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], Tuple[ModelV2, Type[TorchDistributionWrapper]]]]

Optional callable that takes the same arguments as Policy.init and returns a tuple of model instance and torch action distribution class. Note: Only one of make_model or make_model_and_action_dist should be provided. If both are None, a default Model will be created.

None
compute_gradients_fn Optional[Callable[ [Policy, SampleBatch], Tuple[ModelGradients, dict]]]

Optional callable that the sampled batch an computes the gradients w.r. to the loss function. If None, will call the TorchPolicy.compute_gradients() method instead.

None
apply_gradients_fn Optional[Callable[[Policy, "torch.optim.Optimizer"], None]]

Optional callable that takes a grads list and applies these to the Model's parameters. If None, will call the TorchPolicy.apply_gradients() method instead.

None
mixins Optional[List[type]]

Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class.

None
get_batch_divisibility_req Optional[Callable[[Policy], int]]

Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1.

None

Returns:

Type Description
Type[TorchPolicy]

TorchPolicy child class constructed from the specified args.

Source code in ray/rllib/policy/policy_template.py
@DeveloperAPI
def build_policy_class(
        name: str,
        framework: str,
        *,
        loss_fn: Optional[Callable[[
            Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
        ], Union[TensorType, List[TensorType]]]],
        get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
        stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
            str, TensorType]]] = None,
        postprocess_fn: Optional[Callable[[
            Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
                "Episode"]
        ], SampleBatch]] = None,
        extra_action_out_fn: Optional[Callable[[
            Policy, Dict[str, TensorType], List[TensorType], ModelV2,
            TorchDistributionWrapper
        ], Dict[str, TensorType]]] = None,
        extra_grad_process_fn: Optional[Callable[[
            Policy, "torch.optim.Optimizer", TensorType
        ], Dict[str, TensorType]]] = None,
        # TODO: (sven) Replace "fetches" with "process".
        extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
            str, TensorType]]] = None,
        optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
                                        "torch.optim.Optimizer"]] = None,
        validate_spaces: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_loss_init: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], None]] = None,
        after_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        _after_loss_init: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], None]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[
            TensorType]], Tuple[TensorType, TensorType]]] = None,
        action_distribution_fn: Optional[Callable[[
            Policy, ModelV2, TensorType, TensorType, TensorType
        ], Tuple[TensorType, type, List[TensorType]]]] = None,
        make_model: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], ModelV2]] = None,
        make_model_and_action_dist: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
        compute_gradients_fn: Optional[Callable[[Policy, SampleBatch], Tuple[
            ModelGradients, dict]]] = None,
        apply_gradients_fn: Optional[Callable[
            [Policy, "torch.optim.Optimizer"], None]] = None,
        mixins: Optional[List[type]] = None,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
    """Helper function for creating a new Policy class at runtime.

    Supports frameworks JAX and PyTorch.

    Args:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        framework (str): Either "jax" or "torch".
        loss_fn (Optional[Callable[[Policy, ModelV2,
            Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
            List[TensorType]]]]): Callable that returns a loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            Optional[Dict[Any, SampleBatch]], Optional["Episode"]],
            SampleBatch]]): Optional callable for post-processing experience
            batches (called after the super's `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            values given the policy and training batch. If None,
            will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
            used for logging (e.g. in TensorBoard).
        extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
            List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
            TensorType]]]): Optional callable that returns a dict of extra
            values to include in experiences. If None, no extra computations
            will be performed.
        extra_grad_process_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
            Optional callable that is called after gradients are computed and
            returns a processing info dict. If None, will call the
            `TorchPolicy.extra_grad_process()` method instead.
        # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra tensors from the policy after loss evaluation. If None,
            will call the `TorchPolicy.extra_compute_grad_fetches()` method
            instead.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "torch.optim.Optimizer"]]): Optional callable that returns a
            torch optimizer given the policy and config. If None, will call
            the `TorchPolicy.optimizer()` method instead (which returns a
            torch Adam optimizer).
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check for
            correctness. If None, no spaces checking will be done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of `Policy.__init__` that takes the same arguments as
            the Policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
            instead.
        _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run after the loss init. If None, this step will be skipped.
            This will be deprecated at some point and renamed into `after_init`
            to match `build_tf_policy()` behavior.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): Optional callable returning a
            sampled action and its log-likelihood given some (obs and state)
            inputs. If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType], Tuple[TensorType,
            Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
            that takes the Policy, Model, the observation batch, an
            explore-flag, a timestep, and an is_training flag and returns a
            tuple of a) distribution inputs (parameters), b) a dist-class to
            generate an action distribution object from, and c) internal-state
            outputs (empty list if not applicable). If None, will either use
            `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the parameterized action distribution.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that takes the same arguments as Policy.__init__ and returns a
            model instance. The distribution class will be determined
            automatically. Note: Only one of `make_model` or
            `make_model_and_action_dist` should be provided. If both are None,
            a default Model will be created.
        make_model_and_action_dist (Optional[Callable[[Policy,
            gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
            callable that takes the same arguments as Policy.__init__ and
            returns a tuple of model instance and torch action distribution
            class.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided. If both are None, a default Model will be
            created.
        compute_gradients_fn (Optional[Callable[
            [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
            callable that the sampled batch an computes the gradients w.r.
            to the loss function.
            If None, will call the `TorchPolicy.compute_gradients()` method
            instead.
        apply_gradients_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer"], None]]): Optional callable that
            takes a grads list and applies these to the Model's parameters.
            If None, will call the `TorchPolicy.apply_gradients()` method
            instead.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the TorchPolicy class.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        Type[TorchPolicy]: TorchPolicy child class constructed from the
            specified args.
    """

    original_kwargs = locals().copy()
    parent_cls = TorchPolicy
    base = add_mixins(parent_cls, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            # Set up the config from possible default-config fn and given
            # config arg.
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            # Set the DL framework for this Policy.
            self.framework = self.config["framework"] = framework

            # Validate observation- and action-spaces.
            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            # Do some pre-initialization steps.
            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework=framework)

            # Make sure, we passed in a correct Model factory.
            model_cls = TorchModelV2 if framework == "torch" else JAXModelV2
            assert isinstance(self.model, model_cls), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            # Call the framework-specific Policy constructor.
            self.parent_cls = parent_cls
            self.parent_cls.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=None if self.config["in_evaluation"] else loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            # Merge Model's view requirements into Policy's.
            self.view_requirements.update(self.model.view_requirements)

            _before_loss_init = before_loss_init or after_init
            if _before_loss_init:
                _before_loss_init(self, self.observation_space,
                                  self.action_space, config)

            # Perform test runs through postprocessing- and loss functions.
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=None if self.config["in_evaluation"] else stats_fn,
            )

            if _after_loss_init:
                _after_loss_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak
            # in torch (issue #6962).
            with self._no_grad_context():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    sample_batch, other_agent_batches, episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(parent_cls)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return parent_cls.extra_grad_process(self, optimizer, loss)

        @override(parent_cls)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                fetches = convert_to_non_torch_type(
                    extra_learn_fetches_fn(self))
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}}, **fetches)
            else:
                return parent_cls.extra_compute_grad_fetches(self)

        @override(parent_cls)
        def compute_gradients(self, batch):
            if compute_gradients_fn:
                return compute_gradients_fn(self, batch)
            else:
                return parent_cls.compute_gradients(self, batch)

        @override(parent_cls)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                parent_cls.apply_gradients(self, gradients)

        @override(parent_cls)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with self._no_grad_context():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(
                        self, input_dict, state_batches, model, action_dist)
                else:
                    stats_dict = parent_cls.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return self._convert_to_non_torch_type(stats_dict)

        @override(parent_cls)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = parent_cls.optimizer(self)
            return optimizers

        @override(parent_cls)
        def extra_grad_info(self, train_batch):
            with self._no_grad_context():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = self.parent_cls.extra_grad_info(
                        self, train_batch)
                return self._convert_to_non_torch_type(stats_dict)

        def _no_grad_context(self):
            if self.framework == "torch":
                return torch.no_grad()
            return NullContextManager()

        def _convert_to_non_torch_type(self, data):
            if self.framework == "torch":
                return convert_to_non_torch_type(data)
            return data

    def with_updates(**overrides):
        """Creates a Torch|JAXPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_torch_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new Torch|JAXPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_policy_class(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls

ray.rllib.policy.tf_policy_template.build_tf_policy(name, *, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, compute_gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_action_out_fn=None, extra_learn_fetches_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, action_distribution_fn=None, mixins=None, get_batch_divisibility_req=None, obs_include_prev_action_reward=-1, extra_action_fetches_fn=None, gradients_fn=None)

Helper function for creating a dynamic tf policy at runtime.

Functions will be run in this order to initialize the policy: 1. Placeholder setup: postprocess_fn 2. Loss init: loss_fn, stats_fn 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn, grad_stats_fn

This means that you can e.g., depend on any policy attributes created in the running of loss_fn in later functions such as stats_fn.

In eager mode, the following functions will be run repeatedly on each eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn, and grad_stats_fn.

This means that these functions should not define any variables internally, otherwise they will fail in eager mode execution. Variable should only be created in make_model (if defined).

Parameters:

Name Type Description Default
name str

Name of the policy (e.g., "PPOTFPolicy").

required
loss_fn Callable[[ Policy, ModelV2, Type[TFActionDistribution], SampleBatch], Union[TensorType, List[TensorType]]]

Callable for calculating a loss tensor.

required
get_default_config Optional[Callable[[None], TrainerConfigDict]]

Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy.

None
postprocess_fn Optional[Callable[[Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]], Episode], None]]

Optional callable for post-processing experience batches (called after the parent class' postprocess_trajectory method).

None
stats_fn Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]

Optional callable that returns a dict of TF tensors to fetch given the policy and batch input tensors. If None, will not compute any stats.

None
optimizer_fn Optional[Callable[[Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]]

Optional callable that returns a tf.Optimizer given the policy and config. If None, will call the base class' optimizer() method instead (which returns a tf1.train.AdamOptimizer).

None
compute_gradients_fn Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]

Optional callable that returns a list of gradients. If None, this defaults to optimizer.compute_gradients([loss]).

None
apply_gradients_fn Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"]]

Optional callable that returns an apply gradients op given policy, tf-optimizer, and grads_and_vars. If None, will call the base class' build_apply_op() method instead.

None
grad_stats_fn Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]]

Optional callable that returns a dict of TF fetches given the policy, batch input, and gradient tensors. If None, will not collect any gradient stats.

None
extra_action_out_fn Optional[Callable[[Policy], Dict[str, TensorType]]]

Optional callable that returns a dict of TF fetches given the policy object. If None, will not perform any extra fetches.

None
extra_learn_fetches_fn Optional[Callable[[Policy], Dict[str, TensorType]]]

Optional callable that returns a dict of extra values to fetch and return when learning on a batch. If None, will call the base class' extra_compute_grad_fetches() method instead.

None
validate_spaces Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

Optional callable that takes the Policy, observation_space, action_space, and config to check the spaces for correctness. If None, no spaces checking will be done.

None
before_init Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

Optional callable to run at the beginning of policy init that takes the same arguments as the policy constructor. If None, this step will be skipped.

None
before_loss_init Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]

Optional callable to run prior to loss init. If None, this step will be skipped.

None
after_init Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]

Optional callable to run at the end of policy init. If None, this step will be skipped.

None
make_model Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]

Optional callable that returns a ModelV2 object. All policy variables should be created in this function. If None, a default ModelV2 object will be created.

None
action_sampler_fn Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]

A callable returning a sampled action and its log-likelihood given observation and state inputs. If None, will either use action_distribution_fn or compute actions by calling self.model, then sampling from the so parameterized action distribution.

None
action_distribution_fn Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]

Optional callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). If None, will either use action_sampler_fn or compute actions by calling self.model, then sampling from the so parameterized action distribution.

None
mixins Optional[List[type]]

Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class.

None
get_batch_divisibility_req Optional[Callable[[Policy], int]]

Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1.

None

Returns:

Type Description
Type[DynamicTFPolicy]

A child class of DynamicTFPolicy based on the specified args.

Source code in ray/rllib/policy/tf_policy_template.py
@DeveloperAPI
def build_tf_policy(
        name: str,
        *,
        loss_fn: Callable[[
            Policy, ModelV2, Type[TFActionDistribution], SampleBatch
        ], Union[TensorType, List[TensorType]]],
        get_default_config: Optional[Callable[[None],
                                              TrainerConfigDict]] = None,
        postprocess_fn: Optional[Callable[[
            Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
            Optional["Episode"]
        ], SampleBatch]] = None,
        stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
            str, TensorType]]] = None,
        optimizer_fn: Optional[Callable[[
            Policy, TrainerConfigDict
        ], "tf.keras.optimizers.Optimizer"]] = None,
        compute_gradients_fn: Optional[Callable[[
            Policy, "tf.keras.optimizers.Optimizer", TensorType
        ], ModelGradients]] = None,
        apply_gradients_fn: Optional[Callable[[
            Policy, "tf.keras.optimizers.Optimizer", ModelGradients
        ], "tf.Operation"]] = None,
        grad_stats_fn: Optional[Callable[[Policy, SampleBatch, ModelGradients],
                                         Dict[str, TensorType]]] = None,
        extra_action_out_fn: Optional[Callable[[Policy], Dict[
            str, TensorType]]] = None,
        extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
            str, TensorType]]] = None,
        validate_spaces: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_loss_init: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], None]] = None,
        after_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        make_model: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], ModelV2]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[
            TensorType]], Tuple[TensorType, TensorType]]] = None,
        action_distribution_fn: Optional[Callable[[
            Policy, ModelV2, TensorType, TensorType, TensorType
        ], Tuple[TensorType, type, List[TensorType]]]] = None,
        mixins: Optional[List[type]] = None,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
        # Deprecated args.
        obs_include_prev_action_reward=DEPRECATED_VALUE,
        extra_action_fetches_fn=None,  # Use `extra_action_out_fn`.
        gradients_fn=None,  # Use `compute_gradients_fn`.
) -> Type[DynamicTFPolicy]:
    """Helper function for creating a dynamic tf policy at runtime.

    Functions will be run in this order to initialize the policy:
        1. Placeholder setup: postprocess_fn
        2. Loss init: loss_fn, stats_fn
        3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
                           grad_stats_fn

    This means that you can e.g., depend on any policy attributes created in
    the running of `loss_fn` in later functions such as `stats_fn`.

    In eager mode, the following functions will be run repeatedly on each
    eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
    and grad_stats_fn.

    This means that these functions should not define any variables internally,
    otherwise they will fail in eager mode execution. Variable should only
    be created in make_model (if defined).

    Args:
        name (str): Name of the policy (e.g., "PPOTFPolicy").
        loss_fn (Callable[[
            Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
            Union[TensorType, List[TensorType]]]): Callable for calculating a
            loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
            Optional callable for post-processing experience batches (called
            after the parent class' `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF tensors to fetch given the policy and batch input tensors. If
            None, will not compute any stats.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "tf.keras.optimizers.Optimizer"]]): Optional callable that returns
            a tf.Optimizer given the policy and config. If None, will call
            the base class' `optimizer()` method instead (which returns a
            tf1.train.AdamOptimizer).
        compute_gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
            Optional callable that returns a list of gradients. If None,
            this defaults to optimizer.compute_gradients([loss]).
        apply_gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", ModelGradients],
            "tf.Operation"]]): Optional callable that returns an apply
            gradients op given policy, tf-optimizer, and grads_and_vars. If
            None, will call the base class' `build_apply_op()` method instead.
        grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF fetches given the policy, batch input, and gradient tensors. If
            None, will not collect any gradient stats.
        extra_action_out_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns
            a dict of TF fetches given the policy object. If None, will not
            perform any extra fetches.
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra values to fetch and return when learning on a batch. If None,
            will call the base class' `extra_compute_grad_fetches()` method
            instead.
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check
            the spaces for correctness. If None, no spaces checking will be
            done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of policy init that takes the same arguments as the
            policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the end of
            policy init. If None, this step will be skipped.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that returns a ModelV2 object.
            All policy variables should be created in this function. If None,
            a default ModelV2 object will be created.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): A callable returning a sampled
            action and its log-likelihood given observation and state inputs.
            If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType],
            Tuple[TensorType, type, List[TensorType]]]]): Optional callable
            returning distribution inputs (parameters), a dist-class to
            generate an action distribution object from, and internal-state
            outputs (or an empty list if not applicable). If None, will either
            use `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the so parameterized action distribution.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the DynamicTFPolicy class.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
            specified args.
    """
    original_kwargs = locals().copy()
    base = add_mixins(DynamicTFPolicy, mixins)

    if obs_include_prev_action_reward != DEPRECATED_VALUE:
        deprecation_warning(old="obs_include_prev_action_reward", error=False)

    if extra_action_fetches_fn is not None:
        deprecation_warning(
            old="extra_action_fetches_fn",
            new="extra_action_out_fn",
            error=False)
        extra_action_out_fn = extra_action_fetches_fn

    if gradients_fn is not None:
        deprecation_warning(
            old="gradients_fn", new="compute_gradients_fn", error=False)
        compute_gradients_fn = gradients_fn

    class policy_cls(base):
        def __init__(self,
                     obs_space,
                     action_space,
                     config,
                     existing_model=None,
                     existing_inputs=None):
            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, config)

            if before_init:
                before_init(self, obs_space, action_space, config)

            def before_loss_init_wrapper(policy, obs_space, action_space,
                                         config):
                if before_loss_init:
                    before_loss_init(policy, obs_space, action_space, config)

                if extra_action_out_fn is None or policy._is_tower:
                    extra_action_fetches = {}
                else:
                    extra_action_fetches = extra_action_out_fn(policy)

                if hasattr(policy, "_extra_action_fetches"):
                    policy._extra_action_fetches.update(extra_action_fetches)
                else:
                    policy._extra_action_fetches = extra_action_fetches

            DynamicTFPolicy.__init__(
                self,
                obs_space=obs_space,
                action_space=action_space,
                config=config,
                loss_fn=loss_fn,
                stats_fn=stats_fn,
                grad_stats_fn=grad_stats_fn,
                before_loss_init=before_loss_init_wrapper,
                make_model=make_model,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                existing_inputs=existing_inputs,
                existing_model=existing_model,
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @override(TFPolicy)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = base.optimizer(self)
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)

            # No optimizers produced -> Return None.
            if not optimizers:
                return None
            # New API: Allow more than one optimizer to be returned.
            # -> Return list.
            elif self.config["_tf_policy_handles_more_than_one_loss"]:
                return optimizers
            # Old API: Return a single LocalOptimizer.
            else:
                return optimizers[0]

        @override(TFPolicy)
        def gradients(self, optimizer, loss):
            optimizers = force_list(optimizer)
            losses = force_list(loss)

            if compute_gradients_fn:
                # New API: Allow more than one optimizer -> Return a list of
                # lists of gradients.
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    return compute_gradients_fn(self, optimizers, losses)
                # Old API: Return a single List of gradients.
                else:
                    return compute_gradients_fn(self, optimizers[0], losses[0])
            else:
                return base.gradients(self, optimizers, losses)

        @override(TFPolicy)
        def build_apply_op(self, optimizer, grads_and_vars):
            if apply_gradients_fn:
                return apply_gradients_fn(self, optimizer, grads_and_vars)
            else:
                return base.build_apply_op(self, optimizer, grads_and_vars)

        @override(TFPolicy)
        def extra_compute_action_fetches(self):
            return dict(
                base.extra_compute_action_fetches(self),
                **self._extra_action_fetches)

        @override(TFPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                # TODO: (sven) in torch, extra_learn_fetches do not exist.
                #  Hence, things like td_error are returned by the stats_fn
                #  and end up under the LEARNER_STATS_KEY. We should
                #  change tf to do this as well. However, this will confilct
                #  the handling of LEARNER_STATS_KEY inside the multi-GPU
                #  train op.
                # Auto-add empty learner stats dict if needed.
                return dict({
                    LEARNER_STATS_KEY: {}
                }, **extra_learn_fetches_fn(self))
            else:
                return base.extra_compute_grad_fetches(self)

    def with_updates(**overrides):
        """Allows creating a TFPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_tf_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new TFPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_tf_policy(**dict(original_kwargs, **overrides))

    def as_eager():
        return eager_tf_policy.build_eager_tf_policy(**original_kwargs)

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.as_eager = staticmethod(as_eager)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
Back to top