models
Package API Reference
ray.rllib.models.action_dist.ActionDistribution
The policy action distribution of an agent.
Attributes:
Name | Type | Description |
---|---|---|
inputs |
Tensors |
input vector to compute samples from. |
model |
ModelV2 |
reference to model producing the inputs. |
__init__(self, inputs, model)
special
Initializes an ActionDist object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Tensors |
input vector to compute samples from. |
required |
model |
ModelV2 |
reference to model producing the inputs. This is mainly useful if you want to use model variables to compute action outputs (i.e., for auto-regressive action distributions, see examples/autoregressive_action_dist.py). |
required |
Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
def __init__(self, inputs: List[TensorType], model: ModelV2):
"""Initializes an ActionDist object.
Args:
inputs (Tensors): input vector to compute samples from.
model (ModelV2): reference to model producing the inputs. This
is mainly useful if you want to use model variables to compute
action outputs (i.e., for auto-regressive action distributions,
see examples/autoregressive_action_dist.py).
"""
self.inputs = inputs
self.model = model
deterministic_sample(self)
Get the deterministic "sampling" output from the distribution. This is usually the max likelihood output, i.e. mean for Normal, argmax for Categorical, etc..
entropy(self)
kl(self, other)
logp(self, x)
multi_entropy(self)
The entropy of the action distribution.
This differs from entropy() in that it can return an array for MultiDiscrete. TODO(ekl) consider removing this.
multi_kl(self, other)
The KL-divergence between two action distributions.
This differs from kl() in that it can return an array for MultiDiscrete. TODO(ekl) consider removing this.
required_model_output_shape(action_space, model_config)
staticmethod
Returns the required shape of an input parameter tensor for a particular action space and an optional dict of distribution-specific options.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_space |
gym.Space |
The action space this distribution will be used for, whose shape attributes will be used to determine the required shape of the input parameter tensor. |
required |
model_config |
dict |
Model's config dict (as defined in catalog.py) |
required |
Returns:
Type | Description |
---|---|
model_output_shape (int or np.ndarray of ints) |
size of the required input vector (minus leading batch dimension). |
Source code in ray/rllib/models/action_dist.py
@DeveloperAPI
@staticmethod
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
"""Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific
options.
Args:
action_space (gym.Space): The action space this distribution will
be used for, whose shape attributes will be used to determine
the required shape of the input parameter tensor.
model_config (dict): Model's config dict (as defined in catalog.py)
Returns:
model_output_shape (int or np.ndarray of ints): size of the
required input vector (minus leading batch dimension).
"""
raise NotImplementedError
sample(self)
ray.rllib.models.catalog.ModelCatalog
Registry of models, preprocessors, and action distributions for envs.
Examples:
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
... env.action_space, {})
>>> model = ModelCatalog.get_model_v2(
... obs_space, action_space, num_outputs, options)
>>> dist = dist_class(model.outputs, model)
>>> action = dist.sample()
get_action_dist(action_space, config, dist_type=None, framework='tf', **kwargs)
staticmethod
Returns a distribution class and size for the given action space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_space |
Space |
Action space of the target gym env. |
required |
config |
Optional[dict] |
Optional model config. |
required |
dist_type |
Optional[Union[str, Type[ActionDistribution]]] |
Identifier of the action distribution (str) interpreted as a hint or the actual ActionDistribution class to use. |
None |
framework |
str |
One of "tf2", "tf", "tfe", "torch", or "jax". |
'tf' |
kwargs |
dict |
Optional kwargs to pass on to the Distribution's constructor. |
{} |
Returns:
Type | Description |
---|---|
Tuple |
|
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_dist(
action_space: gym.Space,
config: ModelConfigDict,
dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
framework: str = "tf",
**kwargs) -> (type, int):
"""Returns a distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
config (Optional[dict]): Optional model config.
dist_type (Optional[Union[str, Type[ActionDistribution]]]):
Identifier of the action distribution (str) interpreted as a
hint or the actual ActionDistribution class to use.
framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
kwargs (dict): Optional kwargs to pass on to the Distribution's
constructor.
Returns:
Tuple:
- dist_class (ActionDistribution): Python class of the
distribution.
- dist_dim (int): The size of the input vector to the
distribution.
"""
dist_cls = None
config = config or MODEL_DEFAULTS
# Custom distribution given.
if config.get("custom_action_dist"):
custom_action_config = config.copy()
action_dist_name = custom_action_config.pop("custom_action_dist")
logger.debug(
"Using custom action distribution {}".format(action_dist_name))
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
action_dist_name)
return ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, custom_action_config, framework)
# Dist_type is given directly as a class.
elif type(dist_type) is type and \
issubclass(dist_type, ActionDistribution) and \
dist_type not in (
MultiActionDistribution, TorchMultiActionDistribution):
dist_cls = dist_type
# Box space -> DiagGaussian OR Deterministic.
elif isinstance(action_space, Box):
if action_space.dtype.name.startswith("int"):
low_ = np.min(action_space.low)
high_ = np.max(action_space.high)
dist_cls = TorchMultiCategorical if framework == "torch" \
else MultiCategorical
num_cats = int(np.product(action_space.shape))
return partial(
dist_cls,
input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
action_space=action_space), num_cats * (high_ - low_ + 1)
else:
if len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space has multiple dimensions "
"{}. ".format(action_space.shape) +
"Consider reshaping this into a single dimension, "
"using a custom action distribution, "
"using a Tuple action space, or the multi-agent API.")
# TODO(sven): Check for bounds and return SquashedNormal, etc..
if dist_type is None:
dist_cls = TorchDiagGaussian if framework == "torch" \
else DiagGaussian
elif dist_type == "deterministic":
dist_cls = TorchDeterministic if framework == "torch" \
else Deterministic
# Discrete Space -> Categorical.
elif isinstance(action_space, Discrete):
dist_cls = TorchCategorical if framework == "torch" else \
JAXCategorical if framework == "jax" else Categorical
# Tuple/Dict Spaces -> MultiAction.
elif dist_type in (MultiActionDistribution,
TorchMultiActionDistribution) or \
isinstance(action_space, (Tuple, Dict)):
return ModelCatalog._get_multi_action_distribution(
(MultiActionDistribution
if framework == "tf" else TorchMultiActionDistribution),
action_space, config, framework)
# Simplex -> Dirichlet.
elif isinstance(action_space, Simplex):
if framework == "torch":
# TODO(sven): implement
raise NotImplementedError(
"Simplex action spaces not supported for torch.")
dist_cls = Dirichlet
# MultiDiscrete -> MultiCategorical.
elif isinstance(action_space, MultiDiscrete):
dist_cls = TorchMultiCategorical if framework == "torch" else \
MultiCategorical
return partial(dist_cls, input_lens=action_space.nvec), \
int(sum(action_space.nvec))
# Unknown type -> Error.
else:
raise NotImplementedError("Unsupported args: {} {}".format(
action_space, dist_type))
return dist_cls, dist_cls.required_model_output_shape(
action_space, config)
get_action_placeholder(action_space, name='action')
staticmethod
Returns an action placeholder consistent with the action space
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_space |
Space |
Action space of the target gym env. |
required |
name |
str |
An optional string to name the placeholder by. Default: "action". |
'action' |
Returns:
Type | Description |
---|---|
action_placeholder (Tensor) |
A placeholder for the actions |
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_placeholder(action_space: gym.Space,
name: str = "action") -> TensorType:
"""Returns an action placeholder consistent with the action space
Args:
action_space (Space): Action space of the target gym env.
name (str): An optional string to name the placeholder by.
Default: "action".
Returns:
action_placeholder (Tensor): A placeholder for the actions
"""
dtype, shape = ModelCatalog.get_action_shape(
action_space, framework="tf")
return tf1.placeholder(dtype, shape=shape, name=name)
get_action_shape(action_space, framework='tf')
staticmethod
Returns action tensor dtype and shape for the action space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_space |
Space |
Action space of the target gym env. |
required |
framework |
str |
The framework identifier. One of "tf" or "torch". |
'tf' |
Returns:
Type | Description |
---|---|
(dtype, shape) |
Dtype and shape of the actions tensor. |
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_action_shape(action_space: gym.Space,
framework: str = "tf") -> (np.dtype, List[int]):
"""Returns action tensor dtype and shape for the action space.
Args:
action_space (Space): Action space of the target gym env.
framework (str): The framework identifier. One of "tf" or "torch".
Returns:
(dtype, shape): Dtype and shape of the actions tensor.
"""
dl_lib = torch if framework == "torch" else tf
if isinstance(action_space, Discrete):
return action_space.dtype, (None, )
elif isinstance(action_space, (Box, Simplex)):
if np.issubdtype(action_space.dtype, np.floating):
return dl_lib.float32, (None, ) + action_space.shape
elif np.issubdtype(action_space.dtype, np.integer):
return dl_lib.int32, (None, ) + action_space.shape
else:
raise ValueError(
"RLlib doesn't support non int or float box spaces")
elif isinstance(action_space, MultiDiscrete):
return action_space.dtype, (None, ) + action_space.shape
elif isinstance(action_space, (Tuple, Dict)):
flat_action_space = flatten_space(action_space)
size = 0
all_discrete = True
for i in range(len(flat_action_space)):
if isinstance(flat_action_space[i], Discrete):
size += 1
else:
all_discrete = False
size += np.product(flat_action_space[i].shape)
size = int(size)
return dl_lib.int32 if all_discrete else dl_lib.float32, \
(None, size)
else:
raise NotImplementedError(
"Action space {} not supported".format(action_space))
get_model_v2(obs_space, action_space, num_outputs, model_config, framework='tf', name='default_model', model_interface=None, default_model=None, **model_kwargs)
staticmethod
Returns a suitable model compatible with given spaces and output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obs_space |
Space |
Observation space of the target gym env. This
may have an |
required |
action_space |
Space |
Action space of the target gym env. |
required |
num_outputs |
int |
The size of the output vector of the model. |
required |
model_config |
ModelConfigDict |
The "model" sub-config dict within the Trainer's config dict. |
required |
framework |
str |
One of "tf2", "tf", "tfe", "torch", or "jax". |
'tf' |
name |
str |
Name (scope) for the model. |
'default_model' |
model_interface |
cls |
Interface required for the model |
None |
default_model |
cls |
Override the default class for the model. This only has an effect when not using a custom model |
None |
model_kwargs |
dict |
args to pass to the ModelV2 constructor |
{} |
Returns:
Type | Description |
---|---|
model (ModelV2) |
Model to use for the policy. |
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_model_v2(obs_space: gym.Space,
action_space: gym.Space,
num_outputs: int,
model_config: ModelConfigDict,
framework: str = "tf",
name: str = "default_model",
model_interface: type = None,
default_model: type = None,
**model_kwargs) -> ModelV2:
"""Returns a suitable model compatible with given spaces and output.
Args:
obs_space (Space): Observation space of the target gym env. This
may have an `original_space` attribute that specifies how to
unflatten the tensor into a ragged tensor.
action_space (Space): Action space of the target gym env.
num_outputs (int): The size of the output vector of the model.
model_config (ModelConfigDict): The "model" sub-config dict
within the Trainer's config dict.
framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
name (str): Name (scope) for the model.
model_interface (cls): Interface required for the model
default_model (cls): Override the default class for the model. This
only has an effect when not using a custom model
model_kwargs (dict): args to pass to the ModelV2 constructor
Returns:
model (ModelV2): Model to use for the policy.
"""
# Validate the given config dict.
ModelCatalog._validate_config(config=model_config, framework=framework)
if model_config.get("custom_model"):
# Allow model kwargs to be overridden / augmented by
# custom_model_config.
customized_model_kwargs = dict(
model_kwargs, **model_config.get("custom_model_config", {}))
if isinstance(model_config["custom_model"], type):
model_cls = model_config["custom_model"]
else:
model_cls = _global_registry.get(RLLIB_MODEL,
model_config["custom_model"])
# Only allow ModelV2 or native keras Models.
if not issubclass(model_cls, ModelV2):
if framework not in ["tf", "tf2", "tfe"] or \
not issubclass(model_cls, tf.keras.Model):
raise ValueError(
"`model_cls` must be a ModelV2 sub-class, but is"
" {}!".format(model_cls))
logger.info("Wrapping {} as {}".format(model_cls, model_interface))
model_cls = ModelCatalog._wrap_if_needed(model_cls,
model_interface)
if framework in ["tf2", "tf", "tfe"]:
# Try wrapping custom model with LSTM/attention, if required.
if model_config.get("use_lstm") or \
model_config.get("use_attention"):
from ray.rllib.models.tf.attention_net import \
AttentionWrapper, Keras_AttentionWrapper
from ray.rllib.models.tf.recurrent_net import \
LSTMWrapper, Keras_LSTMWrapper
wrapped_cls = model_cls
# Wrapped (custom) model is itself a keras Model ->
# wrap with keras LSTM/GTrXL (attention) wrappers.
if issubclass(wrapped_cls, tf.keras.Model):
model_cls = Keras_LSTMWrapper if \
model_config.get("use_lstm") else \
Keras_AttentionWrapper
model_config["wrapped_cls"] = wrapped_cls
# Wrapped (custom) model is ModelV2 ->
# wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
else:
forward = wrapped_cls.forward
model_cls = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper if
model_config.get("use_lstm") else AttentionWrapper)
model_cls._wrapped_forward = forward
# Obsolete: Track and warn if vars were created but not
# registered. Only still do this, if users do register their
# variables. If not (which they shouldn't), don't check here.
created = set()
def track_var_creation(next_creator, **kw):
v = next_creator(**kw)
created.add(v)
return v
with tf.variable_creator_scope(track_var_creation):
if issubclass(model_cls, tf.keras.Model):
instance = model_cls(
input_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
name=name,
**customized_model_kwargs,
)
else:
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(
obs_space,
action_space,
num_outputs,
model_config,
name,
**customized_model_kwargs,
)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(
obs_space,
action_space,
num_outputs,
model_config,
name,
**model_kwargs,
)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
# Other error -> re-raise.
else:
raise e
# User still registered TFModelV2's variables: Check, whether
# ok.
registered = []
if not isinstance(instance, tf.keras.Model):
registered = set(instance.var_list)
if len(registered) > 0:
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like you are still using "
"`{}.register_variables()` to register your "
"model's weights. This is no longer required, but "
"if you are still calling this method at least "
"once, you must make sure to register all created "
"variables properly. The missing variables are {},"
" and you only registered {}. "
"Did you forget to call `register_variables()` on "
"some of the variables in question?".format(
instance, not_registered, registered))
elif framework == "torch":
# Try wrapping custom model with LSTM/attention, if required.
if model_config.get("use_lstm") or \
model_config.get("use_attention"):
from ray.rllib.models.torch.attention_net import \
AttentionWrapper
from ray.rllib.models.torch.recurrent_net import \
LSTMWrapper
wrapped_cls = model_cls
forward = wrapped_cls.forward
model_cls = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper
if model_config.get("use_lstm") else AttentionWrapper)
model_cls._wrapped_forward = forward
# PyTorch automatically tracks nn.Modules inside the parent
# nn.Module's constructor.
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space, num_outputs,
model_config, name,
**customized_model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(obs_space, action_space,
num_outputs, model_config, name,
**model_kwargs)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
# Other error -> re-raise.
else:
raise e
else:
raise NotImplementedError(
"`framework` must be 'tf2|tf|tfe|torch', but is "
"{}!".format(framework))
return instance
# Find a default TFModelV2 and wrap with model_interface.
if framework in ["tf", "tfe", "tf2"]:
v2_class = None
# Try to get a default v2 model.
if not model_config.get("custom_model"):
v2_class = default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
if not v2_class:
raise ValueError("ModelV2 class could not be determined!")
if model_config.get("use_lstm") or \
model_config.get("use_attention"):
from ray.rllib.models.tf.attention_net import \
AttentionWrapper, Keras_AttentionWrapper
from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
Keras_LSTMWrapper
wrapped_cls = v2_class
if model_config.get("use_lstm"):
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_LSTMWrapper
model_config["wrapped_cls"] = wrapped_cls
else:
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper)
v2_class._wrapped_forward = wrapped_cls.forward
else:
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_AttentionWrapper
model_config["wrapped_cls"] = wrapped_cls
else:
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, AttentionWrapper)
v2_class._wrapped_forward = wrapped_cls.forward
# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
if issubclass(wrapper, tf.keras.Model):
model = wrapper(
input_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
name=name,
**dict(model_kwargs, **model_config),
)
return model
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
# Find a default TorchModelV2 and wrap with model_interface.
elif framework == "torch":
# Try to get a default v2 model.
if not model_config.get("custom_model"):
v2_class = default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
if not v2_class:
raise ValueError("ModelV2 class could not be determined!")
if model_config.get("use_lstm") or \
model_config.get("use_attention"):
from ray.rllib.models.torch.attention_net import \
AttentionWrapper
from ray.rllib.models.torch.recurrent_net import LSTMWrapper
wrapped_cls = v2_class
forward = wrapped_cls.forward
if model_config.get("use_lstm"):
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper)
else:
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, AttentionWrapper)
v2_class._wrapped_forward = forward
# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
# Find a default JAXModelV2 and wrap with model_interface.
elif framework == "jax":
v2_class = \
default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
else:
raise NotImplementedError(
"`framework` must be 'tf2|tf|tfe|torch', but is "
"{}!".format(framework))
get_preprocessor(env, options=None)
staticmethod
Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_preprocessor(env: gym.Env,
options: Optional[dict] = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
"""
return ModelCatalog.get_preprocessor_for_space(env.observation_space,
options)
get_preprocessor_for_space(observation_space, options=None)
staticmethod
Returns a suitable preprocessor for the given observation space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
observation_space |
Space |
The input observation space. |
required |
options |
dict |
Options to pass to the preprocessor. |
None |
Returns:
Type | Description |
---|---|
preprocessor (Preprocessor) |
Preprocessor for the observations. |
Source code in ray/rllib/models/catalog.py
@staticmethod
@DeveloperAPI
def get_preprocessor_for_space(observation_space: gym.Space,
options: dict = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given observation space.
Args:
observation_space (Space): The input observation space.
options (dict): Options to pass to the preprocessor.
Returns:
preprocessor (Preprocessor): Preprocessor for the observations.
"""
options = options or MODEL_DEFAULTS
for k in options.keys():
if k not in MODEL_DEFAULTS:
raise Exception("Unknown config key `{}`, all keys: {}".format(
k, list(MODEL_DEFAULTS)))
if options.get("custom_preprocessor"):
preprocessor = options["custom_preprocessor"]
logger.info("Using custom preprocessor {}".format(preprocessor))
logger.warning(
"DeprecationWarning: Custom preprocessors are deprecated, "
"since they sometimes conflict with the built-in "
"preprocessors for handling complex observation spaces. "
"Please use wrapper classes around your environment "
"instead of preprocessors.")
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
observation_space, options)
else:
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)
if prep is not None:
logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
return prep
register_custom_action_dist(action_dist_name, action_dist_class)
staticmethod
Register a custom action distribution class by name.
The model can be later used by specifying {"custom_action_dist": action_dist_name} in the model config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
Name to register the action distribution under. |
required |
model_class |
type |
Python class of the action distribution. |
required |
Source code in ray/rllib/models/catalog.py
@staticmethod
@PublicAPI
def register_custom_action_dist(action_dist_name: str,
action_dist_class: type) -> None:
"""Register a custom action distribution class by name.
The model can be later used by specifying
{"custom_action_dist": action_dist_name} in the model config.
Args:
model_name (str): Name to register the action distribution under.
model_class (type): Python class of the action distribution.
"""
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
action_dist_class)
register_custom_model(model_name, model_class)
staticmethod
Register a custom model class by name.
The model can be later used by specifying {"custom_model": model_name} in the model config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
Name to register the model under. |
required |
model_class |
type |
Python class of the model. |
required |
Source code in ray/rllib/models/catalog.py
@staticmethod
@PublicAPI
def register_custom_model(model_name: str, model_class: type) -> None:
"""Register a custom model class by name.
The model can be later used by specifying {"custom_model": model_name}
in the model config.
Args:
model_name (str): Name to register the model under.
model_class (type): Python class of the model.
"""
if tf is not None:
if issubclass(model_class, tf.keras.Model):
deprecation_warning(old="register_custom_model", error=False)
_global_registry.register(RLLIB_MODEL, model_name, model_class)
ray.rllib.models.modelv2.ModelV2
Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.
Data flow: obs -> forward() -> model_out value_function() -> V(s)
__init__(self, obs_space, action_space, num_outputs, model_config, name, framework)
special
Initializes a ModelV2 object.
This method should create any variables used by the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obs_space |
gym.spaces.Space |
Observation space of the target gym
env. This may have an |
required |
action_space |
gym.spaces.Space |
Action space of the target gym env. |
required |
num_outputs |
int |
Number of output units of the model. |
required |
model_config |
ModelConfigDict |
Config for the model, documented in ModelCatalog. |
required |
name |
str |
Name (scope) for the model. |
required |
framework |
str |
Either "tf" or "torch". |
required |
Source code in ray/rllib/models/modelv2.py
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str, framework: str):
"""Initializes a ModelV2 object.
This method should create any variables used by the model.
Args:
obs_space (gym.spaces.Space): Observation space of the target gym
env. This may have an `original_space` attribute that
specifies how to unflatten the tensor into a ragged tensor.
action_space (gym.spaces.Space): Action space of the target gym
env.
num_outputs (int): Number of output units of the model.
model_config (ModelConfigDict): Config for the model, documented
in ModelCatalog.
name (str): Name (scope) for the model.
framework (str): Either "tf" or "torch".
"""
self.obs_space: gym.spaces.Space = obs_space
self.action_space: gym.spaces.Space = action_space
self.num_outputs: int = num_outputs
self.model_config: ModelConfigDict = model_config
self.name: str = name or "default_model"
self.framework: str = framework
self._last_output = None
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
}
context(self)
custom_loss(self, policy_loss, loss_inputs)
Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model's layers).
You can find an runnable example in examples/custom_loss.py.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_loss |
Union[List[Tensor],Tensor] |
List of or single policy loss(es) from the policy. |
required |
loss_inputs |
dict |
map of input placeholders for rollout data. |
required |
Returns:
Type | Description |
---|---|
Union[List[Tensor],Tensor] |
List of or scalar tensor for the customized loss(es) for this model. |
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def custom_loss(self, policy_loss: TensorType,
loss_inputs: Dict[str, TensorType]) -> TensorType:
"""Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining
a loss over existing input and output tensors of this model), and
supervised losses (by defining losses over a variable-sharing copy of
this model's layers).
You can find an runnable example in examples/custom_loss.py.
Args:
policy_loss (Union[List[Tensor],Tensor]): List of or single policy
loss(es) from the policy.
loss_inputs (dict): map of input placeholders for rollout data.
Returns:
Union[List[Tensor],Tensor]: List of or scalar tensor for the
customized loss(es) for this model.
"""
return policy_loss
forward(self, input_dict, state, seq_lens)
Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by call before being passed to forward(). To access the flattened observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of call.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dict |
dict |
dictionary of input tensors, including "obs", "obs_flat", "prev_action", "prev_reward", "is_training", "eps_id", "agent_id", "infos", and "t". |
required |
state |
list |
list of state tensors with sizes matching those returned by get_initial_state + the batch dimension |
required |
seq_lens |
Tensor |
1d tensor holding input sequence lengths |
required |
Returns:
Type | Description |
---|---|
(outputs, state) |
The model output tensor of size [BATCH, num_outputs], and the new RNN state. |
Examples:
>>> def forward(self, input_dict, state, seq_lens):
>>> model_out, self._value_out = self.base_model(
... input_dict["obs"])
>>> return model_out, state
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Args:
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training",
"eps_id", "agent_id", "infos", and "t".
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs], and the new RNN state.
Examples:
>>> def forward(self, input_dict, state, seq_lens):
>>> model_out, self._value_out = self.base_model(
... input_dict["obs"])
>>> return model_out, state
"""
raise NotImplementedError
get_initial_state(self)
Get the initial recurrent state values for the model.
Returns:
Type | Description |
---|---|
List[np.ndarray] |
List of np.array objects containing the initial hidden state of an RNN, if applicable. |
Examples:
>>> def get_initial_state(self):
>>> return [
>>> np.zeros(self.cell_size, np.float32),
>>> np.zeros(self.cell_size, np.float32),
>>> ]
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def get_initial_state(self) -> List[np.ndarray]:
"""Get the initial recurrent state values for the model.
Returns:
List[np.ndarray]: List of np.array objects containing the initial
hidden state of an RNN, if applicable.
Examples:
>>> def get_initial_state(self):
>>> return [
>>> np.zeros(self.cell_size, np.float32),
>>> np.zeros(self.cell_size, np.float32),
>>> ]
"""
return []
import_from_h5(self, h5_file)
Imports weights from an h5 file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h5_file |
str |
The h5 file name to import weights from. |
required |
Examples:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
Source code in ray/rllib/models/modelv2.py
def import_from_h5(self, h5_file: str) -> None:
"""Imports weights from an h5 file.
Args:
h5_file (str): The h5 file name to import weights from.
Example:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
"""
raise NotImplementedError
is_time_major(self)
If True, data for calling this ModelV2 must be in time-major format.
Returns !!! bool "Whether this ModelV2 requires a time-major (TxBx...) data" format.
last_output(self)
metrics(self)
Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e., info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1
Returns:
Type | Description |
---|---|
Dict[str, TensorType] |
The custom metrics for this model. |
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def metrics(self) -> Dict[str, TensorType]:
"""Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e.,
info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1
Returns:
Dict[str, TensorType]: The custom metrics for this model.
"""
return {}
trainable_variables(self, as_dict=False)
Returns the list of trainable variables for this model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
as_dict(bool) |
Whether variables should be returned as dict-values (using descriptive keys). |
required |
Returns:
Type | Description |
---|---|
Union[List[any],Dict[str,any]] |
The list (or dict if |
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def trainable_variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list of trainable variables for this model.
Args:
as_dict(bool): Whether variables should be returned as dict-values
(using descriptive keys).
Returns:
Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
True) of all trainable (tf)/requires_grad (torch) variables
of this ModelV2.
"""
raise NotImplementedError
value_function(self)
Returns the value function output for the most recent forward pass.
Note that a forward
call has to be performed first, before this
methods can return anything and thus that calling this method does not
cause an extra forward pass through the network.
Returns:
Type | Description |
---|---|
Any |
value estimate tensor of shape [BATCH]. |
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def value_function(self) -> TensorType:
"""Returns the value function output for the most recent forward pass.
Note that a `forward` call has to be performed first, before this
methods can return anything and thus that calling this method does not
cause an extra forward pass through the network.
Returns:
value estimate tensor of shape [BATCH].
"""
raise NotImplementedError
variables(self, as_dict=False)
Returns the list (or a dict) of variables for this model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
as_dict(bool) |
Whether variables should be returned as dict-values (using descriptive str keys). |
required |
Returns:
Type | Description |
---|---|
Union[List[any],Dict[str,any]] |
The list (or dict if |
Source code in ray/rllib/models/modelv2.py
@PublicAPI
def variables(self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list (or a dict) of variables for this model.
Args:
as_dict(bool): Whether variables should be returned as dict-values
(using descriptive str keys).
Returns:
Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
True) of all variables of this ModelV2.
"""
raise NotImplementedError
ray.rllib.models.preprocessors.Preprocessor
Defines an abstract observation preprocessor function.
Attributes:
Name | Type | Description |
---|---|---|
shape |
List[int] |
Shape of the preprocessed output. |
check_shape(self, observation)
Checks the shape of the given observation.
Source code in ray/rllib/models/preprocessors.py
def check_shape(self, observation: Any) -> None:
"""Checks the shape of the given observation."""
if self._i % OBS_VALIDATION_INTERVAL == 0:
# Convert lists to np.ndarrays.
if type(observation) is list and isinstance(
self._obs_space, gym.spaces.Box):
observation = np.array(observation).astype(np.float32)
if not self._obs_space.contains(observation):
observation = convert_element_to_space_type(
observation, self._obs_for_type_matching)
try:
if not self._obs_space.contains(observation):
raise ValueError(
"Observation ({} dtype={}) outside given space ({})!",
observation, observation.dtype if isinstance(
self._obs_space,
gym.spaces.Box) else None, self._obs_space)
except AttributeError:
raise ValueError(
"Observation for a Box/MultiBinary/MultiDiscrete space "
"should be an np.array, not a Python list.", observation)
self._i += 1
transform(self, observation)
write(self, observation, array, offset)
Alternative to transform for more efficient flattening.