Skip to content

utils Package API Reference

ray.rllib.utils.annotations.override(cls)

Annotation for documenting method overrides.

Parameters:

Name Type Description Default
cls type

The superclass that provides the overridden method. If this cls does not actually have the method, an error is raised.

required
Source code in ray/rllib/utils/annotations.py
def override(cls):
    """Annotation for documenting method overrides.

    Args:
        cls (type): The superclass that provides the overridden method. If this
            cls does not actually have the method, an error is raised.
    """

    def check_override(method):
        if method.__name__ not in dir(cls):
            raise NameError("{} does not override any method of {}".format(
                method, cls))
        return method

    return check_override

ray.rllib.utils.annotations.PublicAPI(obj)

Annotation for documenting public APIs.

Public APIs are classes and methods exposed to end users of RLlib. You can expect these APIs to remain stable across RLlib releases.

Subclasses that inherit from a @PublicAPI base class can be assumed part of the RLlib public API as well (e.g., all trainer classes are in public API because Trainer is @PublicAPI).

In addition, you can assume all trainer configurations are part of their public API as well.

Source code in ray/rllib/utils/annotations.py
def PublicAPI(obj):
    """Annotation for documenting public APIs.

    Public APIs are classes and methods exposed to end users of RLlib. You
    can expect these APIs to remain stable across RLlib releases.

    Subclasses that inherit from a ``@PublicAPI`` base class can be
    assumed part of the RLlib public API as well (e.g., all trainer classes
    are in public API because Trainer is ``@PublicAPI``).

    In addition, you can assume all trainer configurations are part of their
    public API as well.
    """

    return obj

ray.rllib.utils.annotations.DeveloperAPI(obj)

Annotation for documenting developer APIs.

Developer APIs are classes and methods explicitly exposed to developers for the purposes of building custom algorithms or advanced training strategies on top of RLlib internals. You can generally expect these APIs to be stable sans minor changes (but less stable than public APIs).

Subclasses that inherit from a @DeveloperAPI base class can be assumed part of the RLlib developer API as well.

Source code in ray/rllib/utils/annotations.py
def DeveloperAPI(obj):
    """Annotation for documenting developer APIs.

    Developer APIs are classes and methods explicitly exposed to developers
    for the purposes of building custom algorithms or advanced training
    strategies on top of RLlib internals. You can generally expect these APIs
    to be stable sans minor changes (but less stable than public APIs).

    Subclasses that inherit from a ``@DeveloperAPI`` base class can be
    assumed part of the RLlib developer API as well.
    """

    return obj

ray.rllib.utils.framework.try_import_tf(error=False)

Tries importing tf and returns the module (or None).

Parameters:

Name Type Description Default
error bool

Whether to raise an error if tf cannot be imported.

False

Returns:

Type Description
Tuple containing 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). 2) tf module (resulting from `import tensorflow`). Either tf1.x or 2.x. 3) The actually installed tf version as int

1 or 2.

Exceptions:

Type Description
ImportError

If error=True and tf is not installed.

Source code in ray/rllib/utils/framework.py
def try_import_tf(error: bool = False):
    """Tries importing tf and returns the module (or None).

    Args:
        error: Whether to raise an error if tf cannot be imported.

    Returns:
        Tuple containing
        1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
        2) tf module (resulting from `import tensorflow`). Either tf1.x or
        2.x. 3) The actually installed tf version as int: 1 or 2.

    Raises:
        ImportError: If error=True and tf is not installed.
    """
    # Make sure, these are reset after each test case
    # that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
    if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
        logger.warning("Not importing TensorFlow for test purposes")
        return None, None, None

    if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

    # Try to reuse already imported tf module. This will avoid going through
    # the initial import steps below and thereby switching off v2_behavior
    # (switching off v2 behavior twice breaks all-framework tests for eager).
    was_imported = False
    if "tensorflow" in sys.modules:
        tf_module = sys.modules["tensorflow"]
        was_imported = True

    else:
        try:
            import tensorflow as tf_module
        except ImportError:
            if error:
                raise ImportError(
                    "Could not import TensorFlow! RLlib requires you to "
                    "install at least one deep-learning framework: "
                    "`pip install [torch|tensorflow|jax]`.")
            return None, None, None

    # Try "reducing" tf to tf.compat.v1.
    try:
        tf1_module = tf_module.compat.v1
        tf1_module.logging.set_verbosity(tf1_module.logging.ERROR)
        if not was_imported:
            tf1_module.disable_v2_behavior()
            tf1_module.enable_resource_variables()
        tf1_module.logging.set_verbosity(tf1_module.logging.WARN)
    # No compat.v1 -> return tf as is.
    except AttributeError:
        tf1_module = tf_module

    if not hasattr(tf_module, "__version__"):
        version = 1  # sphinx doc gen
    else:
        version = 2 if "2." in tf_module.__version__[:2] else 1

    return tf1_module, tf_module, version

ray.rllib.utils.framework.try_import_tfp(error=False)

Tries importing tfp and returns the module (or None).

Parameters:

Name Type Description Default
error bool

Whether to raise an error if tfp cannot be imported.

False

Returns:

Type Description

The tfp module.

Exceptions:

Type Description
ImportError

If error=True and tfp is not installed.

Source code in ray/rllib/utils/framework.py
def try_import_tfp(error: bool = False):
    """Tries importing tfp and returns the module (or None).

    Args:
        error: Whether to raise an error if tfp cannot be imported.

    Returns:
        The tfp module.

    Raises:
        ImportError: If error=True and tfp is not installed.
    """
    if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
        logger.warning("Not importing TensorFlow Probability for test "
                       "purposes.")
        return None

    try:
        import tensorflow_probability as tfp
        return tfp
    except ImportError as e:
        if error:
            raise e
        return None

ray.rllib.utils.framework.try_import_torch(error=False)

Tries importing torch and returns the module (or None).

Parameters:

Name Type Description Default
error bool

Whether to raise an error if torch cannot be imported.

False

Returns:

Type Description

Tuple consisting of the torch- AND torch.nn modules.

Exceptions:

Type Description
ImportError

If error=True and PyTorch is not installed.

Source code in ray/rllib/utils/framework.py
def try_import_torch(error: bool = False):
    """Tries importing torch and returns the module (or None).

    Args:
        error: Whether to raise an error if torch cannot be imported.

    Returns:
        Tuple consisting of the torch- AND torch.nn modules.

    Raises:
        ImportError: If error=True and PyTorch is not installed.
    """
    if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
        logger.warning("Not importing PyTorch for test purposes.")
        return _torch_stubs()

    try:
        import torch
        import torch.nn as nn
        return torch, nn
    except ImportError:
        if error:
            raise ImportError(
                "Could not import PyTorch! RLlib requires you to "
                "install at least one deep-learning framework: "
                "`pip install [torch|tensorflow|jax]`.")
        return _torch_stubs()

