Skip to content

models Package API Reference

ray.rllib.models.action_dist.ActionDistribution

The policy action distribution of an agent.

Attributes:

Name Type Description
inputs Tensors

input vector to compute samples from.

model ModelV2

reference to model producing the inputs.

__init__(self, inputs, model) special

Initializes an ActionDist object.

Parameters:

Name Type Description Default
inputs Tensors

input vector to compute samples from.

required
model ModelV2

reference to model producing the inputs. This is mainly useful if you want to use model variables to compute action outputs (i.e., for auto-regressive action distributions, see examples/autoregressive_action_dist.py).

required
Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def __init__(self, inputs: List[TensorType], model: ModelV2):
    """Initializes an ActionDist object.

    Args:
        inputs (Tensors): input vector to compute samples from.
        model (ModelV2): reference to model producing the inputs. This
            is mainly useful if you want to use model variables to compute
            action outputs (i.e., for auto-regressive action distributions,
            see examples/autoregressive_action_dist.py).
    """
    self.inputs = inputs
    self.model = model

deterministic_sample(self)

Get the deterministic "sampling" output from the distribution. This is usually the max likelihood output, i.e. mean for Normal, argmax for Categorical, etc..

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def deterministic_sample(self) -> TensorType:
    """
    Get the deterministic "sampling" output from the distribution.
    This is usually the max likelihood output, i.e. mean for Normal, argmax
    for Categorical, etc..
    """
    raise NotImplementedError

entropy(self)

The entropy of the action distribution.

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def entropy(self) -> TensorType:
    """The entropy of the action distribution."""
    raise NotImplementedError

kl(self, other)

The KL-divergence between two action distributions.

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def kl(self, other: "ActionDistribution") -> TensorType:
    """The KL-divergence between two action distributions."""
    raise NotImplementedError

logp(self, x)

The log-likelihood of the action distribution.

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def logp(self, x: TensorType) -> TensorType:
    """The log-likelihood of the action distribution."""
    raise NotImplementedError

multi_entropy(self)

The entropy of the action distribution.

This differs from entropy() in that it can return an array for MultiDiscrete. TODO(ekl) consider removing this.

Source code in ray/rllib/models/action_dist.py
def multi_entropy(self) -> TensorType:
    """The entropy of the action distribution.

    This differs from entropy() in that it can return an array for
    MultiDiscrete. TODO(ekl) consider removing this.
    """
    return self.entropy()

multi_kl(self, other)

The KL-divergence between two action distributions.

This differs from kl() in that it can return an array for MultiDiscrete. TODO(ekl) consider removing this.

Source code in ray/rllib/models/action_dist.py
def multi_kl(self, other: "ActionDistribution") -> TensorType:
    """The KL-divergence between two action distributions.

    This differs from kl() in that it can return an array for
    MultiDiscrete. TODO(ekl) consider removing this.
    """
    return self.kl(other)

required_model_output_shape(action_space, model_config) staticmethod

Returns the required shape of an input parameter tensor for a particular action space and an optional dict of distribution-specific options.

Parameters:

Name Type Description Default
action_space gym.Space

The action space this distribution will be used for, whose shape attributes will be used to determine the required shape of the input parameter tensor.

required
model_config dict

Model's config dict (as defined in catalog.py)

required

Returns:

Type Description
model_output_shape (int or np.ndarray of ints)

size of the required input vector (minus leading batch dimension).

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
@staticmethod
def required_model_output_shape(
        action_space: gym.Space,
        model_config: ModelConfigDict) -> Union[int, np.ndarray]:
    """Returns the required shape of an input parameter tensor for a
    particular action space and an optional dict of distribution-specific
    options.

    Args:
        action_space (gym.Space): The action space this distribution will
            be used for, whose shape attributes will be used to determine
            the required shape of the input parameter tensor.
        model_config (dict): Model's config dict (as defined in catalog.py)

    Returns:
        model_output_shape (int or np.ndarray of ints): size of the
            required input vector (minus leading batch dimension).
    """
    raise NotImplementedError

sample(self)

Draw a sample from the action distribution.

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def sample(self) -> TensorType:
    """Draw a sample from the action distribution."""
    raise NotImplementedError

sampled_action_logp(self)

Returns the log probability of the last sampled action.

Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def sampled_action_logp(self) -> TensorType:
    """Returns the log probability of the last sampled action."""
    raise NotImplementedError

ray.rllib.models.catalog.ModelCatalog

Registry of models, preprocessors, and action distributions for envs.

Examples:

>>> prep = ModelCatalog.get_preprocessor(env)
>>> observation = prep.transform(raw_observation)
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
...     env.action_space, {})
>>> model = ModelCatalog.get_model_v2(
...     obs_space, action_space, num_outputs, options)
>>> dist = dist_class(model.outputs, model)
>>> action = dist.sample()

get_action_dist(action_space, config, dist_type=None, framework='tf', **kwargs) staticmethod

