stable_baselines3/common/buffers.py
0 removals
622 lines
0 additions
622 lines
import warnings
import warnings
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import numpy as np
import torch as th
import torch as th
from gymnasium import spaces
from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictReplayBufferSamples,
DictRolloutBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
RolloutBufferSamples,
)
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.vec_env import VecNormalize
try:
try:
# Check memory used by replay buffer when possible
# Check memory used by replay buffer when possible
import psutil
import psutil
except ImportError:
except ImportError:
psutil = None
psutil = None
class BaseBuffer(ABC):
class BaseBuffer(ABC):
"""
"""
Base class that represent a buffer (rollout or replay)
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param observation_space: Observation space
:param action_space: Action space
:param action_space: Action space
:param device: PyTorch device
:param device: PyTorch device
to which the values will be converted
to which the values will be converted
:param n_envs: Number of parallel environments
:param n_envs: Number of parallel environments
"""
"""
observation_space: spaces.Space
observation_space: spaces.Space
obs_shape: Tuple[int, ...]
obs_shape: Tuple[int, ...]
def __init__(
def __init__(
self,
self,
buffer_size: int,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
n_envs: int = 1,
):
):
super().__init__()
super().__init__()
self.buffer_size = buffer_size
self.buffer_size = buffer_size
self.observation_space = observation_space
self.observation_space = observation_space
self.action_space = action_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.action_dim = get_action_dim(action_space)
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.pos = 0
self.full = False
self.full = False
self.device = get_device(device)
self.device = get_device(device)
self.n_envs = n_envs
self.n_envs = n_envs
@staticmethod
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:param arr:
:return:
:return:
"""
"""
shape = arr.shape
shape = arr.shape
if len(shape) < 3:
if len(shape) < 3:
shape = (*shape, 1)
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
def size(self) -> int:
"""
"""
:return: The current size of the buffer
:return: The current size of the buffer
"""
"""
if self.full:
if self.full:
return self.buffer_size
return self.buffer_size
return self.pos
return self.pos
def add(self, *args, **kwargs) -> None:
def add(self, *args, **kwargs) -> None:
"""
"""
Add elements to the buffer.
Add elements to the buffer.
"""
"""
raise NotImplementedError()
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
def extend(self, *args, **kwargs) -> None:
"""
"""
Add a new batch of transitions to the buffer
Add a new batch of transitions to the buffer
"""
"""
# Do a for loop along the batch axis
# Do a for loop along the batch axis
for data in zip(*args):
for data in zip(*args):
self.add(*data)
self.add(*data)
def reset(self) -> None:
def reset(self) -> None:
"""
"""
Reset the buffer.
Reset the buffer.
"""
"""
self.pos = 0
self.pos = 0
self.full = False
self.full = False
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
"""
"""
:param batch_size: Number of element to sample
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
to normalize the observations/rewards when sampling
:return:
:return:
"""
"""
upper_bound = self.buffer_size if self.full else self.pos
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
return self._get_samples(batch_inds, env=env)
@abstractmethod
@abstractmethod
def _get_samples(
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
"""
"""
:param batch_inds:
:param batch_inds:
:param env:
:param env:
:return:
:return:
"""
"""
raise NotImplementedError()
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
"""
Convert a numpy array to a PyTorch tensor.
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
Note: it copies the data by default
:param array:
:param array:
:param copy: Whether to copy or not the data (may be useful to avoid changing things
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
by reference). This argument is inoperative if the device is not the CPU.
:return:
:return:
"""
"""
if copy:
if copy:
return th.tensor(array, device=self.device)
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
@staticmethod
@staticmethod
def _normalize_obs(
def _normalize_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]],
obs: Union[np.ndarray, Dict[str, np.ndarray]],
env: Optional[VecNormalize] = None,
env: Optional[VecNormalize] = None,
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
if env is not None:
if env is not None:
return env.normalize_obs(obs)
return env.normalize_obs(obs)
return obs
return obs
@staticmethod
@staticmethod
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
if env is not None:
if env is not None:
return env.normalize_reward(reward).astype(np.float32)
return env.normalize_reward(reward).astype(np.float32)
return reward
return reward
class ReplayBuffer(BaseBuffer):
class ReplayBuffer(BaseBuffer):
"""
"""
Replay buffer used in off-policy algorithms like SAC/TD3.
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param observation_space: Observation space
:param action_space: Action space
:param action_space: Action space
:param device: PyTorch device
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
Cannot be used in combination with handle_timeout_termination.
Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
"""
observations: np.ndarray
observations: np.ndarray
next_observations: np.ndarray
next_observations: np.ndarray
actions: np.ndarray
actions: np.ndarray
rewards: np.ndarray
rewards: np.ndarray
dones: np.ndarray
dones: np.ndarray
timeouts: np.ndarray
timeouts: np.ndarray
def __init__(
def __init__(
self,
self,
buffer_size: int,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
n_envs: int = 1,
optimize_memory_usage: bool = False,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
handle_timeout_termination: bool = True,
):
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
# Check that the replay buffer can fit into the memory
if psutil is not None:
if psutil is not None:
mem_available = psutil.virtual_memory().available
mem_available = psutil.virtual_memory().available
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
# see https://github.com/DLR-RM/stable-baselines3/issues/934
# see https://github.com/DLR-RM/stable-baselines3/issues/934
if optimize_memory_usage and handle_timeout_termination:
if optimize_memory_usage and handle_timeout_termination:
raise ValueError(
raise ValueError(
"ReplayBuffer does not support optimize_memory_usage = True "
"ReplayBuffer does not support optimize_memory_usage = True "
"and handle_timeout_termination = True simultaneously."
"and handle_timeout_termination = True simultaneously."
)
)
self.optimize_memory_usage = optimize_memory_usage
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
if not optimize_memory_usage:
if not optimize_memory_usage:
# When optimizing memory, `observations` contains also the next observation
# When optimizing memory, `observations` contains also the next observation
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.actions = np.zeros(
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
if psutil is not None:
total_memory_usage: float = (
total_memory_usage: float = (
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
)
)
if not optimize_memory_usage:
if not optimize_memory_usage:
total_memory_usage += self.next_observations.nbytes
total_memory_usage += self.next_observations.nbytes
if total_memory_usage > mem_available:
if total_memory_usage > mem_available:
# Convert to GB
# Convert to GB
total_memory_usage /= 1e9
total_memory_usage /= 1e9
mem_available /= 1e9
mem_available /= 1e9
warnings.warn(
warnings.warn(
"This system does not have apparently enough memory to store the complete "
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
)
def add(
def add(
self,
self,
obs: np.ndarray,
obs: np.ndarray,
next_obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
infos: List[Dict[str, Any]],
) -> None:
) -> None:
# Reshape needed when using multiple envs with discrete observations
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs)
self.observations[self.pos] = np.array(obs)
if self.optimize_memory_usage:
if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
else:
else:
self.next_observations[self.pos] = np.array(next_obs)
self.next_observations[self.pos] = np.array(next_obs)
self.actions[self.pos] = np.array(action)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
self.dones[self.pos] = np.array(done)
if self.handle_timeout_termination:
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
self.pos += 1
if self.pos == self.buffer_size:
if self.pos == self.buffer_size:
self.full = True
self.full = True
self.pos = 0
self.pos = 0
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
"""
Sample elements from the replay buffer.
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
to normalize the observations/rewards when sampling
:return:
:return:
"""
"""
if not self.optimize_memory_usage:
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
return super().sample(batch_size=batch_size, env=env)
# Do not sample the element with index `self.pos` as the transitions is invalid
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
# (we use only one array to store `obs` and `next_obs`)
if self.full:
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
return self._get_samples(batch_inds, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
# Sample randomly the env idx
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
if self.optimize_memory_usage:
if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
else:
else:
next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
data = (
data = (
self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
self.actions[batch_inds, env_indices, :],
self.actions[batch_inds, env_indices, :],
next_obs,
next_obs,
# Only use dones that are not due to timeouts
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
# deactivated by default (timeouts is initialized as an array of False)
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
)
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@staticmethod
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
"""
Cast `np.float64` action datatype to `np.float32`,
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
keep the others dtype unchanged.
See GH#1572 for more information.
See GH#1572 for more information.
:param dtype: The original action space dtype
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
the original dtype otherwise.
"""
"""
if dtype == np.float64:
if dtype == np.float64:
return np.float32
return np.float32
return dtype
return dtype
class RolloutBuffer(BaseBuffer):
class RolloutBuffer(BaseBuffer):
"""
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
It corresponds to ``buffer_size`` transitions collected
using the current policy.
using the current policy.
This experience will be discarded after the policy update.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param observation_space: Observation space
:param action_space: Action space
:param action_space: Action space
:param device: PyTorch device
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param gamma: Discount factor
:param n_envs: Number of parallel environments
:param n_envs: Number of parallel environments
"""
"""
observations: np.ndarray
observations: np.ndarray
actions: np.ndarray
actions: np.ndarray
rewards: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
advantages: np.ndarray
returns: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
log_probs: np.ndarray
values: np.ndarray
values: np.ndarray
def __init__(
def __init__(
self,
self,
buffer_size: int,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gae_lambda: float = 1,
gamma: float = 0.99,
gamma: float = 0.99,
n_envs: int = 1,
n_envs: int = 1,
):
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gae_lambda = gae_lambda
self.gamma = gamma
self.gamma = gamma
self.generator_ready = False
self.generator_ready = False
self.reset()
self.reset()
def reset(self) -> None:
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
self.generator_ready = False
super().reset()
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
:param dones: if the last step was a terminal step (one bool for each env).
"""
"""
# Convert to numpy
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_values = last_values.clone().cpu().numpy().flatten()
last_gae_lam = 0
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_non_terminal = 1.0 - dones
next_values = last_values
next_values = last_values
else:
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
self.returns = self.advantages + self.values
def add(
def add(
self,
self,
obs: np.ndarray,
obs: np.ndarray,
action: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
value: th.Tensor,
log_prob: th.Tensor,
log_prob: th.Tensor,
) -> None:
) -> None:
"""
"""
:param obs: Observation
:param obs: Observation
:param action: Action
:param action: Action
:param reward:
:param reward:
:param episode_start: Start of episode signal.
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
:param value: estimated value of the current state
following the current policy.
following the current policy.
:param log_prob: log probability of the action
:param log_prob: log probability of the action
following the current policy.
following the current policy.
"""
"""
if len(log_prob.shape) == 0:
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
obs = obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs)
self.observations[self.pos] = np.array(obs)
self.actions[self.pos] = np.array(action)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
self.pos += 1
if self.pos == self.buffer_size:
if self.pos == self.buffer_size:
self.full = True
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
# Prepare the data
if not self.generator_ready:
if not self.generator_ready:
_tensor_names = [
_tensor_names = [
"observations",
"observations",
"actions",
"actions",
"values",
"values",
"log_probs",
"log_probs",
"advantages",
"advantages",
"returns",
"returns",
]
]
for tensor in _tensor_names:
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
self.generator_ready = True
# Return everything, don't create minibatches
# Return everything, don't create minibatches
if batch_size is None:
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
batch_size = self.buffer_size * self.n_envs
start_idx = 0
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
start_idx += batch_size
def _get_samples(
def _get_samples(
self,
self,
batch_inds: np.ndarray,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples:
) -> RolloutBufferSamples:
data = (
data = (
self.observations[batch_inds],
self.observations[batch_inds],
self.actions[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
class DictReplayBuffer(ReplayBuffer):
class DictReplayBuffer(ReplayBuffer):
"""
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Extends the ReplayBuffer to use dictionary observations
Extends the ReplayBuffer to use dictionary observations
:param buffer_size: Max number of element in the buffer
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param observation_space: Observation space
:param action_space: Action space
:param action_space: Action space
:param device: PyTorch device
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
"""
observation_space: spaces.Dict
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
def __init__(
self,
self,
buffer_size: int,
buffer_size: int,
observation_space: spaces.Dict,
observation_space: spaces.Dict,
action_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
n_envs: int = 1,
optimize_memory_usage: bool = False,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
handle_timeout_termination: bool = True,
):
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
self.buffer_size = max(buffer_size // n_envs, 1)
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
# Check that the replay buffer can fit into the memory
if psutil is not None:
if psutil is not None:
mem_available = psutil.virtual_memory().available
mem_available = psutil.virtual_memory().available
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage
self.optimize_memory_usage = optimize_memory_usage
self.observations = {
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
for key, _obs_shape in self.obs_shape.items()
}
}
self.next_observations = {
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
for key, _obs_shape in self.obs_shape.items()
}
}
self.actions = np.zeros(
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
if psutil is not None:
obs_nbytes = 0
obs_nbytes = 0
for _, obs in self.observations.items():
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
obs_nbytes += obs.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if not optimize_memory_usage:
if not optimize_memory_usage:
next_obs_nbytes = 0
next_obs_nbytes = 0
for _, obs in self.observations.items():
for _, obs in self.observations.items():
next_obs_nbytes += obs.nbytes
next_obs_nbytes += obs.nbytes
total_memory_usage += next_obs_nbytes
total_memory_usage += next_obs_nbytes
if total_memory_usage > mem_available:
if total_memory_usage > mem_available:
# Convert to GB
# Convert to GB
total_memory_usage /= 1e9
total_memory_usage /= 1e9
mem_available /= 1e9
mem_available /= 1e9
warnings.warn(
warnings.warn(
"This system does not have apparently enough memory to store the complete "
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
)
def add( # type: ignore[override]
def add( # type: ignore[override]
self,
self,
obs: Dict[str, np.ndarray],
obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
action: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
infos: List[Dict[str, Any]],
) -> None:
) -> None:
# Copy to avoid modification by reference
# Copy to avoid modification by reference
for key in self.observations.keys():
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
# Reshape needed when using multiple envs with discrete observations
# as numpy canno
# as numpy canno