ray.rllib.utils.deprecation.deprecation_warning(old, new=None, *, help=None, error=None)

Warns (via the logger object) or throws a deprecation warning/error.

Parameters:

Name Type Description Default
old str

A description of the "thing" that is to be deprecated.

required
new Optional[str]

A description of the new "thing" that replaces it.

None
help Optional[str]

An optional help text to tell the user, what to do instead of using old.

None
error Optional[Union[bool, Exception]]

Whether or which exception to raise. If True, raise ValueError. If False, just warn. If error is-a subclass of Exception, raise that Exception.

None

Exceptions:

Type Description
ValueError

If error=True.

Exception

Of type error, iff error is-a Exception subclass.

Source code in ray/rllib/utils/deprecation.py
def deprecation_warning(
        old: str,
        new: Optional[str] = None,
        *,
        help: Optional[str] = None,
        error: Optional[Union[bool, Exception]] = None) -> None:
    """Warns (via the `logger` object) or throws a deprecation warning/error.

    Args:
        old (str): A description of the "thing" that is to be deprecated.
        new (Optional[str]): A description of the new "thing" that replaces it.
        help (Optional[str]): An optional help text to tell the user, what to
            do instead of using `old`.
        error (Optional[Union[bool, Exception]]): Whether or which exception to
            raise. If True, raise ValueError. If False, just warn.
            If error is-a subclass of Exception, raise that Exception.

    Raises:
        ValueError: If `error=True`.
        Exception: Of type `error`, iff error is-a Exception subclass.
    """
    msg = "`{}` has been deprecated.{}".format(
        old, (" Use `{}` instead.".format(new) if new else f" {help}"
              if help else ""))

    if error is True:
        raise ValueError(msg)
    elif error and issubclass(error, Exception):
        raise error(msg)
    else:
        logger.warning("DeprecationWarning: " + msg +
                       " This will raise an error in the future!")

ray.rllib.utils.filter_manager.FilterManager

Manages filters and coordination across remote evaluators that expose get_filters and sync_filters.

synchronize(local_filters, remotes, update_remote=True) staticmethod

Aggregates all filters from remote evaluators.

Local copy is updated and then broadcasted to all remote evaluators.

Parameters:

Name Type Description Default
local_filters dict

Filters to be synchronized.

required
remotes list

Remote evaluators with filters.

required
update_remote bool

Whether to push updates to remote filters.

True
Source code in ray/rllib/utils/filter_manager.py
@staticmethod
@DeveloperAPI
def synchronize(local_filters, remotes, update_remote=True):
    """Aggregates all filters from remote evaluators.

    Local copy is updated and then broadcasted to all remote evaluators.

    Args:
        local_filters (dict): Filters to be synchronized.
        remotes (list): Remote evaluators with filters.
        update_remote (bool): Whether to push updates to remote filters.
    """
    remote_filters = ray.get(
        [r.get_filters.remote(flush_after=True) for r in remotes])
    for rf in remote_filters:
        for k in local_filters:
            local_filters[k].apply_changes(rf[k], with_buffer=False)
    if update_remote:
        copies = {k: v.as_serializable() for k, v in local_filters.items()}
        remote_copy = ray.put(copies)
        [r.sync_filters.remote(remote_copy) for r in remotes]

ray.rllib.utils.filter.Filter

Processes input, possibly statefully.

apply_changes(self, other, *args, **kwargs)

Updates self with "new state" from other filter.

Source code in ray/rllib/utils/filter.py
def apply_changes(self, other, *args, **kwargs):
    """Updates self with "new state" from other filter."""
    raise NotImplementedError

clear_buffer(self)

Creates copy of current state and clears accumulated state

Source code in ray/rllib/utils/filter.py
def clear_buffer(self):
    """Creates copy of current state and clears accumulated state"""
    raise NotImplementedError

copy(self)

Creates a new object with same state as self.

Returns:

Type Description

A copy of self.

Source code in ray/rllib/utils/filter.py
def copy(self):
    """Creates a new object with same state as self.

    Returns:
        A copy of self.
    """
    raise NotImplementedError

sync(self, other)

Copies all state from other filter to self.

Source code in ray/rllib/utils/filter.py
def sync(self, other):
    """Copies all state from other filter to self."""
    raise NotImplementedError

ray.rllib.utils.numpy.sigmoid(x, derivative=False)

Returns the sigmoid function applied to x. Alternatively, can return the derivative or the sigmoid function.

Parameters:

Name Type Description Default
x ndarray

The input to the sigmoid function.

required
derivative bool

Whether to return the derivative or not. Default: False.

False

Returns:

Type Description
ndarray

The sigmoid function (or its derivative) applied to x.

Source code in ray/rllib/utils/numpy.py
def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray:
    """
    Returns the sigmoid function applied to x.
    Alternatively, can return the derivative or the sigmoid function.

    Args:
        x: The input to the sigmoid function.
        derivative: Whether to return the derivative or not.
            Default: False.

    Returns:
        The sigmoid function (or its derivative) applied to x.
    """
    if derivative:
        return x * (1 - x)
    else:
        return 1 / (1 + np.exp(-x))

ray.rllib.utils.numpy.softmax(x, axis=-1, epsilon=None)

Returns the softmax values for x.

The exact formula used is: S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x.

Parameters:

Name Type Description Default
x ndarray

The input to the softmax function.

required
axis int

The axis along which to softmax.

-1
epsilon Optional[float]

Optional epsilon as a minimum value. If None, use SMALL_NUMBER.

None

Returns:

Type Description
ndarray

The softmax over x.

Source code in ray/rllib/utils/numpy.py
def softmax(x: np.ndarray, axis: int = -1,
            epsilon: Optional[float] = None) -> np.ndarray:
    """Returns the softmax values for x.

    The exact formula used is:
    S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x.

    Args:
        x: The input to the softmax function.
        axis: The axis along which to softmax.
        epsilon: Optional epsilon as a minimum value. If None, use
            `SMALL_NUMBER`.

    Returns:
        The softmax over x.
    """
    epsilon = epsilon or SMALL_NUMBER
    # x_exp = np.maximum(np.exp(x), SMALL_NUMBER)
    x_exp = np.exp(x)
    # return x_exp /
    #   np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER)
    return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), epsilon)