Returns a distribution class and size for the given action space.

Parameters:

Name Type Description Default
action_space Space

Action space of the target gym env.

required
config Optional[dict]

Optional model config.

required
dist_type Optional[Union[str, Type[ActionDistribution]]]

Identifier of the action distribution (str) interpreted as a hint or the actual ActionDistribution class to use.

None
framework str

One of "tf2", "tf", "tfe", "torch", or "jax".

'tf'
kwargs dict

Optional kwargs to pass on to the Distribution's constructor.

{}

Returns:

Type Description
Tuple
  • dist_class (ActionDistribution): Python class of the distribution.
    • dist_dim (int): The size of the input vector to the distribution.
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_dist(
        action_space: gym.Space,
        config: ModelConfigDict,
        dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
        framework: str = "tf",
        **kwargs) -> (type, int):
    """Returns a distribution class and size for the given action space.

    Args:
        action_space (Space): Action space of the target gym env.
        config (Optional[dict]): Optional model config.
        dist_type (Optional[Union[str, Type[ActionDistribution]]]):
            Identifier of the action distribution (str) interpreted as a
            hint or the actual ActionDistribution class to use.
        framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
        kwargs (dict): Optional kwargs to pass on to the Distribution's
            constructor.

    Returns:
        Tuple:
            - dist_class (ActionDistribution): Python class of the
                distribution.
            - dist_dim (int): The size of the input vector to the
                distribution.
    """

    dist_cls = None
    config = config or MODEL_DEFAULTS
    # Custom distribution given.
    if config.get("custom_action_dist"):
        custom_action_config = config.copy()
        action_dist_name = custom_action_config.pop("custom_action_dist")
        logger.debug(
            "Using custom action distribution {}".format(action_dist_name))
        dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
                                        action_dist_name)
        return ModelCatalog._get_multi_action_distribution(
            dist_cls, action_space, custom_action_config, framework)

    # Dist_type is given directly as a class.
    elif type(dist_type) is type and \
            issubclass(dist_type, ActionDistribution) and \
            dist_type not in (
            MultiActionDistribution, TorchMultiActionDistribution):
        dist_cls = dist_type
    # Box space -> DiagGaussian OR Deterministic.
    elif isinstance(action_space, Box):
        if action_space.dtype.name.startswith("int"):
            low_ = np.min(action_space.low)
            high_ = np.max(action_space.high)
            dist_cls = TorchMultiCategorical if framework == "torch" \
                else MultiCategorical
            num_cats = int(np.product(action_space.shape))
            return partial(
                dist_cls,
                input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
                action_space=action_space), num_cats * (high_ - low_ + 1)
        else:
            if len(action_space.shape) > 1:
                raise UnsupportedSpaceException(
                    "Action space has multiple dimensions "
                    "{}. ".format(action_space.shape) +
                    "Consider reshaping this into a single dimension, "
                    "using a custom action distribution, "
                    "using a Tuple action space, or the multi-agent API.")
            # TODO(sven): Check for bounds and return SquashedNormal, etc..
            if dist_type is None:
                dist_cls = TorchDiagGaussian if framework == "torch" \
                    else DiagGaussian
            elif dist_type == "deterministic":
                dist_cls = TorchDeterministic if framework == "torch" \
                    else Deterministic
    # Discrete Space -> Categorical.
    elif isinstance(action_space, Discrete):
        dist_cls = TorchCategorical if framework == "torch" else \
            JAXCategorical if framework == "jax" else Categorical
    # Tuple/Dict Spaces -> MultiAction.
    elif dist_type in (MultiActionDistribution,
                       TorchMultiActionDistribution) or \
            isinstance(action_space, (Tuple, Dict)):
        return ModelCatalog._get_multi_action_distribution(
            (MultiActionDistribution
             if framework == "tf" else TorchMultiActionDistribution),
            action_space, config, framework)
    # Simplex -> Dirichlet.
    elif isinstance(action_space, Simplex):
        if framework == "torch":
            # TODO(sven): implement
            raise NotImplementedError(
                "Simplex action spaces not supported for torch.")
        dist_cls = Dirichlet
    # MultiDiscrete -> MultiCategorical.
    elif isinstance(action_space, MultiDiscrete):
        dist_cls = TorchMultiCategorical if framework == "torch" else \
            MultiCategorical
        return partial(dist_cls, input_lens=action_space.nvec), \
            int(sum(action_space.nvec))
    # Unknown type -> Error.
    else:
        raise NotImplementedError("Unsupported args: {} {}".format(
            action_space, dist_type))

    return dist_cls, dist_cls.required_model_output_shape(
        action_space, config)

get_action_placeholder(action_space, name='action') staticmethod

Returns an action placeholder consistent with the action space

Parameters:

Name Type Description Default
action_space Space

Action space of the target gym env.

required
name str

An optional string to name the placeholder by. Default: "action".

'action'

Returns:

Type Description
action_placeholder (Tensor)

