mlagents/trainers/torch_entities/utils.py
0 removals
453 lines
2 additions
454 lines
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np
import numpy as np
from mlagents.trainers.torch_entities.encoders import (
from mlagents.trainers.torch_entities.encoders import (
SimpleVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
FullyConnectedVisualEncoder,
VectorInput,
VectorInput,
)
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch_entities.attention import (
from mlagents.trainers.torch_entities.attention import (
EntityEmbedding,
EntityEmbedding,
ResidualSelfAttention,
ResidualSelfAttention,
)
)
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
class ModelUtils:
class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.FULLY_CONNECTED: 1,
EncoderType.FULLY_CONNECTED: 1,
EncoderType.MATCH3: 5,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
EncoderType.NATURE_CNN: 36,
EncoderType.RESNET: 15,
EncoderType.RESNET: 15,
}
}
VALID_VISUAL_PROP = frozenset(
VALID_VISUAL_PROP = frozenset(
[
[
(
(
DimensionProperty.NONE,
DimensionProperty.NONE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
),
),
(DimensionProperty.UNSPECIFIED,) * 3,
(DimensionProperty.UNSPECIFIED,) * 3,
]
]
)
)
VALID_VECTOR_PROP = frozenset(
VALID_VECTOR_PROP = frozenset(
[(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
[(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
)
)
VALID_VAR_LEN_PROP = frozenset(
VALID_VAR_LEN_PROP = frozenset(
[(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
[(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
)
)
@staticmethod
@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
"""
"""
Apply a learning rate to a torch optimizer.
Apply a learning rate to a torch optimizer.
:param optim: Optimizer
:param optim: Optimizer
:param lr: Learning rate
:param lr: Learning rate
"""
"""
for param_group in optim.param_groups:
for param_group in optim.param_groups:
param_group["lr"] = lr
param_group["lr"] = lr
class DecayedValue:
class DecayedValue:
def __init__(
def __init__(
self,
self,
schedule: ScheduleType,
schedule: ScheduleType,
initial_value: float,
initial_value: float,
min_value: float,
min_value: float,
max_step: int,
max_step: int,
):
):
"""
"""
Object that represnets value of a parameter that should be decayed, assuming it is a function of
Object that represnets value of a parameter that should be decayed, assuming it is a function of
global_step.
global_step.
:param schedule: Type of learning rate schedule.
:param schedule: Type of learning rate schedule.
:param initial_value: Initial value before decay.
:param initial_value: Initial value before decay.
:param min_value: Decay value to this value by max_step.
:param min_value: Decay value to this value by max_step.
:param max_step: The final step count where the return value should equal min_value.
:param max_step: The final step count where the return value should equal min_value.
:param global_step: The current step count.
:param global_step: The current step count.
:return: The value.
:return: The value.
"""
"""
self.schedule = schedule
self.schedule = schedule
self.initial_value = initial_value
self.initial_value = initial_value
self.min_value = min_value
self.min_value = min_value
self.max_step = max_step
self.max_step = max_step
def get_value(self, global_step: int) -> float:
def get_value(self, global_step: int) -> float:
"""
"""
Get the value at a given global step.
Get the value at a given global step.
:param global_step: Step count.
:param global_step: Step count.
:returns: Decayed value at this global step.
:returns: Decayed value at this global step.
"""
"""
if self.schedule == ScheduleType.CONSTANT:
if self.schedule == ScheduleType.CONSTANT:
return self.initial_value
return self.initial_value
elif self.schedule == ScheduleType.LINEAR:
elif self.schedule == ScheduleType.LINEAR:
return ModelUtils.polynomial_decay(
return ModelUtils.polynomial_decay(
self.initial_value, self.min_value, self.max_step, global_step
self.initial_value, self.min_value, self.max_step, global_step
)
)
else:
else:
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")
@staticmethod
@staticmethod
def polynomial_decay(
def polynomial_decay(
initial_value: float,
initial_value: float,
min_value: float,
min_value: float,
max_step: int,
max_step: int,
global_step: int,
global_step: int,
power: float = 1.0,
power: float = 1.0,
) -> float:
) -> float:
"""
"""
Get a decayed value based on a polynomial schedule, with respect to the current global step.
Get a decayed value based on a polynomial schedule, with respect to the current global step.
:param initial_value: Initial value before decay.
:param initial_value: Initial value before decay.
:param min_value: Decay value to this value by max_step.
:param min_value: Decay value to this value by max_step.
:param max_step: The final step count where the return value should equal min_value.
:param max_step: The final step count where the return value should equal min_value.
:param global_step: The current step count.
:param global_step: The current step count.
:param power: Power of polynomial decay. 1.0 (default) is a linear decay.
:param power: Power of polynomial decay. 1.0 (default) is a linear decay.
:return: The current decayed value.
:return: The current decayed value.
"""
"""
global_step = min(global_step, max_step)
global_step = min(global_step, max_step)
decayed_value = (initial_value - min_value) * (
decayed_value = (initial_value - min_value) * (
1 - float(global_step) / max_step
1 - float(global_step) / max_step
) ** (power) + min_value
) ** (power) + min_value
return decayed_value
return decayed_value
@staticmethod
@staticmethod
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
ENCODER_FUNCTION_BY_TYPE = {
ENCODER_FUNCTION_BY_TYPE = {
EncoderType.SIMPLE: SimpleVisualEncoder,
EncoderType.SIMPLE: SimpleVisualEncoder,
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
@staticmethod
@staticmethod
def _check_resolution_for_encoder(
def _check_resolution_for_encoder(
height: int, width: int, vis_encoder_type: EncoderType
height: int, width: int, vis_encoder_type: EncoderType
) -> None:
) -> None:
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
if height < min_res or width < min_res:
if height < min_res or width < min_res:
raise UnityTrainerException(
raise UnityTrainerException(
f"Visual observation resolution ({width}x{height}) is too small for"
f"Visual observation resolution ({width}x{height}) is too small for"
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
)
)
@staticmethod
@staticmethod
def get_encoder_for_obs(
def get_encoder_for_obs(
obs_spec: ObservationSpec,
obs_spec: ObservationSpec,
normalize: bool,
normalize: bool,
h_size: int,
h_size: int,
attention_embedding_size: int,
attention_embedding_size: int,
vis_encode_type: EncoderType,
vis_encode_type: EncoderType,
) -> Tuple[nn.Module, int]:
) -> Tuple[nn.Module, int]:
"""
"""
Returns the encoder and the size of the appropriate encoder.
Returns the encoder and the size of the appropriate encoder.
:param shape: Tuples that represent the observation dimension.
:param shape: Tuples that represent the observation dimension.
:param normalize: Normalize all vector inputs.
:param normalize: Normalize all vector inputs.
:param h_size: Number of hidden units per layer excluding attention layers.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
:param vis_encode_type: Type of visual encoder to use.
"""
"""
shape = obs_spec.shape
shape = obs_spec.shape
dim_prop = obs_spec.dimension_property
dim_prop = obs_spec.dimension_property
# VISUAL
# VISUAL
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
ModelUtils._check_resolution_for_encoder(
ModelUtils._check_resolution_for_encoder(
shape[1], shape[2], vis_encode_type
shape[1], shape[2], vis_encode_type
)
)
return (visual_encoder_class(shape[1], shape[2], shape[0], h_size), h_size)
return (visual_encoder_class(shape[1], shape[2], shape[0], h_size), h_size)
# VECTOR
# VECTOR
if dim_prop in ModelUtils.VALID_VECTOR_PROP:
if dim_prop in ModelUtils.VALID_VECTOR_PROP:
return (VectorInput(shape[0], normalize), shape[0])
return (VectorInput(shape[0], normalize), shape[0])
# VARIABLE LENGTH
# VARIABLE LENGTH
if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
return (
return (
EntityEmbedding(
EntityEmbedding(
entity_size=shape[1],
entity_size=shape[1],
entity_num_max_elements=shape[0],
entity_num_max_elements=shape[0],
embedding_size=attention_embedding_size,
embedding_size=attention_embedding_size,
),
),
0,
0,
)
)
# OTHER
# OTHER
raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")
raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")
@staticmethod
@staticmethod
def create_input_processors(
def create_input_processors(
observation_specs: List[ObservationSpec],
observation_specs: List[ObservationSpec],
h_size: int,
h_size: int,
vis_encode_type: EncoderType,
vis_encode_type: EncoderType,
attention_embedding_size: int,
attention_embedding_size: int,
normalize: bool = False,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int]]:
) -> Tuple[nn.ModuleList, List[int]]:
"""
"""
Creates visual and vector encoders, along with their normalizers.
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
conditioning network on other values (e.g. actions for a Q function)
conditioning network on other values (e.g. actions for a Q function)
:param h_size: Number of hidden units per layer excluding attention layers.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
:param vis_encode_type: Type of visual encoder to use.
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
obs.
obs.
:param normalize: Normalize all vector inputs.
:param normalize: Normalize all vector inputs.
:return: Tuple of :
:return: Tuple of :
- ModuleList of the encoders
- ModuleList of the encoders
- A list of embedding sizes (0 if the input requires to be processed with a variable length
- A list of embedding sizes (0 if the input requires to be processed with a variable length
observation encoder)
observation encoder)
"""
"""
encoders: List[nn.Module] = []
encoders: List[nn.Module] = []
embedding_sizes: List[int] = []
embedding_sizes: List[int] = []
for obs_spec in observation_specs:
for obs_spec in observation_specs:
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
)
)
encoders.append(encoder)
encoders.append(encoder)
embedding_sizes.append(embedding_size)
embedding_sizes.append(embedding_size)
x_self_size = sum(embedding_sizes) # The size of the "self" embedding
x_self_size = sum(embedding_sizes) # The size of the "self" embedding
if x_self_size > 0:
if x_self_size > 0:
for enc in encoders:
for enc in encoders:
if isinstance(enc, EntityEmbedding):
if isinstance(enc, EntityEmbedding):
enc.add_self_embedding(attention_embedding_size)
enc.add_self_embedding(attention_embedding_size)
return (nn.ModuleList(encoders), embedding_sizes)
return (nn.ModuleList(encoders), embedding_sizes)
@staticmethod
@staticmethod
def list_to_tensor(
def list_to_tensor(
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
) -> torch.Tensor:
"""
"""
Converts a list of numpy arrays into a tensor. MUCH faster than
Converts a list of numpy arrays into a tensor. MUCH faster than
calling as_tensor on the list directly.
calling as_tensor on the list directly.
"""
"""
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
@staticmethod
@staticmethod
def list_to_tensor_list(
def list_to_tensor_list(
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
) -> torch.Tensor:
"""
"""
Converts a list of numpy arrays into a list of tensors. MUCH faster than
Converts a list of numpy arrays into a list of tensors. MUCH faster than
calling as_tensor on the list directly.
calling as_tensor on the list directly.
"""
"""
return [
return [
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
]
]
@staticmethod
@staticmethod
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""
"""
Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
be brought to the CPU.
be brought to the CPU.
"""
"""
return tensor.detach().cpu().numpy()
return tensor.detach().cpu().numpy()
@staticmethod
@staticmethod
def break_into_branches(
def break_into_branches(
concatenated_logits: torch.Tensor, action_size: List[int]
concatenated_logits: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
"""
"""
Takes a concatenated set of logits that represent multiple discrete action branches
Takes a concatenated set of logits that represent multiple discrete action branches
and breaks it up into one Tensor per branch.
and breaks it up into one Tensor per branch.
:param concatenated_logits: Tensor that represents the concatenated action branches
:param concatenated_logits: Tensor that represents the concatenated action branches
:param action_size: List of ints containing the number of possible actions for each branch.
:param action_size: List of ints containing the number of possible actions for each branch.
:return: A List of Tensors containing one tensor per branch.
:return: A List of Tensors containing one tensor per branch.
"""
"""
action_idx = [0] + list(np.cumsum(action_size))
action_idx = [0] + list(np.cumsum(action_size))
branched_logits = [
branched_logits = [
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(action_size))
for i in range(len(action_size))
]
]
return branched_logits
return branched_logits
@staticmethod
@staticmethod
def actions_to_onehot(
def actions_to_onehot(
discrete_actions: torch.Tensor, action_size: List[int]
discrete_actions: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
"""
"""
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
action.
action.
:param discrete_actions: Actions in integer form.
:param discrete_actions: Actions in integer form.
:param action_size: List of branch sizes. Should be of same size as discrete_actions'
:param action_size: List of branch sizes. Should be of same size as discrete_actions'
last dimension.
last dimension.
:return: List of one-hot tensors, one representing each branch.
:return: List of one-hot tensors, one representing each branch.
"""
"""
onehot_branches = [
onehot_branches = [
torch.nn.functional.one_hot(_act.T, action_size[i]).float()
torch.nn.functional.one_hot(_act.T, action_size[i]).float()
for i, _act in enumerate(discrete_actions.long().T)
for i, _act in enumerate(discrete_actions.long().T)
]
]
return onehot_branches
return onehot_branches
@staticmethod
@staticmethod
def dynamic_partition(
def dynamic_partition(
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
"""
"""
Torch implementation of dynamic_partition :
Torch implementation of dynamic_partition :
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
Splits the data Tensor input into num_partitions Tensors according to the indices in
Splits the data Tensor input into num_partitions Tensors according to the indices in
partitions.
partitions.
:param data: The Tensor data that will be split into partitions.
:param data: The Tensor data that will be split into partitions.
:param partitions: An indices tensor that determines in which partition each element
:param partitions: An indices tensor that determines in which partition each element
of data will be in.
of data will be in.
:param num_partitions: The number of partitions to output. Corresponds to the
:param num_partitions: The number of partitions to output. Corresponds to the
maximum possible index in the partitions argument.
maximum possible index in the partitions argument.
:return: A list of Tensor partitions (Their indices correspond to their partition index).
:return: A list of Tensor partitions (Their indices correspond to their partition index).
"""
"""
res: List[torch.Tensor] = []
res: List[torch.Tensor] = []
for i in range(num_partitions):
for i in range(num_partitions):
res += [data[(partitions == i).nonzero().squeeze(1)]]
res += [data[(partitions == i).nonzero().squeeze(1)]]
return res
return res
@staticmethod
@staticmethod
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
"""
Returns the mean of the tensor but ignoring the values specified by masks.
Returns the mean of the tensor but ignoring the values specified by masks.
Used for masking out loss functions.
Used for masking out loss functions.
:param tensor: Tensor which needs mean computation.
:param tensor: Tensor which needs mean computation.
:param masks: Boolean tensor of masks with same dimension as tensor.
:param masks: Boolean tensor of masks with same dimension as tensor.
"""
"""
if tensor.ndim == 0:
if tensor.ndim == 0:
return (tensor * masks).sum() / torch.clamp(
return (tensor * masks).sum() / torch.clamp(
(torch.ones_like(tensor) * masks).float().sum(), min=1.0
(torch.ones_like(tensor) * masks).float().sum(), min=1.0
)
)
else:
else:
return (
return (
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks
).sum() / torch.clamp(
).sum() / torch.clamp(
(
(
torch.ones_like(
torch.ones_like(
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1))
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1))
)
)
* masks
* masks
)
)
.float()
.float()
.sum(),
.sum(),
min=1.0,
min=1.0,
)
)
@staticmethod
@staticmethod
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
"""
"""
Performs an in-place polyak update of the target module based on the source,
Performs an in-place polyak update of the target module based on the source,
by a ratio of tau. Note that source and target modules must have the same
by a ratio of tau. Note that source and target modules must have the same
parameters, where:
parameters, where:
target = tau * source + (1-tau) * target
target = tau * source + (1-tau) * target
:param source: Source module whose parameters will be used.
:param source: Source module whose parameters will be used.
:param target: Target module whose parameters will be updated.
:param target: Target module whose parameters will be updated.
:param tau: Percentage of source parameters to use in average. Setting tau to
:param tau: Percentage of source parameters to use in average. Setting tau to
1 will copy the source parameters to the target.
1 will copy the source parameters to the target.
"""
"""
with torch.no_grad():
with torch.no_grad():
for source_param, target_param in zip(
for source_param, target_param in zip(
source.parameters(), target.parameters()
source.parameters(), target.parameters()
):
):
target_param.data.mul_(1.0 - tau)
target_param.data.mul_(1.0 - tau)
torch.add(
torch.add(
target_param.data,
target_param.data,
source_param.data,
source_param.data,
alpha=tau,
alpha=tau,
out=target_param.data,
out=target_param.data,
)
)
@staticmethod
@staticmethod
def create_residual_self_attention(
def create_residual_self_attention(
input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
"""
"""
Creates an RSA if there are variable length observations found in the input processors.
Creates an RSA if there are variable length observations found in the input processors.
:param input_processors: A ModuleList of input processors as returned by the function
:param input_processors: A ModuleList of input processors as returned by the function
create_input_processors().
create_input_processors().
:param embedding sizes: A List of embedding sizes as returned by create_input_processors().
:param embedding sizes: A List of embedding sizes as returned by create_input_processors().
:param hidden_size: The hidden size to use for the RSA.
:param hidden_size: The hidden size to use for the RSA.
:returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
:returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
Returns None for the RSA and encoder if no var len inputs are detected.
Returns None for the RSA and encoder if no var len inputs are detected.
"""
"""
rsa, x_self_encoder = None, None
rsa, x_self_encoder = None, None
entity_num_max: int = 0
entity_num_max: int = 0
var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
# Only adds entity max if it was known at construction
if entity_max > 0:
if entity_max > 0:
entity_num_max += entity_max
entity_num_max += entity_max
if len(var_processors) > 0:
if len(var_processors) > 0:
if sum(embedding_sizes):
if sum(embedding_sizes):
x_self_encoder = LinearEncoder(
x_self_encoder = LinearEncoder(
sum(embedding_sizes),
sum(embedding_sizes),
1,
1,
hidden_size,
hidden_size,
kernel_init=Initialization.Normal,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / hidden_size) ** 0.5,
kernel_gain=(0.125 / hidden_size) ** 0.5,
)
)
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
return rsa, x_self_encoder
return rsa, x_self_encoder
@staticmethod
@staticmethod
def trust_region_value_loss(
def trust_region_value_loss(
values: Dict[str, torch.Tensor],
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
epsilon: float,
loss_masks: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""
"""
Evaluates value loss, clipping to stay within a trust region of old value estimates.
Evaluates value loss, clipping to stay within a trust region of old value estimates.
Used for PPO and POCA.
Used for PPO and POCA.
:param values: Value output of the current network.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
"""
value_losses = []
value_losses = []
for name, head in values.items():
for name, head in values.items():
old_val_tensor = old_values[name]
old_val_tensor = old_values[name]
returns_tensor = returns[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
head - old_val_tensor, -1 * epsilon, epsilon
)
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks)
#value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
value_loss = torch.mean(torch.stack(value_losses))
return value_loss
return value_loss
@staticmethod
@staticmethod
def trust_region_policy_loss(
def trust_region_policy_loss(
advantages: torch.Tensor,
advantages: torch.Tensor,
log_probs: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
loss_masks: torch.Tensor,
epsilon: float,
epsilon: float,
) -> torch.Tensor:
) -> torch.Tensor:
"""
"""
Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
:param advantages: Computed advantages.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
"""
advantage = advantages.unsqueeze(-1)
advantage = advantages.unsqueeze(-1)
r_theta = torch.exp(log_probs - old_log_probs)
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_a = r_theta * advantage
p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
policy_loss = -1 * ModelUtils.masked_mean(
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
torch.min(p_opt_a, p_opt_b), loss_masks
)
)
return policy_loss
return policy_loss