ray.rllib.utils.numpy.relu(x, alpha=0.0)

Implementation of the leaky ReLU function.

y = x * alpha if x < 0 else x

Parameters:

Name Type Description Default
x ndarray

The input values.

required
alpha float

A scaling ("leak") factor to use for negative x.

0.0

Returns:

Type Description
ndarray

The leaky ReLU output for x.

Source code in ray/rllib/utils/numpy.py
def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray:
    """Implementation of the leaky ReLU function.

    y = x * alpha if x < 0 else x

    Args:
        x: The input values.
        alpha: A scaling ("leak") factor to use for negative x.

    Returns:
        The leaky ReLU output for x.
    """
    return np.maximum(x, x * alpha, x)

ray.rllib.utils.numpy.one_hot(x, depth=0, on_value=1.0, off_value=0.0)

One-hot utility function for numpy.

Thanks to qianyizhang: https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.

Parameters:

Name Type Description Default
x Union[Any, int]

The input to be one-hot encoded.

required
depth int

The max. number to be one-hot encoded (size of last rank).

0
on_value int

The value to use for on. Default: 1.0.

1.0
off_value float

The value to use for off. Default: 0.0.

0.0

Returns:

Type Description
ndarray

The one-hot encoded equivalent of the input array.

Source code in ray/rllib/utils/numpy.py
def one_hot(x: Union[TensorType, int],
            depth: int = 0,
            on_value: int = 1.0,
            off_value: float = 0.0) -> np.ndarray:
    """One-hot utility function for numpy.

    Thanks to qianyizhang:
    https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.

    Args:
        x: The input to be one-hot encoded.
        depth: The max. number to be one-hot encoded (size of last rank).
        on_value: The value to use for on. Default: 1.0.
        off_value: The value to use for off. Default: 0.0.

    Returns:
        The one-hot encoded equivalent of the input array.
    """

    # Handle simple ints properly.
    if isinstance(x, int):
        x = np.array(x, dtype=np.int32)
    # Handle torch arrays properly.
    elif torch and isinstance(x, torch.Tensor):
        x = x.numpy()

    # Handle bool arrays correctly.
    if x.dtype == np.bool_:
        x = x.astype(np.int)
        depth = 2

    # If depth is not given, try to infer it from the values in the array.
    if depth == 0:
        depth = np.max(x) + 1
    assert np.max(x) < depth, \
        "ERROR: The max. index of `x` ({}) is larger than depth ({})!".\
        format(np.max(x), depth)
    shape = x.shape

    # Python 2.7 compatibility, (*shape, depth) is not allowed.
    shape_list = list(shape[:])
    shape_list.append(depth)
    out = np.ones(shape_list) * off_value
    indices = []
    for i in range(x.ndim):
        tiles = [1] * x.ndim
        s = [1] * x.ndim
        s[i] = -1
        r = np.arange(shape[i]).reshape(s)
        if i > 0:
            tiles[i - 1] = shape[i - 1]
            r = np.tile(r, tiles)
        indices.append(r)
    indices.append(x)
    out[tuple(indices)] = on_value
    return out

ray.rllib.utils.numpy.fc(x, weights, biases=None, framework=None)

Calculates FC (dense) layer outputs given weights/biases and input.

Parameters:

Name Type Description Default
x ndarray

The input to the dense layer.

required
weights ndarray

The weights matrix.

required
biases Optional[numpy.ndarray]

The biases vector. All 0s if None.

None
framework Optional[str]

An optional framework hint (to figure out, e.g. whether to transpose torch weight matrices).

None

Returns:

Type Description
ndarray

The dense layer's output.

Source code in ray/rllib/utils/numpy.py
def fc(x: np.ndarray,
       weights: np.ndarray,
       biases: Optional[np.ndarray] = None,
       framework: Optional[str] = None) -> np.ndarray:
    """Calculates FC (dense) layer outputs given weights/biases and input.

    Args:
        x: The input to the dense layer.
        weights: The weights matrix.
        biases: The biases vector. All 0s if None.
        framework: An optional framework hint (to figure out,
            e.g. whether to transpose torch weight matrices).

    Returns:
        The dense layer's output.
    """

    def map_(data, transpose=False):
        if torch:
            if isinstance(data, torch.Tensor):
                data = data.cpu().detach().numpy()
        if tf and tf.executing_eagerly():
            if isinstance(data, tf.Variable):
                data = data.numpy()
        if transpose:
            data = np.transpose(data)
        return data

    x = map_(x)
    # Torch stores matrices in transpose (faster for backprop).
    transpose = (framework == "torch" and (x.shape[1] != weights.shape[0]
                                           and x.shape[1] == weights.shape[1]))
    weights = map_(weights, transpose=transpose)
    biases = map_(biases)

    return np.matmul(x, weights) + (0.0 if biases is None else biases)

ray.rllib.utils.numpy.lstm(x, weights, biases=None, initial_internal_states=None, time_major=False, forget_bias=1.0)

Calculates LSTM layer output given weights/biases, states, and input.

Parameters:

Name Type Description Default
x

The inputs to the LSTM layer including time-rank (0th if time-major, else 1st) and the batch-rank (1st if time-major, else 0th).

required
weights ndarray

The weights matrix.

required
biases Optional[numpy.ndarray]

The biases vector. All 0s if None.

None
initial_internal_states Optional[numpy.ndarray]

The initial internal states to pass into the layer. All 0s if None.

None
time_major bool

Whether to use time-major or not. Default: False.

False
forget_bias float

Gets added to first sigmoid (forget gate) output. Default: 1.0.

1.0

Returns:

Type Description
Tuple consisting of 1) The LSTM layer's output and 2) Tuple

Last (c-state, h-state).