A placeholder for the actions

Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_placeholder(action_space: gym.Space,
                           name: str = "action") -> TensorType:
    """Returns an action placeholder consistent with the action space

    Args:
        action_space (Space): Action space of the target gym env.
        name (str): An optional string to name the placeholder by.
            Default: "action".

    Returns:
        action_placeholder (Tensor): A placeholder for the actions
    """
    dtype, shape = ModelCatalog.get_action_shape(
        action_space, framework="tf")

    return tf1.placeholder(dtype, shape=shape, name=name)

get_action_shape(action_space, framework='tf') staticmethod

Returns action tensor dtype and shape for the action space.

Parameters:

Name Type Description Default
action_space Space

Action space of the target gym env.

required
framework str

The framework identifier. One of "tf" or "torch".

'tf'

Returns:

Type Description
(dtype, shape)

Dtype and shape of the actions tensor.

Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_shape(action_space: gym.Space,
                     framework: str = "tf") -> (np.dtype, List[int]):
    """Returns action tensor dtype and shape for the action space.

    Args:
        action_space (Space): Action space of the target gym env.
        framework (str): The framework identifier. One of "tf" or "torch".

    Returns:
        (dtype, shape): Dtype and shape of the actions tensor.
    """
    dl_lib = torch if framework == "torch" else tf
    if isinstance(action_space, Discrete):
        return action_space.dtype, (None, )
    elif isinstance(action_space, (Box, Simplex)):
        if np.issubdtype(action_space.dtype, np.floating):
            return dl_lib.float32, (None, ) + action_space.shape
        elif np.issubdtype(action_space.dtype, np.integer):
            return dl_lib.int32, (None, ) + action_space.shape
        else:
            raise ValueError(
                "RLlib doesn't support non int or float box spaces")
    elif isinstance(action_space, MultiDiscrete):
        return action_space.dtype, (None, ) + action_space.shape
    elif isinstance(action_space, (Tuple, Dict)):
        flat_action_space = flatten_space(action_space)
        size = 0
        all_discrete = True
        for i in range(len(flat_action_space)):
            if isinstance(flat_action_space[i], Discrete):
                size += 1
            else:
                all_discrete = False
                size += np.product(flat_action_space[i].shape)
        size = int(size)
        return dl_lib.int32 if all_discrete else dl_lib.float32, \
            (None, size)
    else:
        raise NotImplementedError(
            "Action space {} not supported".format(action_space))

get_model_v2(obs_space, action_space, num_outputs, model_config, framework='tf', name='default_model', model_interface=None, default_model=None, **model_kwargs) staticmethod

Returns a suitable model compatible with given spaces and output.

Parameters:

Name Type Description Default
obs_space Space

Observation space of the target gym env. This may have an original_space attribute that specifies how to unflatten the tensor into a ragged tensor.

required
action_space Space

Action space of the target gym env.

required
num_outputs int

The size of the output vector of the model.

required
model_config ModelConfigDict

The "model" sub-config dict within the Trainer's config dict.

required
framework str

One of "tf2", "tf", "tfe", "torch", or "jax".

'tf'
name str

Name (scope) for the model.

'default_model'
model_interface cls

Interface required for the model

None
default_model cls

Override the default class for the model. This only has an effect when not using a custom model

None
model_kwargs dict

args to pass to the ModelV2 constructor

{}

Returns:

Type Description
model (ModelV2)

