Source code for ray.rllib.utils.replay_buffers.base
from abc import ABCMeta, abstractmethod
import platform
from typing import Any, Dict, Optional
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class ReplayBufferInterface(metaclass=ABCMeta):
"""Abstract base class for all of RLlib's replay buffers.
Mainly defines the `add()` and `sample()` methods that every buffer class
must implement to be usable by an Algorithm.
Buffers may determine on all the implementation details themselves, e.g.
whether to store single timesteps, episodes, or episode fragments or whether
to return fixed batch sizes or per-call defined ones.
"""
@abstractmethod
@DeveloperAPI
def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
@abstractmethod
@DeveloperAPI
def add(self, batch: Any, **kwargs) -> None:
"""Adds a batch of experiences or other data to this buffer.
Args:
batch: Batch or data to add.
``**kwargs``: Forward compatibility kwargs.
"""
@abstractmethod
@DeveloperAPI
def sample(self, num_items: Optional[int] = None, **kwargs) -> Any:
"""Samples `num_items` items from this buffer.
The exact shape of the returned data depends on the buffer's implementation.
Args:
num_items: Number of items to sample from this buffer.
``**kwargs``: Forward compatibility kwargs.
Returns:
A batch of items.
"""
@abstractmethod
@DeveloperAPI
def get_state(self) -> Dict[str, Any]:
"""Returns all local state in a dict.
Returns:
The serializable local state.
"""
@abstractmethod
@DeveloperAPI
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
Args:
state: The new state to set this buffer. Can be obtained by calling
`self.get_state()`.
"""
@DeveloperAPI
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()