Source code in ray/rllib/utils/numpy.py
def lstm(x,
         weights: np.ndarray,
         biases: Optional[np.ndarray] = None,
         initial_internal_states: Optional[np.ndarray] = None,
         time_major: bool = False,
         forget_bias: float = 1.0):
    """Calculates LSTM layer output given weights/biases, states, and input.

    Args:
        x: The inputs to the LSTM layer including time-rank
            (0th if time-major, else 1st) and the batch-rank
            (1st if time-major, else 0th).
        weights: The weights matrix.
        biases: The biases vector. All 0s if None.
        initial_internal_states: The initial internal
            states to pass into the layer. All 0s if None.
        time_major: Whether to use time-major or not. Default: False.
        forget_bias: Gets added to first sigmoid (forget gate) output.
            Default: 1.0.

    Returns:
        Tuple consisting of 1) The LSTM layer's output and
        2) Tuple: Last (c-state, h-state).
    """
    sequence_length = x.shape[0 if time_major else 1]
    batch_size = x.shape[1 if time_major else 0]
    units = weights.shape[1] // 4  # 4 internal layers (3x sigmoid, 1x tanh)

    if initial_internal_states is None:
        c_states = np.zeros(shape=(batch_size, units))
        h_states = np.zeros(shape=(batch_size, units))
    else:
        c_states = initial_internal_states[0]
        h_states = initial_internal_states[1]

    # Create a placeholder for all n-time step outputs.
    if time_major:
        unrolled_outputs = np.zeros(shape=(sequence_length, batch_size, units))
    else:
        unrolled_outputs = np.zeros(shape=(batch_size, sequence_length, units))

    # Push the batch 4 times through the LSTM cell and capture the outputs plus
    # the final h- and c-states.
    for t in range(sequence_length):
        input_matrix = x[t, :, :] if time_major else x[:, t, :]
        input_matrix = np.concatenate((input_matrix, h_states), axis=1)
        input_matmul_matrix = np.matmul(input_matrix, weights) + biases
        # Forget gate (3rd slot in tf output matrix). Add static forget bias.
        sigmoid_1 = sigmoid(input_matmul_matrix[:, units * 2:units * 3] +
                            forget_bias)
        c_states = np.multiply(c_states, sigmoid_1)
        # Add gate (1st and 2nd slots in tf output matrix).
        sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units])
        tanh_3 = np.tanh(input_matmul_matrix[:, units:units * 2])
        c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3))
        # Output gate (last slot in tf output matrix).
        sigmoid_4 = sigmoid(input_matmul_matrix[:, units * 3:units * 4])
        h_states = np.multiply(sigmoid_4, np.tanh(c_states))

        # Store this output time-slice.
        if time_major:
            unrolled_outputs[t, :, :] = h_states
        else:
            unrolled_outputs[:, t, :] = h_states

    return unrolled_outputs, (c_states, h_states)

ray.rllib.utils.schedules.linear_schedule.LinearSchedule (PolynomialSchedule)

Linear interpolation between initial_p and final_p. Simply uses Polynomial with power=1.0.

final_p + (initial_p - final_p) * (1 - t/t_max)

ray.rllib.utils.schedules.piecewise_schedule.PiecewiseSchedule (Schedule)

__init__(self, endpoints, framework, interpolation=<function _linear_interpolation at 0x11030d3b0>, outside_value=None) special

Parameters:

Name Type Description Default
endpoints List[Tuple[int,float]]

A list of tuples (t, value) such that the output is an interpolation (given by the interpolation callable) between two values. E.g. t=400 and endpoints=[(0, 20.0),(500, 30.0)] output=20.0 + 0.8 * (30.0 - 20.0) = 28.0 NOTE: All the values for time must be sorted in an increasing order.

required
interpolation callable

A function that takes the left-value, the right-value and an alpha interpolation parameter (0.0=only left value, 1.0=only right value), which is the fraction of distance from left endpoint to right endpoint.

<function _linear_interpolation at 0x11030d3b0>
outside_value Optional[float]

If t in call to value is outside of all the intervals in endpoints this value is returned. If None then an AssertionError is raised when outside value is requested.

None
Source code in ray/rllib/utils/schedules/piecewise_schedule.py
def __init__(self,
             endpoints,
             framework,
             interpolation=_linear_interpolation,
             outside_value=None):
    """
    Args:
        endpoints (List[Tuple[int,float]]): A list of tuples
            `(t, value)` such that the output
            is an interpolation (given by the `interpolation` callable)
            between two values.
            E.g.
            t=400 and endpoints=[(0, 20.0),(500, 30.0)]
            output=20.0 + 0.8 * (30.0 - 20.0) = 28.0
            NOTE: All the values for time must be sorted in an increasing
            order.

        interpolation (callable): A function that takes the left-value,
            the right-value and an alpha interpolation parameter
            (0.0=only left value, 1.0=only right value), which is the
            fraction of distance from left endpoint to right endpoint.

        outside_value (Optional[float]): If t in call to `value` is
            outside of all the intervals in `endpoints` this value is
            returned. If None then an AssertionError is raised when outside
            value is requested.
    """
    super().__init__(framework=framework)

    idxes = [e[0] for e in endpoints]
    assert idxes == sorted(idxes)
    self.interpolation = interpolation
    self.outside_value = outside_value
    self.endpoints = [(int(e[0]), float(e[1])) for e in endpoints]

ray.rllib.utils.schedules.polynomial_schedule.PolynomialSchedule (Schedule)

__init__(self, schedule_timesteps, final_p, framework, initial_p=1.0, power=2.0) special

Polynomial interpolation between initial_p and final_p over schedule_timesteps. After this many time steps, always final_p is returned.

!!! agrs schedule_timesteps (int): Number of time steps for which to linearly anneal initial_p to final_p final_p (float): Final output value. initial_p (float): Initial output value. framework (Optional[str]): One of "tf", "torch", or None.

Source code in ray/rllib/utils/schedules/polynomial_schedule.py
def __init__(self,
             schedule_timesteps,
             final_p,
             framework,
             initial_p=1.0,
             power=2.0):
    """
    Polynomial interpolation between initial_p and final_p over
    schedule_timesteps. After this many time steps, always `final_p` is
    returned.

    Agrs:
        schedule_timesteps (int): Number of time steps for which to
            linearly anneal initial_p to final_p
        final_p (float): Final output value.
        initial_p (float): Initial output value.
        framework (Optional[str]): One of "tf", "torch", or None.
    """
    super().__init__(framework=framework)
    assert schedule_timesteps > 0
    self.schedule_timesteps = schedule_timesteps
    self.final_p = final_p
    self.initial_p = initial_p
    self.power = power

ray.rllib.utils.schedules.exponential_schedule.ExponentialSchedule (Schedule)

__init__(self, schedule_timesteps, framework, initial_p=1.0, decay_rate=0.1) special

Exponential decay schedule from initial_p to final_p over schedule_timesteps. After this many time steps always final_p is returned.

!!! agrs schedule_timesteps (int): Number of time steps for which to linearly anneal initial_p to final_p initial_p (float): Initial output value. decay_rate (float): The percentage of the original value after 100% of the time has been reached (see formula above). >0.0: The smaller the decay-rate, the stronger the decay. 1.0: No decay at all. framework (Optional[str]): One of "tf", "torch", or None.