Model to use for the policy.

Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_model_v2(obs_space: gym.Space,
                 action_space: gym.Space,
                 num_outputs: int,
                 model_config: ModelConfigDict,
                 framework: str = "tf",
                 name: str = "default_model",
                 model_interface: type = None,
                 default_model: type = None,
                 **model_kwargs) -> ModelV2:
    """Returns a suitable model compatible with given spaces and output.

    Args:
        obs_space (Space): Observation space of the target gym env. This
            may have an `original_space` attribute that specifies how to
            unflatten the tensor into a ragged tensor.
        action_space (Space): Action space of the target gym env.
        num_outputs (int): The size of the output vector of the model.
        model_config (ModelConfigDict): The "model" sub-config dict
            within the Trainer's config dict.
        framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
        name (str): Name (scope) for the model.
        model_interface (cls): Interface required for the model
        default_model (cls): Override the default class for the model. This
            only has an effect when not using a custom model
        model_kwargs (dict): args to pass to the ModelV2 constructor

    Returns:
        model (ModelV2): Model to use for the policy.
    """

    # Validate the given config dict.
    ModelCatalog._validate_config(config=model_config, framework=framework)

    if model_config.get("custom_model"):
        # Allow model kwargs to be overridden / augmented by
        # custom_model_config.
        customized_model_kwargs = dict(
            model_kwargs, **model_config.get("custom_model_config", {}))

        if isinstance(model_config["custom_model"], type):
            model_cls = model_config["custom_model"]
        else:
            model_cls = _global_registry.get(RLLIB_MODEL,
                                             model_config["custom_model"])

        # Only allow ModelV2 or native keras Models.
        if not issubclass(model_cls, ModelV2):
            if framework not in ["tf", "tf2", "tfe"] or \
                    not issubclass(model_cls, tf.keras.Model):
                raise ValueError(
                    "`model_cls` must be a ModelV2 sub-class, but is"
                    " {}!".format(model_cls))

        logger.info("Wrapping {} as {}".format(model_cls, model_interface))
        model_cls = ModelCatalog._wrap_if_needed(model_cls,
                                                 model_interface)

        if framework in ["tf2", "tf", "tfe"]:
            # Try wrapping custom model with LSTM/attention, if required.
            if model_config.get("use_lstm") or \
                    model_config.get("use_attention"):
                from ray.rllib.models.tf.attention_net import \
                    AttentionWrapper, Keras_AttentionWrapper
                from ray.rllib.models.tf.recurrent_net import \
                    LSTMWrapper, Keras_LSTMWrapper

                wrapped_cls = model_cls
                # Wrapped (custom) model is itself a keras Model ->
                # wrap with keras LSTM/GTrXL (attention) wrappers.
                if issubclass(wrapped_cls, tf.keras.Model):
                    model_cls = Keras_LSTMWrapper if \
                        model_config.get("use_lstm") else \
                        Keras_AttentionWrapper
                    model_config["wrapped_cls"] = wrapped_cls
                # Wrapped (custom) model is ModelV2 ->
                # wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
                else:
                    forward = wrapped_cls.forward
                    model_cls = ModelCatalog._wrap_if_needed(
                        wrapped_cls, LSTMWrapper if
                        model_config.get("use_lstm") else AttentionWrapper)
                    model_cls._wrapped_forward = forward

            # Obsolete: Track and warn if vars were created but not
            # registered. Only still do this, if users do register their
            # variables. If not (which they shouldn't), don't check here.
            created = set()

            def track_var_creation(next_creator, **kw):
                v = next_creator(**kw)
                created.add(v)
                return v

            with tf.variable_creator_scope(track_var_creation):
                if issubclass(model_cls, tf.keras.Model):
                    instance = model_cls(
                        input_space=obs_space,
                        action_space=action_space,
                        num_outputs=num_outputs,
                        name=name,
                        **customized_model_kwargs,
                    )
                else:
                    # Try calling with kwargs first (custom ModelV2 should
                    # accept these as kwargs, not get them from
                    # config["custom_model_config"] anymore).
                    try:
                        instance = model_cls(
                            obs_space,
                            action_space,
                            num_outputs,
                            model_config,
                            name,
                            **customized_model_kwargs,
                        )
                    except TypeError as e:
                        # Keyword error: Try old way w/o kwargs.
                        if "__init__() got an unexpected " in e.args[0]:
                            instance = model_cls(
                                obs_space,
                                action_space,
                                num_outputs,
                                model_config,
                                name,
                                **model_kwargs,
                            )
                            logger.warning(
                                "Custom ModelV2 should accept all custom "
                                "options as **kwargs, instead of expecting"
                                " them in config['custom_model_config']!")
                        # Other error -> re-raise.
                        else:
                            raise e

            # User still registered TFModelV2's variables: Check, whether
            # ok.
            registered = []
            if not isinstance(instance, tf.keras.Model):
                registered = set(instance.var_list)
            if len(registered) > 0:
                not_registered = set()
                for var in created:
                    if var not in registered:
                        not_registered.add(var)
                if not_registered:
                    raise ValueError(
                        "It looks like you are still using "
                        "`{}.register_variables()` to register your "
                        "model's weights. This is no longer required, but "
                        "if you are still calling this method at least "
                        "once, you must make sure to register all created "
                        "variables properly. The missing variables are {},"
                        " and you only registered {}. "
                        "Did you forget to call `register_variables()` on "
                        "some of the variables in question?".format(
                            instance, not_registered, registered))
        elif framework == "torch":
            # Try wrapping custom model with LSTM/attention, if required.
            if model_config.get("use_lstm") or \
                    model_config.get("use_attention"):
                from ray.rllib.models.torch.attention_net import \
                    AttentionWrapper
                from ray.rllib.models.torch.recurrent_net import \
                    LSTMWrapper

                wrapped_cls = model_cls
                forward = wrapped_cls.forward
                model_cls = ModelCatalog._wrap_if_needed(
                    wrapped_cls, LSTMWrapper
                    if model_config.get("use_lstm") else AttentionWrapper)
                model_cls._wrapped_forward = forward

            # PyTorch automatically tracks nn.Modules inside the parent
            # nn.Module's constructor.
            # Try calling with kwargs first (custom ModelV2 should
            # accept these as kwargs, not get them from
            # config["custom_model_config"] anymore).
            try:
                instance = model_cls(obs_space, action_space, num_outputs,
                                     model_config, name,
                                     **customized_model_kwargs)
            except TypeError as e:
                # Keyword error: Try old way w/o kwargs.
                if "__init__() got an unexpected " in e.args[0]:
                    instance = model_cls(obs_space, action_space,
                                         num_outputs, model_config, name,
                                         **model_kwargs)
                    logger.warning(
                        "Custom ModelV2 should accept all custom "
                        "options as **kwargs, instead of expecting"
                        " them in config['custom_model_config']!")
                # Other error -> re-raise.
                else:
                    raise e
        else:
            raise NotImplementedError(
                "`framework` must be 'tf2|tf|tfe|torch', but is "
                "{}!".format(framework))

        return instance

    # Find a default TFModelV2 and wrap with model_interface.
    if framework in ["tf", "tfe", "tf2"]:
        v2_class = None
        # Try to get a default v2 model.
        if not model_config.get("custom_model"):
            v2_class = default_model or ModelCatalog._get_v2_model_class(
                obs_space, model_config, framework=framework)

        if not v2_class:
            raise ValueError("ModelV2 class could not be determined!")

        if model_config.get("use_lstm") or \
                model_config.get("use_attention"):

            from ray.rllib.models.tf.attention_net import \
                AttentionWrapper, Keras_AttentionWrapper
            from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
                Keras_LSTMWrapper

            wrapped_cls = v2_class
            if model_config.get("use_lstm"):
                if issubclass(wrapped_cls, tf.keras.Model):
                    v2_class = Keras_LSTMWrapper
                    model_config["wrapped_cls"] = wrapped_cls
                else:
                    v2_class = ModelCatalog._wrap_if_needed(
                        wrapped_cls, LSTMWrapper)
                    v2_class._wrapped_forward = wrapped_cls.forward
            else:
                if issubclass(wrapped_cls, tf.keras.Model):
                    v2_class = Keras_AttentionWrapper
                    model_config["wrapped_cls"] = wrapped_cls
                else:
                    v2_class = ModelCatalog._wrap_if_needed(
                        wrapped_cls, AttentionWrapper)
                    v2_class._wrapped_forward = wrapped_cls.forward

        # Wrap in the requested interface.
        wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)

        if issubclass(wrapper, tf.keras.Model):
            model = wrapper(
                input_space=obs_space,
                action_space=action_space,
                num_outputs=num_outputs,
                name=name,
                **dict(model_kwargs, **model_config),
            )
            return model

        return wrapper(obs_space, action_space, num_outputs, model_config,
                       name, **model_kwargs)

    # Find a default TorchModelV2 and wrap with model_interface.
    elif framework == "torch":
        # Try to get a default v2 model.
        if not model_config.get("custom_model"):
            v2_class = default_model or ModelCatalog._get_v2_model_class(
                obs_space, model_config, framework=framework)

        if not v2_class:
            raise ValueError("ModelV2 class could not be determined!")

        if model_config.get("use_lstm") or \
                model_config.get("use_attention"):

            from ray.rllib.models.torch.attention_net import \
                AttentionWrapper
            from ray.rllib.models.torch.recurrent_net import LSTMWrapper

            wrapped_cls = v2_class
            forward = wrapped_cls.forward
            if model_config.get("use_lstm"):
                v2_class = ModelCatalog._wrap_if_needed(
                    wrapped_cls, LSTMWrapper)
            else:
                v2_class = ModelCatalog._wrap_if_needed(
                    wrapped_cls, AttentionWrapper)

            v2_class._wrapped_forward = forward

        # Wrap in the requested interface.
        wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
        return wrapper(obs_space, action_space, num_outputs, model_config,
                       name, **model_kwargs)

    # Find a default JAXModelV2 and wrap with model_interface.
    elif framework == "jax":
        v2_class = \
            default_model or ModelCatalog._get_v2_model_class(
                obs_space, model_config, framework=framework)
        # Wrap in the requested interface.
        wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
        return wrapper(obs_space, action_space, num_outputs, model_config,
                       name, **model_kwargs)
    else:
        raise NotImplementedError(
            "`framework` must be 'tf2|tf|tfe|torch', but is "
            "{}!".format(framework))