Source code in ray/rllib/utils/schedules/exponential_schedule.py
def __init__(self,
             schedule_timesteps,
             framework,
             initial_p=1.0,
             decay_rate=0.1):
    """
    Exponential decay schedule from initial_p to final_p over
    schedule_timesteps. After this many time steps always `final_p` is
    returned.

    Agrs:
        schedule_timesteps (int): Number of time steps for which to
            linearly anneal initial_p to final_p
        initial_p (float): Initial output value.
        decay_rate (float): The percentage of the original value after
            100% of the time has been reached (see formula above).
            >0.0: The smaller the decay-rate, the stronger the decay.
            1.0: No decay at all.
        framework (Optional[str]): One of "tf", "torch", or None.
    """
    super().__init__(framework=framework)
    assert schedule_timesteps > 0
    self.schedule_timesteps = schedule_timesteps
    self.initial_p = initial_p
    self.decay_rate = decay_rate

ray.rllib.utils.schedules.constant_schedule.ConstantSchedule (Schedule)

A Schedule where the value remains constant over time.

__init__(self, value, framework) special

Parameters:

Name Type Description Default
value float

The constant value to return, independently of time.

required
Source code in ray/rllib/utils/schedules/constant_schedule.py
def __init__(self, value, framework):
    """
    Args:
        value (float): The constant value to return, independently of time.
    """
    super().__init__(framework=framework)
    self._v = value

ray.rllib.utils.test_utils.check(x, y, decimals=5, atol=None, rtol=None, false=False)

Checks two structures (dict, tuple, list, np.array, float, int, etc..) for (almost) numeric identity. All numbers in the two structures have to match up to decimal digits after the floating point. Uses assertions.

Parameters:

Name Type Description Default
x any

The value to be compared (to the expectation: y). This may be a Tensor.

required
y any

The expected value to be compared to x. This must not be a tf-Tensor, but may be a tfe/torch-Tensor.

required
decimals int

The number of digits after the floating point up to which all numeric values have to match.

5
atol float

Absolute tolerance of the difference between x and y (overrides decimals if given).

None
rtol float

Relative tolerance of the difference between x and y (overrides decimals if given).

None
false bool

Whether to check that x and y are NOT the same.

False
Source code in ray/rllib/utils/test_utils.py
def check(x, y, decimals=5, atol=None, rtol=None, false=False):
    """
    Checks two structures (dict, tuple, list,
    np.array, float, int, etc..) for (almost) numeric identity.
    All numbers in the two structures have to match up to `decimal` digits
    after the floating point. Uses assertions.

    Args:
        x (any): The value to be compared (to the expectation: `y`). This
            may be a Tensor.
        y (any): The expected value to be compared to `x`. This must not
            be a tf-Tensor, but may be a tfe/torch-Tensor.
        decimals (int): The number of digits after the floating point up to
            which all numeric values have to match.
        atol (float): Absolute tolerance of the difference between x and y
            (overrides `decimals` if given).
        rtol (float): Relative tolerance of the difference between x and y
            (overrides `decimals` if given).
        false (bool): Whether to check that x and y are NOT the same.
    """
    # A dict type.
    if isinstance(x, dict):
        assert isinstance(y, dict), \
            "ERROR: If x is dict, y needs to be a dict as well!"
        y_keys = set(x.keys())
        for key, value in x.items():
            assert key in y, \
                "ERROR: y does not have x's key='{}'! y={}".format(key, y)
            check(
                value,
                y[key],
                decimals=decimals,
                atol=atol,
                rtol=rtol,
                false=false)
            y_keys.remove(key)
        assert not y_keys, \
            "ERROR: y contains keys ({}) that are not in x! y={}".\
            format(list(y_keys), y)
    # A tuple type.
    elif isinstance(x, (tuple, list)):
        assert isinstance(y, (tuple, list)),\
            "ERROR: If x is tuple, y needs to be a tuple as well!"
        assert len(y) == len(x),\
            "ERROR: y does not have the same length as x ({} vs {})!".\
            format(len(y), len(x))
        for i, value in enumerate(x):
            check(
                value,
                y[i],
                decimals=decimals,
                atol=atol,
                rtol=rtol,
                false=false)
    # Boolean comparison.
    elif isinstance(x, (np.bool_, bool)):
        if false is True:
            assert bool(x) is not bool(y), \
                "ERROR: x ({}) is y ({})!".format(x, y)
        else:
            assert bool(x) is bool(y), \
                "ERROR: x ({}) is not y ({})!".format(x, y)
    # Nones or primitives.
    elif x is None or y is None or isinstance(x, (str, int)):
        if false is True:
            assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
        else:
            assert x == y, \
                "ERROR: x ({}) is not the same as y ({})!".format(x, y)
    # String/byte comparisons.
    elif hasattr(x, "dtype") and \
            (x.dtype == np.object or str(x.dtype).startswith("<U")):
        try:
            np.testing.assert_array_equal(x, y)
            if false is True:
                assert False, \
                    "ERROR: x ({}) is the same as y ({})!".format(x, y)
        except AssertionError as e:
            if false is False:
                raise e
    # Everything else (assume numeric or tf/torch.Tensor).
    else:
        if tf1 is not None:
            # y should never be a Tensor (y=expected value).
            if isinstance(y, (tf1.Tensor, tf1.Variable)):
                # In eager mode, numpyize tensors.
                if tf.executing_eagerly():
                    y = y.numpy()
                else:
                    raise ValueError(
                        "`y` (expected value) must not be a Tensor. "
                        "Use numpy.ndarray instead")
            if isinstance(x, (tf1.Tensor, tf1.Variable)):
                # In eager mode, numpyize tensors.
                if tf1.executing_eagerly():
                    x = x.numpy()
                # Otherwise, use a new tf-session.
                else:
                    with tf1.Session() as sess:
                        x = sess.run(x)
                        return check(
                            x,
                            y,
                            decimals=decimals,
                            atol=atol,
                            rtol=rtol,
                            false=false)
        if torch is not None:
            if isinstance(x, torch.Tensor):
                x = x.detach().cpu().numpy()
            if isinstance(y, torch.Tensor):
                y = y.detach().cpu().numpy()

        # Using decimals.
        if atol is None and rtol is None:
            # Assert equality of both values.
            try:
                np.testing.assert_almost_equal(x, y, decimal=decimals)
            # Both values are not equal.
            except AssertionError as e:
                # Raise error in normal case.
                if false is False:
                    raise e
            # Both values are equal.
            else:
                # If false is set -> raise error (not expected to be equal).
                if false is True:
                    assert False, \
                        "ERROR: x ({}) is the same as y ({})!".format(x, y)

        # Using atol/rtol.
        else:
            # Provide defaults for either one of atol/rtol.
            if atol is None:
                atol = 0
            if rtol is None:
                rtol = 1e-7
            try:
                np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
            except AssertionError as e:
                if false is False:
                    raise e
            else:
                if false is True:
                    assert False, \
                        "ERROR: x ({}) is the same as y ({})!".format(x, y)

ray.rllib.utils.test_utils.check_compute_single_action(trainer, include_state=False, include_prev_action_reward=False)

Tests different combinations of args for trainer.compute_single_action.

Parameters:

Name Type Description Default
trainer

The Trainer object to test.

required
include_state

Whether to include the initial state of the Policy's Model in the compute_single_action call.

False
include_prev_action_reward

Whether to include the prev-action and -reward in the compute_single_action call.

False

Exceptions:

Type Description
ValueError

If anything unexpected happens.

Source code in ray/rllib/utils/test_utils.py
def check_compute_single_action(trainer,
                                include_state=False,
                                include_prev_action_reward=False):
    """Tests different combinations of args for trainer.compute_single_action.

    Args:
        trainer: The Trainer object to test.
        include_state: Whether to include the initial state of the Policy's
            Model in the `compute_single_action` call.
        include_prev_action_reward: Whether to include the prev-action and
            -reward in the `compute_single_action` call.

    Raises:
        ValueError: If anything unexpected happens.
    """
    # Have to import this here to avoid circular dependency.
    from ray.rllib.policy.sample_batch import SampleBatch

    # Some Trainers may not abide to the standard API.
    try:
        pol = trainer.get_policy()
    except AttributeError:
        pol = trainer.policy
    # Get the policy's model.
    model = pol.model

    action_space = pol.action_space

    def _test(what, method_to_test, obs_space, full_fetch, explore, timestep,
              unsquash, clip):
        call_kwargs = {}
        if what is trainer:
            call_kwargs["full_fetch"] = full_fetch

        obs = obs_space.sample()
        if isinstance(obs_space, Box):
            obs = np.clip(obs, -1.0, 1.0)
        state_in = None
        if include_state:
            state_in = model.get_initial_state()
            if not state_in:
                state_in = []
                i = 0
                while f"state_in_{i}" in model.view_requirements:
                    state_in.append(model.view_requirements[f"state_in_{i}"]
                                    .space.sample())
                    i += 1
        action_in = action_space.sample() \
            if include_prev_action_reward else None
        reward_in = 1.0 if include_prev_action_reward else None

        if method_to_test == "input_dict":
            assert what is pol

            input_dict = {SampleBatch.OBS: obs}
            if include_prev_action_reward:
                input_dict[SampleBatch.PREV_ACTIONS] = action_in
                input_dict[SampleBatch.PREV_REWARDS] = reward_in
            if state_in:
                for i, s in enumerate(state_in):
                    input_dict[f"state_in_{i}"] = s
            input_dict_batched = SampleBatch(
                tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict))
            action = pol.compute_actions_from_input_dict(
                input_dict=input_dict_batched,
                explore=explore,
                timestep=timestep,
                **call_kwargs)
            # Unbatch everything to be able to compare against single
            # action below.
            # ARS and ES return action batches as lists.
            if isinstance(action[0], list):
                action = (np.array(action[0]), action[1], action[2])
            action = tree.map_structure(lambda s: s[0], action)

            try:
                action2 = pol.compute_single_action(
                    input_dict=input_dict,
                    explore=explore,
                    timestep=timestep,
                    **call_kwargs)
                # Make sure these are the same, unless we have exploration
                # switched on (or noisy layers).
                if not explore and not pol.config.get("noisy"):
                    check(action, action2)
            except TypeError:
                pass
        else:
            action = what.compute_single_action(
                obs,
                state_in,
                prev_action=action_in,
                prev_reward=reward_in,
                explore=explore,
                timestep=timestep,
                unsquash_action=unsquash,
                clip_action=clip,
                **call_kwargs)

        state_out = None
        if state_in or full_fetch or what is pol:
            action, state_out, _ = action
        if state_out:
            for si, so in zip(state_in, state_out):
                check(list(si.shape), so.shape)

        # Test whether unsquash/clipping works on the Trainer's
        # compute_single_action method: Both flags should force the action
        # to be within the space's bounds.
        if method_to_test == "single" and what == trainer:
            if not action_space.contains(action) and \
                    (clip or unsquash or not isinstance(action_space, Box)):
                raise ValueError(
                    f"Returned action ({action}) of trainer/policy {what} "
                    f"not in Env's action_space {action_space}")
            # We are operating in normalized space: Expect only smaller action
            # values.
            if isinstance(action_space, Box) and not unsquash and \
                    what.config.get("normalize_actions") and \
                    np.any(np.abs(action) > 3.0):
                raise ValueError(
                    f"Returned action ({action}) of trainer/policy {what} "
                    "should be in normalized space, but seems too large/small "
                    "for that!")

    # Loop through: Policy vs Trainer; Different API methods to calculate
    # actions; unsquash option; clip option; full fetch or not.
    for what in [pol, trainer]:
        if what is trainer:
            # Get the obs-space from Workers.env (not Policy) due to possible
            # pre-processor up front.
            worker_set = getattr(trainer, "workers")
            # TODO: ES and ARS use `self._workers` instead of `self.workers` to
            #  store their rollout worker set. Change to `self.workers`.
            if worker_set is None:
                worker_set = getattr(trainer, "_workers", None)
            assert worker_set
            if isinstance(worker_set, list):
                obs_space = trainer.get_policy().observation_space
            else:
                obs_space = worker_set.local_worker().for_policy(
                    lambda p: p.observation_space)
            obs_space = getattr(obs_space, "original_space", obs_space)
        else:
            obs_space = pol.observation_space

        for method_to_test in ["single"] + \
                (["input_dict"] if what is pol else []):
            for explore in [True, False]:
                for full_fetch in ([False, True]
                                   if what is trainer else [False]):
                    timestep = random.randint(0, 100000)
                    for unsquash in [True, False]:
                        for clip in ([False] if unsquash else [True, False]):
                            _test(what, method_to_test, obs_space, full_fetch,
                                  explore, timestep, unsquash, clip)

ray.rllib.utils.test_utils.check_train_results(train_results)

Checks proper structure of a Trainer.train() returned dict.

Parameters:

Name Type Description Default
train_results

The train results dict to check.

required

Exceptions:

Type Description
AssertionError

If train_results doesn't have the proper structure or data in it.