get_preprocessor(env, options=None) staticmethod

Returns a suitable preprocessor for the given env.

This is a wrapper for get_preprocessor_for_space().

Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_preprocessor(env: gym.Env,
                     options: Optional[dict] = None) -> Preprocessor:
    """Returns a suitable preprocessor for the given env.

    This is a wrapper for get_preprocessor_for_space().
    """

    return ModelCatalog.get_preprocessor_for_space(env.observation_space,
                                                   options)

get_preprocessor_for_space(observation_space, options=None) staticmethod

Returns a suitable preprocessor for the given observation space.

Parameters:

Name Type Description Default
observation_space Space

The input observation space.

required
options dict

Options to pass to the preprocessor.

None

Returns:

Type Description
preprocessor (Preprocessor)

Preprocessor for the observations.

Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_preprocessor_for_space(observation_space: gym.Space,
                               options: dict = None) -> Preprocessor:
    """Returns a suitable preprocessor for the given observation space.

    Args:
        observation_space (Space): The input observation space.
        options (dict): Options to pass to the preprocessor.

    Returns:
        preprocessor (Preprocessor): Preprocessor for the observations.
    """

    options = options or MODEL_DEFAULTS
    for k in options.keys():
        if k not in MODEL_DEFAULTS:
            raise Exception("Unknown config key `{}`, all keys: {}".format(
                k, list(MODEL_DEFAULTS)))

    if options.get("custom_preprocessor"):
        preprocessor = options["custom_preprocessor"]
        logger.info("Using custom preprocessor {}".format(preprocessor))
        logger.warning(
            "DeprecationWarning: Custom preprocessors are deprecated, "
            "since they sometimes conflict with the built-in "
            "preprocessors for handling complex observation spaces. "
            "Please use wrapper classes around your environment "
            "instead of preprocessors.")
        prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
            observation_space, options)
    else:
        cls = get_preprocessor(observation_space)
        prep = cls(observation_space, options)

    if prep is not None:
        logger.debug("Created preprocessor {}: {} -> {}".format(
            prep, observation_space, prep.shape))
    return prep