Source code in ray/rllib/utils/test_utils.py
def check_train_results(train_results):
    """Checks proper structure of a Trainer.train() returned dict.

    Args:
        train_results: The train results dict to check.

    Raises:
        AssertionError: If `train_results` doesn't have the proper structure or
            data in it.
    """
    # Import these here to avoid circular dependencies.
    from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
    from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
        LEARNER_STATS_KEY
    from ray.rllib.utils.multi_agent import check_multi_agent

    # Assert that some keys are where we would expect them.
    for key in [
            "agent_timesteps_total",
            "config",
            "custom_metrics",
            "episode_len_mean",
            "episode_reward_max",
            "episode_reward_mean",
            "episode_reward_min",
            "episodes_total",
            "hist_stats",
            "info",
            "iterations_since_restore",
            "num_healthy_workers",
            "perf",
            "policy_reward_max",
            "policy_reward_mean",
            "policy_reward_min",
            "sampler_perf",
            "time_since_restore",
            "time_this_iter_s",
            "timesteps_since_restore",
            "timesteps_total",
            "timers",
            "time_total_s",
            "training_iteration",
    ]:
        assert key in train_results, \
            f"'{key}' not found in `train_results` ({train_results})!"

    _, is_multi_agent = check_multi_agent(train_results["config"])

    # Check in particular the "info" dict.
    info = train_results["info"]
    assert LEARNER_INFO in info, \
        f"'learner' not in train_results['infos'] ({info})!"
    assert "num_steps_trained" in info,\
        f"'num_steps_trained' not in train_results['infos'] ({info})!"

    learner_info = info[LEARNER_INFO]

    # Make sure we have a default_policy key if we are not in a
    # multi-agent setup.
    if not is_multi_agent:
        # APEX algos sometimes have an empty learner info dict (no metrics
        # collected yet).
        assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
            f"'{DEFAULT_POLICY_ID}' not found in " \
            f"train_results['infos']['learner'] ({learner_info})!"

    for pid, policy_stats in learner_info.items():
        if pid == "batch_count":
            continue
        # Expect td-errors to be per batch-item.
        if "td_error" in policy_stats:
            configured_b = train_results["config"]["train_batch_size"]
            actual_b = policy_stats["td_error"].shape[0]
            # R2D2 case.
            if (configured_b - actual_b) / actual_b > 0.1:
                assert configured_b / (
                    train_results["config"]["model"]["max_seq_len"] +
                    train_results["config"]["burn_in"]) == actual_b

        # Make sure each policy has the LEARNER_STATS_KEY under it.
        assert LEARNER_STATS_KEY in policy_stats
        learner_stats = policy_stats[LEARNER_STATS_KEY]
        for key, value in learner_stats.items():
            # Min- and max-stats should be single values.
            if key.startswith("min_") or key.startswith("max_"):
                assert np.isscalar(
                    value), f"'key' value not a scalar ({value})!"

    return train_results

ray.rllib.utils.test_utils.framework_iterator(config=None, frameworks=('tf2', 'tf', 'tfe', 'torch'), session=False, with_eager_tracing=False, time_iterations=None)

An generator that allows for looping through n frameworks for testing.

Provides the correct config entries ("framework") as well as the correct eager/non-eager contexts for tfe/tf.

Parameters:

Name Type Description Default
config Optional[dict]

An optional config dict to alter in place depending on the iteration.

None
frameworks Sequence[str]

A list/tuple of the frameworks to be tested. Allowed are: "tf2", "tf", "tfe", "torch", and None.

('tf2', 'tf', 'tfe', 'torch')
session bool

If True and only in the tf-case: Enter a tf.Session() and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic.

False
with_eager_tracing bool

Include eager_tracing=True in the returned configs, when framework=[tfe|tf2].

False
time_iterations Optional[dict]