register_custom_action_dist(action_dist_name, action_dist_class) staticmethod

Register a custom action distribution class by name.

The model can be later used by specifying {"custom_action_dist": action_dist_name} in the model config.

Parameters:

Name Type Description Default
model_name str

Name to register the action distribution under.

required
model_class type

Python class of the action distribution.

required
Source code in ray/rllib/models/catalog.py
@staticmethod
@PublicAPI
def register_custom_action_dist(action_dist_name: str,
                                action_dist_class: type) -> None:
    """Register a custom action distribution class by name.

    The model can be later used by specifying
    {"custom_action_dist": action_dist_name} in the model config.

    Args:
        model_name (str): Name to register the action distribution under.
        model_class (type): Python class of the action distribution.
    """
    _global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
                              action_dist_class)

register_custom_model(model_name, model_class) staticmethod

Register a custom model class by name.

The model can be later used by specifying {"custom_model": model_name} in the model config.

Parameters:

Name Type Description Default
model_name str

Name to register the model under.

required
model_class type

Python class of the model.

required
Source code in ray/rllib/models/catalog.py
@staticmethod
@PublicAPI
def register_custom_model(model_name: str, model_class: type) -> None:
    """Register a custom model class by name.

    The model can be later used by specifying {"custom_model": model_name}
    in the model config.

    Args:
        model_name (str): Name to register the model under.
        model_class (type): Python class of the model.
    """
    if tf is not None:
        if issubclass(model_class, tf.keras.Model):
            deprecation_warning(old="register_custom_model", error=False)
    _global_registry.register(RLLIB_MODEL, model_name, model_class)

ray.rllib.models.modelv2.ModelV2

Defines an abstract neural network model for use with RLlib.

Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.

Data flow: obs -> forward() -> model_out value_function() -> V(s)

__init__(self, obs_space, action_space, num_outputs, model_config, name, framework) special

Initializes a ModelV2 object.

This method should create any variables used by the model.

Parameters:

Name Type Description Default
obs_space gym.spaces.Space

Observation space of the target gym env. This may have an original_space attribute that specifies how to unflatten the tensor into a ragged tensor.

required
action_space gym.spaces.Space

Action space of the target gym env.

required
num_outputs int

Number of output units of the model.

required
model_config ModelConfigDict

Config for the model, documented in ModelCatalog.

required
name str

Name (scope) for the model.

required
framework str

Either "tf" or "torch".

required
Source code in ray/rllib/models/modelv2.py
def __init__(self, obs_space: gym.spaces.Space,
             action_space: gym.spaces.Space, num_outputs: int,
             model_config: ModelConfigDict, name: str, framework: str):
    """Initializes a ModelV2 object.

    This method should create any variables used by the model.

    Args:
        obs_space (gym.spaces.Space): Observation space of the target gym
            env. This may have an `original_space` attribute that
            specifies how to unflatten the tensor into a ragged tensor.
        action_space (gym.spaces.Space): Action space of the target gym
            env.
        num_outputs (int): Number of output units of the model.
        model_config (ModelConfigDict): Config for the model, documented
            in ModelCatalog.
        name (str): Name (scope) for the model.
        framework (str): Either "tf" or "torch".
    """

    self.obs_space: gym.spaces.Space = obs_space
    self.action_space: gym.spaces.Space = action_space
    self.num_outputs: int = num_outputs
    self.model_config: ModelConfigDict = model_config
    self.name: str = name or "default_model"
    self.framework: str = framework
    self._last_output = None
    self.time_major = self.model_config.get("_time_major")
    # Basic view requirement for all models: Use the observation as input.
    self.view_requirements = {
        SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
    }