If provided, will write to the given dict (by framework key) the times in seconds that each (framework's) iteration takes.

None

!!! yields If session is False: The current framework [tf2|tf|tfe|torch] used. If session is True: A tuple consisting of the current framework string and the tf1.Session (if fw="tf", otherwise None).

Source code in ray/rllib/utils/test_utils.py
def framework_iterator(
        config: Optional[PartialTrainerConfigDict] = None,
        frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
        session: bool = False,
        with_eager_tracing: bool = False,
        time_iterations: Optional[dict] = None,
) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
    """An generator that allows for looping through n frameworks for testing.

    Provides the correct config entries ("framework") as well
    as the correct eager/non-eager contexts for tfe/tf.

    Args:
        config: An optional config dict to alter in place depending on the
            iteration.
        frameworks: A list/tuple of the frameworks to be tested.
            Allowed are: "tf2", "tf", "tfe", "torch", and None.
        session: If True and only in the tf-case: Enter a tf.Session()
            and yield that as second return value (otherwise yield (fw, None)).
            Also sets a seed (42) on the session to make the test
            deterministic.
        with_eager_tracing: Include `eager_tracing=True` in the returned
            configs, when framework=[tfe|tf2].
        time_iterations: If provided, will write to the given dict (by
            framework key) the times in seconds that each (framework's)
            iteration takes.

    Yields:
        If `session` is False: The current framework [tf2|tf|tfe|torch] used.
        If `session` is True: A tuple consisting of the current framework
        string and the tf1.Session (if fw="tf", otherwise None).
    """
    config = config or {}
    frameworks = [frameworks] if isinstance(frameworks, str) else \
        list(frameworks)

    # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version.
    if "tf2" in frameworks and "tfe" in frameworks:
        frameworks.remove("tfe" if tfv == 2 else "tf2")

    for fw in frameworks:
        # Skip non-installed frameworks.
        if fw == "torch" and not torch:
            logger.warning(
                "framework_iterator skipping torch (not installed)!")
            continue
        if fw != "torch" and not tf:
            logger.warning("framework_iterator skipping {} (tf not "
                           "installed)!".format(fw))
            continue
        elif fw == "tfe" and not eager_mode:
            logger.warning("framework_iterator skipping tf-eager (could not "
                           "import `eager_mode` from tensorflow.python)!")
            continue
        elif fw == "tf2" and tfv != 2:
            logger.warning(
                "framework_iterator skipping tf2.x (tf version is < 2.0)!")
            continue
        elif fw == "jax" and not jax:
            logger.warning("framework_iterator skipping JAX (not installed)!")
            continue
        assert fw in ["tf2", "tf", "tfe", "torch", "jax", None]

        # Do we need a test session?
        sess = None
        if fw == "tf" and session is True:
            sess = tf1.Session()
            sess.__enter__()
            tf1.set_random_seed(42)

        config["framework"] = fw

        eager_ctx = None
        # Enable eager mode for tf2 and tfe.
        if fw in ["tf2", "tfe"]:
            eager_ctx = eager_mode()
            eager_ctx.__enter__()
            assert tf1.executing_eagerly()
        # Make sure, eager mode is off.
        elif fw == "tf":
            assert not tf1.executing_eagerly()

        # Additionally loop through eager_tracing=True + False, if necessary.
        if fw in ["tf2", "tfe"] and with_eager_tracing:
            for tracing in [True, False]:
                config["eager_tracing"] = tracing
                print(f"framework={fw} (eager-tracing={tracing})")
                time_started = time.time()
                yield fw if session is False else (fw, sess)
                if time_iterations is not None:
                    time_total = time.time() - time_started
                    time_iterations[fw + ("+tracing" if tracing else "")] = \
                        time_total
                    print(f".. took {time_total}sec")
                config["eager_tracing"] = False
        # Yield current framework + tf-session (if necessary).
        else:
            print(f"framework={fw}")
            time_started = time.time()
            yield fw if session is False else (fw, sess)
            if time_iterations is not None:
                time_total = time.time() - time_started
                time_iterations[fw + ("+tracing" if tracing else "")] = \
                    time_total
                print(f".. took {time_total}sec")

        # Exit any context we may have entered.
        if eager_ctx:
            eager_ctx.__exit__(None, None, None)
        elif sess:
            sess.__exit__(None, None, None)

ray.util.ml_utils.dict.merge_dicts(d1, d2)

Parameters:

Name Type Description Default
d1 dict

Dict 1.

required
d2 dict

Dict 2.

required

Returns:

Type Description
dict

A new dict that is d1 and d2 deep merged.

Source code in ray/util/ml_utils/dict.py
def merge_dicts(d1: dict, d2: dict) -> dict:
    """
    Args:
        d1 (dict): Dict 1.
        d2 (dict): Dict 2.

    Returns:
         dict: A new dict that is d1 and d2 deep merged.
    """
    merged = copy.deepcopy(d1)
    deep_update(merged, d2, True, [])
    return merged

ray.util.ml_utils.dict.deep_update(original, new_dict, new_keys_allowed=False, allow_new_subkey_list=None, override_all_if_type_changes=None)

Updates original dict with values from new_dict recursively.

If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the allow_new_subkey_list, then new subkeys can be introduced.

Parameters:

Name Type Description Default
original dict

Dictionary with default values.

required
new_dict dict

Dictionary with values to be updated

required
new_keys_allowed bool

Whether new keys are allowed.

False
allow_new_subkey_list Optional[List[str]]

List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level.

None
override_all_if_type_changes(Optional[List[str]])

List of top level keys with value=dict, for which we always simply override the entire value (dict), iff the "type" key in that value dict changes.

required
Source code in ray/util/ml_utils/dict.py
def deep_update(
        original: dict,
        new_dict: dict,
        new_keys_allowed: str = False,
        allow_new_subkey_list: Optional[List[str]] = None,
        override_all_if_type_changes: Optional[List[str]] = None) -> dict:
    """Updates original dict with values from new_dict recursively.

    If new key is introduced in new_dict, then if new_keys_allowed is not
    True, an error will be thrown. Further, for sub-dicts, if the key is
    in the allow_new_subkey_list, then new subkeys can be introduced.

    Args:
        original (dict): Dictionary with default values.
        new_dict (dict): Dictionary with values to be updated
        new_keys_allowed (bool): Whether new keys are allowed.
        allow_new_subkey_list (Optional[List[str]]): List of keys that
            correspond to dict values where new subkeys can be introduced.
            This is only at the top level.
        override_all_if_type_changes(Optional[List[str]]): List of top level
            keys with value=dict, for which we always simply override the
            entire value (dict), iff the "type" key in that value dict changes.
    """
    allow_new_subkey_list = allow_new_subkey_list or []
    override_all_if_type_changes = override_all_if_type_changes or []

    for k, value in new_dict.items():
        if k not in original and not new_keys_allowed:
            raise Exception("Unknown config parameter `{}` ".format(k))

        # Both orginal value and new one are dicts.
        if isinstance(original.get(k), dict) and isinstance(value, dict):
            # Check old type vs old one. If different, override entire value.
            if k in override_all_if_type_changes and \
                "type" in value and "type" in original[k] and \
                    value["type"] != original[k]["type"]:
                original[k] = value
            # Allowed key -> ok to add new subkeys.
            elif k in allow_new_subkey_list:
                deep_update(original[k], value, True)
            # Non-allowed key.
            else:
                deep_update(original[k], value, new_keys_allowed)
        # Original value not a dict OR new value not a dict:
        # Override entire value.
        else:
            original[k] = value
    return original

ray.rllib.utils.add_mixins(base, mixins, reversed=False)

Returns a new class with mixins applied in priority order.

Source code in ray/rllib/utils/__init__.py
def add_mixins(base, mixins, reversed=False):
    """Returns a new class with mixins applied in priority order."""

    mixins = list(mixins or [])

    while mixins:
        if reversed:

            class new_base(base, mixins.pop()):
                pass

        else:

            class new_base(mixins.pop(), base):
                pass

        base = new_base

    return base

ray.rllib.utils.force_list(elements=None, to_tuple=False)

Makes sure elements is returned as a list, whether elements is a single item, already a list, or a tuple.

Parameters:

Name Type Description Default
elements Optional[any]

The inputs as single item, list, or tuple to be converted into a list/tuple. If None, returns empty list/tuple.

None
to_tuple bool

Whether to use tuple (instead of list).

False

Returns:

Type Description
Union[list,tuple]

All given elements in a list/tuple depending on to_tuple's value. If elements is None, returns an empty list/tuple.

Source code in ray/rllib/utils/__init__.py
def force_list(elements=None, to_tuple=False):
    """
    Makes sure `elements` is returned as a list, whether `elements` is a single
    item, already a list, or a tuple.

    Args:
        elements (Optional[any]): The inputs as single item, list, or tuple to
            be converted into a list/tuple. If None, returns empty list/tuple.
        to_tuple (bool): Whether to use tuple (instead of list).

    Returns:
        Union[list,tuple]: All given elements in a list/tuple depending on
            `to_tuple`'s value. If elements is None,
            returns an empty list/tuple.
    """
    ctor = list
    if to_tuple is True:
        ctor = tuple
    return ctor() if elements is None else ctor(elements) \
        if type(elements) in [list, tuple] else ctor([elements])

ray.rllib.utils.force_tuple

Back to top