context(self)

Returns a contextmanager for the current forward pass.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def context(self) -> contextlib.AbstractContextManager:
    """Returns a contextmanager for the current forward pass."""
    return NullContextManager()

custom_loss(self, policy_loss, loss_inputs)

Override to customize the loss function used to optimize this model.

This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model's layers).

You can find an runnable example in examples/custom_loss.py.

Parameters:

Name Type Description Default
policy_loss Union[List[Tensor],Tensor]

List of or single policy loss(es) from the policy.

required
loss_inputs dict

map of input placeholders for rollout data.

required

Returns:

Type Description
Union[List[Tensor],Tensor]

List of or scalar tensor for the customized loss(es) for this model.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def custom_loss(self, policy_loss: TensorType,
                loss_inputs: Dict[str, TensorType]) -> TensorType:
    """Override to customize the loss function used to optimize this model.

    This can be used to incorporate self-supervised losses (by defining
    a loss over existing input and output tensors of this model), and
    supervised losses (by defining losses over a variable-sharing copy of
    this model's layers).

    You can find an runnable example in examples/custom_loss.py.

    Args:
        policy_loss (Union[List[Tensor],Tensor]): List of or single policy
            loss(es) from the policy.
        loss_inputs (dict): map of input placeholders for rollout data.

    Returns:
        Union[List[Tensor],Tensor]: List of or scalar tensor for the
            customized loss(es) for this model.
    """
    return policy_loss

forward(self, input_dict, state, seq_lens)

Call the model with the given input tensors and state.

Any complex observations (dicts, tuples, etc.) will be unpacked by call before being passed to forward(). To access the flattened observation tensor, refer to input_dict["obs_flat"].

This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).

Custom models should override this instead of call.

Parameters:

Name Type Description Default
input_dict dict

dictionary of input tensors, including "obs", "obs_flat", "prev_action", "prev_reward", "is_training", "eps_id", "agent_id", "infos", and "t".

required
state list

list of state tensors with sizes matching those returned by get_initial_state + the batch dimension

required
seq_lens Tensor

1d tensor holding input sequence lengths

required

Returns:

Type Description
(outputs, state)

The model output tensor of size [BATCH, num_outputs], and the new RNN state.

Examples:

>>> def forward(self, input_dict, state, seq_lens):
>>>     model_out, self._value_out = self.base_model(
...         input_dict["obs"])
>>>     return model_out, state
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def forward(self, input_dict: Dict[str, TensorType],
            state: List[TensorType],
            seq_lens: TensorType) -> (TensorType, List[TensorType]):
    """Call the model with the given input tensors and state.

    Any complex observations (dicts, tuples, etc.) will be unpacked by
    __call__ before being passed to forward(). To access the flattened
    observation tensor, refer to input_dict["obs_flat"].

    This method can be called any number of times. In eager execution,
    each call to forward() will eagerly evaluate the model. In symbolic
    execution, each call to forward creates a computation graph that
    operates over the variables of this model (i.e., shares weights).

    Custom models should override this instead of __call__.

    Args:
        input_dict (dict): dictionary of input tensors, including "obs",
            "obs_flat", "prev_action", "prev_reward", "is_training",
            "eps_id", "agent_id", "infos", and "t".
        state (list): list of state tensors with sizes matching those
            returned by get_initial_state + the batch dimension
        seq_lens (Tensor): 1d tensor holding input sequence lengths

    Returns:
        (outputs, state): The model output tensor of size
            [BATCH, num_outputs], and the new RNN state.

    Examples:
        >>> def forward(self, input_dict, state, seq_lens):
        >>>     model_out, self._value_out = self.base_model(
        ...         input_dict["obs"])
        >>>     return model_out, state
    """
    raise NotImplementedError

get_initial_state(self)

Get the initial recurrent state values for the model.

Returns:

Type Description
List[np.ndarray]

List of np.array objects containing the initial hidden state of an RNN, if applicable.

Examples:

>>> def get_initial_state(self):
>>>    return [
>>>        np.zeros(self.cell_size, np.float32),
>>>        np.zeros(self.cell_size, np.float32),
>>>    ]
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def get_initial_state(self) -> List[np.ndarray]:
    """Get the initial recurrent state values for the model.

    Returns:
        List[np.ndarray]: List of np.array objects containing the initial
            hidden state of an RNN, if applicable.

    Examples:
        >>> def get_initial_state(self):
        >>>    return [
        >>>        np.zeros(self.cell_size, np.float32),
        >>>        np.zeros(self.cell_size, np.float32),
        >>>    ]
    """
    return []

import_from_h5(self, h5_file)

Imports weights from an h5 file.

Parameters:

Name Type Description Default
h5_file str

The h5 file name to import weights from.

required

Examples:

>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>>     trainer.train()
Source code in ray/rllib/models/modelv2.py
def import_from_h5(self, h5_file: str) -> None:
    """Imports weights from an h5 file.

    Args:
        h5_file (str): The h5 file name to import weights from.

    Example:
        >>> trainer = MyTrainer()
        >>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
        >>> for _ in range(10):
        >>>     trainer.train()
    """
    raise NotImplementedError

is_time_major(self)

If True, data for calling this ModelV2 must be in time-major format.

Returns !!! bool "Whether this ModelV2 requires a time-major (TxBx...) data" format.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def is_time_major(self) -> bool:
    """If True, data for calling this ModelV2 must be in time-major format.

    Returns
        bool: Whether this ModelV2 requires a time-major (TxBx...) data
            format.
    """
    return self.time_major is True

last_output(self)

Returns the last output returned from calling the model.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def last_output(self) -> TensorType:
    """Returns the last output returned from calling the model."""
    return self._last_output

metrics(self)

Override to return custom metrics from your model.

The stats will be reported as part of the learner stats, i.e., info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1

Returns:

Type Description
Dict[str, TensorType]

The custom metrics for this model.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def metrics(self) -> Dict[str, TensorType]:
    """Override to return custom metrics from your model.

    The stats will be reported as part of the learner stats, i.e.,
    info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1

    Returns:
        Dict[str, TensorType]: The custom metrics for this model.
    """
    return {}

trainable_variables(self, as_dict=False)

Returns the list of trainable variables for this model.

Parameters:

Name Type Description Default
as_dict(bool)

Whether variables should be returned as dict-values (using descriptive keys).

required

Returns:

Type Description
Union[List[any],Dict[str,any]]

The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def trainable_variables(
        self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
    """Returns the list of trainable variables for this model.

    Args:
        as_dict(bool): Whether variables should be returned as dict-values
            (using descriptive keys).

    Returns:
        Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
            True) of all trainable (tf)/requires_grad (torch) variables
            of this ModelV2.
    """
    raise NotImplementedError

value_function(self)

Returns the value function output for the most recent forward pass.

Note that a forward call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.

Returns:

Type Description
Any

value estimate tensor of shape [BATCH].

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def value_function(self) -> TensorType:
    """Returns the value function output for the most recent forward pass.

    Note that a `forward` call has to be performed first, before this
    methods can return anything and thus that calling this method does not
    cause an extra forward pass through the network.

    Returns:
        value estimate tensor of shape [BATCH].
    """
    raise NotImplementedError

variables(self, as_dict=False)

Returns the list (or a dict) of variables for this model.

Parameters:

Name Type Description Default
as_dict(bool)

Whether variables should be returned as dict-values (using descriptive str keys).

required

Returns:

Type Description
Union[List[any],Dict[str,any]]

The list (or dict if as_dict is True) of all variables of this ModelV2.

Source code in ray/rllib/models/modelv2.py
@PublicAPI
def variables(self, as_dict: bool = False
              ) -> Union[List[TensorType], Dict[str, TensorType]]:
    """Returns the list (or a dict) of variables for this model.

    Args:
        as_dict(bool): Whether variables should be returned as dict-values
            (using descriptive str keys).

    Returns:
        Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
            True) of all variables of this ModelV2.
    """
    raise NotImplementedError

ray.rllib.models.preprocessors.Preprocessor

Defines an abstract observation preprocessor function.

Attributes:

Name Type Description
shape List[int]

Shape of the preprocessed output.

check_shape(self, observation)

Checks the shape of the given observation.

Source code in ray/rllib/models/preprocessors.py
def check_shape(self, observation: Any) -> None:
    """Checks the shape of the given observation."""
    if self._i % OBS_VALIDATION_INTERVAL == 0:
        # Convert lists to np.ndarrays.
        if type(observation) is list and isinstance(
                self._obs_space, gym.spaces.Box):
            observation = np.array(observation).astype(np.float32)
        if not self._obs_space.contains(observation):
            observation = convert_element_to_space_type(
                observation, self._obs_for_type_matching)
        try:
            if not self._obs_space.contains(observation):
                raise ValueError(
                    "Observation ({} dtype={}) outside given space ({})!",
                    observation, observation.dtype if isinstance(
                        self._obs_space,
                        gym.spaces.Box) else None, self._obs_space)
        except AttributeError:
            raise ValueError(
                "Observation for a Box/MultiBinary/MultiDiscrete space "
                "should be an np.array, not a Python list.", observation)
    self._i += 1

transform(self, observation)

Returns the preprocessed observation.

Source code in ray/rllib/models/preprocessors.py
@PublicAPI
def transform(self, observation: TensorType) -> np.ndarray:
    """Returns the preprocessed observation."""
    raise NotImplementedError

write(self, observation, array, offset)

Alternative to transform for more efficient flattening.

Source code in ray/rllib/models/preprocessors.py
def write(self, observation: TensorType, array: np.ndarray,
          offset: int) -> None:
    """Alternative to transform for more efficient flattening."""
    array[offset:offset + self._size] = self.transform(observation)
Back to top