mlagents/trainers/torch_entities/utils.py
453 lignes
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np
import numpy as np
from mlagents.trainers.torch_entities.encoders import (
from mlagents.trainers.torch_entities.encoders import (
    SimpleVisualEncoder,
    SimpleVisualEncoder,
    ResNetVisualEncoder,
    ResNetVisualEncoder,
    NatureVisualEncoder,
    NatureVisualEncoder,
    SmallVisualEncoder,
    SmallVisualEncoder,
    FullyConnectedVisualEncoder,
    FullyConnectedVisualEncoder,
    VectorInput,
    VectorInput,
)
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch_entities.attention import (
from mlagents.trainers.torch_entities.attention import (
    EntityEmbedding,
    EntityEmbedding,
    ResidualSelfAttention,
    ResidualSelfAttention,
)
)
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
class ModelUtils:
class ModelUtils:
    # Minimum supported side for each encoder type. If refactoring an encoder, please
    # Minimum supported side for each encoder type. If refactoring an encoder, please
    # adjust these also.
    # adjust these also.
    MIN_RESOLUTION_FOR_ENCODER = {
    MIN_RESOLUTION_FOR_ENCODER = {
        EncoderType.FULLY_CONNECTED: 1,
        EncoderType.FULLY_CONNECTED: 1,
        EncoderType.MATCH3: 5,
        EncoderType.MATCH3: 5,
        EncoderType.SIMPLE: 20,
        EncoderType.SIMPLE: 20,
        EncoderType.NATURE_CNN: 36,
        EncoderType.NATURE_CNN: 36,
        EncoderType.RESNET: 15,
        EncoderType.RESNET: 15,
    }
    }
    VALID_VISUAL_PROP = frozenset(
    VALID_VISUAL_PROP = frozenset(
        [
        [
            (
            (
                DimensionProperty.NONE,
                DimensionProperty.NONE,
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
            ),
            ),
            (DimensionProperty.UNSPECIFIED,) * 3,
            (DimensionProperty.UNSPECIFIED,) * 3,
        ]
        ]
    )
    )
    VALID_VECTOR_PROP = frozenset(
    VALID_VECTOR_PROP = frozenset(
        [(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
        [(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
    )
    )
    VALID_VAR_LEN_PROP = frozenset(
    VALID_VAR_LEN_PROP = frozenset(
        [(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
        [(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
    )
    )
    @staticmethod
    @staticmethod
    def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
    def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
        """
        """
        Apply a learning rate to a torch optimizer.
        Apply a learning rate to a torch optimizer.
        :param optim: Optimizer
        :param optim: Optimizer
        :param lr: Learning rate
        :param lr: Learning rate
        """
        """
        for param_group in optim.param_groups:
        for param_group in optim.param_groups:
            param_group["lr"] = lr
            param_group["lr"] = lr
    class DecayedValue:
    class DecayedValue:
        def __init__(
        def __init__(
            self,
            self,
            schedule: ScheduleType,
            schedule: ScheduleType,
            initial_value: float,
            initial_value: float,
            min_value: float,
            min_value: float,
            max_step: int,
            max_step: int,
        ):
        ):
            """
            """
            Object that represnets value of a parameter that should be decayed, assuming it is a function of
            Object that represnets value of a parameter that should be decayed, assuming it is a function of
            global_step.
            global_step.
            :param schedule: Type of learning rate schedule.
            :param schedule: Type of learning rate schedule.
            :param initial_value: Initial value before decay.
            :param initial_value: Initial value before decay.
            :param min_value: Decay value to this value by max_step.
            :param min_value: Decay value to this value by max_step.
            :param max_step: The final step count where the return value should equal min_value.
            :param max_step: The final step count where the return value should equal min_value.
            :param global_step: The current step count.
            :param global_step: The current step count.
            :return: The value.
            :return: The value.
            """
            """
            self.schedule = schedule
            self.schedule = schedule
            self.initial_value = initial_value
            self.initial_value = initial_value
            self.min_value = min_value
            self.min_value = min_value
            self.max_step = max_step
            self.max_step = max_step
        def get_value(self, global_step: int) -> float:
        def get_value(self, global_step: int) -> float:
            """
            """
            Get the value at a given global step.
            Get the value at a given global step.
            :param global_step: Step count.
            :param global_step: Step count.
            :returns: Decayed value at this global step.
            :returns: Decayed value at this global step.
            """
            """
            if self.schedule == ScheduleType.CONSTANT:
            if self.schedule == ScheduleType.CONSTANT:
                return self.initial_value
                return self.initial_value
            elif self.schedule == ScheduleType.LINEAR:
            elif self.schedule == ScheduleType.LINEAR:
                return ModelUtils.polynomial_decay(
                return ModelUtils.polynomial_decay(
                    self.initial_value, self.min_value, self.max_step, global_step
                    self.initial_value, self.min_value, self.max_step, global_step
                )
                )
            else:
            else:
                raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")
                raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")
    @staticmethod
    @staticmethod
    def polynomial_decay(
    def polynomial_decay(
        initial_value: float,
        initial_value: float,
        min_value: float,
        min_value: float,
        max_step: int,
        max_step: int,
        global_step: int,
        global_step: int,
        power: float = 1.0,
        power: float = 1.0,
    ) -> float:
    ) -> float:
        """
        """
        Get a decayed value based on a polynomial schedule, with respect to the current global step.
        Get a decayed value based on a polynomial schedule, with respect to the current global step.
        :param initial_value: Initial value before decay.
        :param initial_value: Initial value before decay.
        :param min_value: Decay value to this value by max_step.
        :param min_value: Decay value to this value by max_step.
        :param max_step: The final step count where the return value should equal min_value.
        :param max_step: The final step count where the return value should equal min_value.
        :param global_step: The current step count.
        :param global_step: The current step count.
        :param power: Power of polynomial decay. 1.0 (default) is a linear decay.
        :param power: Power of polynomial decay. 1.0 (default) is a linear decay.
        :return: The current decayed value.
        :return: The current decayed value.
        """
        """
        global_step = min(global_step, max_step)
        global_step = min(global_step, max_step)
        decayed_value = (initial_value - min_value) * (
        decayed_value = (initial_value - min_value) * (
            1 - float(global_step) / max_step
            1 - float(global_step) / max_step
        ) ** (power) + min_value
        ) ** (power) + min_value
        return decayed_value
        return decayed_value
    @staticmethod
    @staticmethod
    def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
    def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
        ENCODER_FUNCTION_BY_TYPE = {
        ENCODER_FUNCTION_BY_TYPE = {
            EncoderType.SIMPLE: SimpleVisualEncoder,
            EncoderType.SIMPLE: SimpleVisualEncoder,
            EncoderType.NATURE_CNN: NatureVisualEncoder,
            EncoderType.NATURE_CNN: NatureVisualEncoder,
            EncoderType.RESNET: ResNetVisualEncoder,
            EncoderType.RESNET: ResNetVisualEncoder,
            EncoderType.MATCH3: SmallVisualEncoder,
            EncoderType.MATCH3: SmallVisualEncoder,
            EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
            EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
        }
        }
        return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
        return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
    @staticmethod
    @staticmethod
    def _check_resolution_for_encoder(
    def _check_resolution_for_encoder(
        height: int, width: int, vis_encoder_type: EncoderType
        height: int, width: int, vis_encoder_type: EncoderType
    ) -> None:
    ) -> None:
        min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
        min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
        if height < min_res or width < min_res:
        if height < min_res or width < min_res:
            raise UnityTrainerException(
            raise UnityTrainerException(
                f"Visual observation resolution ({width}x{height}) is too small for"
                f"Visual observation resolution ({width}x{height}) is too small for"
                f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
                f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
            )
            )
    @staticmethod
    @staticmethod
    def get_encoder_for_obs(
    def get_encoder_for_obs(
        obs_spec: ObservationSpec,
        obs_spec: ObservationSpec,
        normalize: bool,
        normalize: bool,
        h_size: int,
        h_size: int,
        attention_embedding_size: int,
        attention_embedding_size: int,
        vis_encode_type: EncoderType,
        vis_encode_type: EncoderType,
    ) -> Tuple[nn.Module, int]:
    ) -> Tuple[nn.Module, int]:
        """
        """
        Returns the encoder and the size of the appropriate encoder.
        Returns the encoder and the size of the appropriate encoder.
        :param shape: Tuples that represent the observation dimension.
        :param shape: Tuples that represent the observation dimension.
        :param normalize: Normalize all vector inputs.
        :param normalize: Normalize all vector inputs.
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param vis_encode_type: Type of visual encoder to use.
        :param vis_encode_type: Type of visual encoder to use.
        """
        """
        shape = obs_spec.shape
        shape = obs_spec.shape
        dim_prop = obs_spec.dimension_property
        dim_prop = obs_spec.dimension_property
        # VISUAL
        # VISUAL
        if dim_prop in ModelUtils.VALID_VISUAL_PROP:
        if dim_prop in ModelUtils.VALID_VISUAL_PROP:
            visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
            visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
            ModelUtils._check_resolution_for_encoder(
            ModelUtils._check_resolution_for_encoder(
                shape[1], shape[2], vis_encode_type
                shape[1], shape[2], vis_encode_type
            )
            )
            return (visual_encoder_class(shape[1], shape[2], shape[0], h_size), h_size)
            return (visual_encoder_class(shape[1], shape[2], shape[0], h_size), h_size)
        # VECTOR
        # VECTOR
        if dim_prop in ModelUtils.VALID_VECTOR_PROP:
        if dim_prop in ModelUtils.VALID_VECTOR_PROP:
            return (VectorInput(shape[0], normalize), shape[0])
            return (VectorInput(shape[0], normalize), shape[0])
        # VARIABLE LENGTH
        # VARIABLE LENGTH
        if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
        if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
            return (
            return (
                EntityEmbedding(
                EntityEmbedding(
                    entity_size=shape[1],
                    entity_size=shape[1],
                    entity_num_max_elements=shape[0],
                    entity_num_max_elements=shape[0],
                    embedding_size=attention_embedding_size,
                    embedding_size=attention_embedding_size,
                ),
                ),
                0,
                0,
            )
            )
        # OTHER
        # OTHER
        raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")
        raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")
    @staticmethod
    @staticmethod
    def create_input_processors(
    def create_input_processors(
        observation_specs: List[ObservationSpec],
        observation_specs: List[ObservationSpec],
        h_size: int,
        h_size: int,
        vis_encode_type: EncoderType,
        vis_encode_type: EncoderType,
        attention_embedding_size: int,
        attention_embedding_size: int,
        normalize: bool = False,
        normalize: bool = False,
    ) -> Tuple[nn.ModuleList, List[int]]:
    ) -> Tuple[nn.ModuleList, List[int]]:
        """
        """
        Creates visual and vector encoders, along with their normalizers.
        Creates visual and vector encoders, along with their normalizers.
        :param observation_specs: List of ObservationSpec that represent the observation dimensions.
        :param observation_specs: List of ObservationSpec that represent the observation dimensions.
        :param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
        :param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
            conditioning network on other values (e.g. actions for a Q function)
            conditioning network on other values (e.g. actions for a Q function)
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param vis_encode_type: Type of visual encoder to use.
        :param vis_encode_type: Type of visual encoder to use.
        :param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
        :param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
            obs.
            obs.
        :param normalize: Normalize all vector inputs.
        :param normalize: Normalize all vector inputs.
        :return: Tuple of :
        :return: Tuple of :
         - ModuleList of the encoders
         - ModuleList of the encoders
         - A list of embedding sizes (0 if the input requires to be processed with a variable length
         - A list of embedding sizes (0 if the input requires to be processed with a variable length
         observation encoder)
         observation encoder)
        """
        """
        encoders: List[nn.Module] = []
        encoders: List[nn.Module] = []
        embedding_sizes: List[int] = []
        embedding_sizes: List[int] = []
        for obs_spec in observation_specs:
        for obs_spec in observation_specs:
            encoder, embedding_size = ModelUtils.get_encoder_for_obs(
            encoder, embedding_size = ModelUtils.get_encoder_for_obs(
                obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
                obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
            )
            )
            encoders.append(encoder)
            encoders.append(encoder)
            embedding_sizes.append(embedding_size)
            embedding_sizes.append(embedding_size)
        x_self_size = sum(embedding_sizes)  # The size of the "self" embedding
        x_self_size = sum(embedding_sizes)  # The size of the "self" embedding
        if x_self_size > 0:
        if x_self_size > 0:
            for enc in encoders:
            for enc in encoders:
                if isinstance(enc, EntityEmbedding):
                if isinstance(enc, EntityEmbedding):
                    enc.add_self_embedding(attention_embedding_size)
                    enc.add_self_embedding(attention_embedding_size)
        return (nn.ModuleList(encoders), embedding_sizes)
        return (nn.ModuleList(encoders), embedding_sizes)
    @staticmethod
    @staticmethod
    def list_to_tensor(
    def list_to_tensor(
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        """
        """
        Converts a list of numpy arrays into a tensor. MUCH faster than
        Converts a list of numpy arrays into a tensor. MUCH faster than
        calling as_tensor on the list directly.
        calling as_tensor on the list directly.
        """
        """
        return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
        return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
    @staticmethod
    @staticmethod
    def list_to_tensor_list(
    def list_to_tensor_list(
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        """
        """
        Converts a list of numpy arrays into a list of tensors. MUCH faster than
        Converts a list of numpy arrays into a list of tensors. MUCH faster than
        calling as_tensor on the list directly.
        calling as_tensor on the list directly.
        """
        """
        return [
        return [
            torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
            torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
        ]
        ]
    @staticmethod
    @staticmethod
    def to_numpy(tensor: torch.Tensor) -> np.ndarray:
    def to_numpy(tensor: torch.Tensor) -> np.ndarray:
        """
        """
        Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
        Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
        be brought to the CPU.
        be brought to the CPU.
        """
        """
        return tensor.detach().cpu().numpy()
        return tensor.detach().cpu().numpy()
    @staticmethod
    @staticmethod
    def break_into_branches(
    def break_into_branches(
        concatenated_logits: torch.Tensor, action_size: List[int]
        concatenated_logits: torch.Tensor, action_size: List[int]
    ) -> List[torch.Tensor]:
    ) -> List[torch.Tensor]:
        """
        """
        Takes a concatenated set of logits that represent multiple discrete action branches
        Takes a concatenated set of logits that represent multiple discrete action branches
        and breaks it up into one Tensor per branch.
        and breaks it up into one Tensor per branch.
        :param concatenated_logits: Tensor that represents the concatenated action branches
        :param concatenated_logits: Tensor that represents the concatenated action branches
        :param action_size: List of ints containing the number of possible actions for each branch.
        :param action_size: List of ints containing the number of possible actions for each branch.
        :return: A List of Tensors containing one tensor per branch.
        :return: A List of Tensors containing one tensor per branch.
        """
        """
        action_idx = [0] + list(np.cumsum(action_size))
        action_idx = [0] + list(np.cumsum(action_size))
        branched_logits = [
        branched_logits = [
            concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
            concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
            for i in range(len(action_size))
            for i in range(len(action_size))
        ]
        ]
        return branched_logits
        return branched_logits
    @staticmethod
    @staticmethod
    def actions_to_onehot(
    def actions_to_onehot(
        discrete_actions: torch.Tensor, action_size: List[int]
        discrete_actions: torch.Tensor, action_size: List[int]
    ) -> List[torch.Tensor]:
    ) -> List[torch.Tensor]:
        """
        """
        Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
        Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
        action.
        action.
        :param discrete_actions: Actions in integer form.
        :param discrete_actions: Actions in integer form.
        :param action_size: List of branch sizes. Should be of same size as discrete_actions'
        :param action_size: List of branch sizes. Should be of same size as discrete_actions'
        last dimension.
        last dimension.
        :return: List of one-hot tensors, one representing each branch.
        :return: List of one-hot tensors, one representing each branch.
        """
        """
        onehot_branches = [
        onehot_branches = [
            torch.nn.functional.one_hot(_act.T, action_size[i]).float()
            torch.nn.functional.one_hot(_act.T, action_size[i]).float()
            for i, _act in enumerate(discrete_actions.long().T)
            for i, _act in enumerate(discrete_actions.long().T)
        ]
        ]
        return onehot_branches
        return onehot_branches
    @staticmethod
    @staticmethod
    def dynamic_partition(
    def dynamic_partition(
        data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
        data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
    ) -> List[torch.Tensor]:
    ) -> List[torch.Tensor]:
        """
        """
        Torch implementation of dynamic_partition :
        Torch implementation of dynamic_partition :
        https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
        https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
        Splits the data Tensor input into num_partitions Tensors according to the indices in
        Splits the data Tensor input into num_partitions Tensors according to the indices in
        partitions.
        partitions.
        :param data: The Tensor data that will be split into partitions.
        :param data: The Tensor data that will be split into partitions.
        :param partitions: An indices tensor that determines in which partition each element
        :param partitions: An indices tensor that determines in which partition each element
        of data will be in.
        of data will be in.
        :param num_partitions: The number of partitions to output. Corresponds to the
        :param num_partitions: The number of partitions to output. Corresponds to the
        maximum possible index in the partitions argument.
        maximum possible index in the partitions argument.
        :return: A list of Tensor partitions (Their indices correspond to their partition index).
        :return: A list of Tensor partitions (Their indices correspond to their partition index).
        """
        """
        res: List[torch.Tensor] = []
        res: List[torch.Tensor] = []
        for i in range(num_partitions):
        for i in range(num_partitions):
            res += [data[(partitions == i).nonzero().squeeze(1)]]
            res += [data[(partitions == i).nonzero().squeeze(1)]]
        return res
        return res
    @staticmethod
    @staticmethod
    def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
    def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
        """
        """
        Returns the mean of the tensor but ignoring the values specified by masks.
        Returns the mean of the tensor but ignoring the values specified by masks.
        Used for masking out loss functions.
        Used for masking out loss functions.
        :param tensor: Tensor which needs mean computation.
        :param tensor: Tensor which needs mean computation.
        :param masks: Boolean tensor of masks with same dimension as tensor.
        :param masks: Boolean tensor of masks with same dimension as tensor.
        """
        """
        if tensor.ndim == 0:
        if tensor.ndim == 0:
            return (tensor * masks).sum() / torch.clamp(
            return (tensor * masks).sum() / torch.clamp(
                (torch.ones_like(tensor) * masks).float().sum(), min=1.0
                (torch.ones_like(tensor) * masks).float().sum(), min=1.0
            )
            )
        else:
        else:
            return (
            return (
                tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks
                tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks
            ).sum() / torch.clamp(
            ).sum() / torch.clamp(
                (
                (
                    torch.ones_like(
                    torch.ones_like(
                        tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1))
                        tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1))
                    )
                    )
                    * masks
                    * masks
                )
                )
                .float()
                .float()
                .sum(),
                .sum(),
                min=1.0,
                min=1.0,
            )
            )
    @staticmethod
    @staticmethod
    def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
    def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
        """
        """
        Performs an in-place polyak update of the target module based on the source,
        Performs an in-place polyak update of the target module based on the source,
        by a ratio of tau. Note that source and target modules must have the same
        by a ratio of tau. Note that source and target modules must have the same
        parameters, where:
        parameters, where:
            target = tau * source + (1-tau) * target
            target = tau * source + (1-tau) * target
        :param source: Source module whose parameters will be used.
        :param source: Source module whose parameters will be used.
        :param target: Target module whose parameters will be updated.
        :param target: Target module whose parameters will be updated.
        :param tau: Percentage of source parameters to use in average. Setting tau to
        :param tau: Percentage of source parameters to use in average. Setting tau to
            1 will copy the source parameters to the target.
            1 will copy the source parameters to the target.
        """
        """
        with torch.no_grad():
        with torch.no_grad():
            for source_param, target_param in zip(
            for source_param, target_param in zip(
                source.parameters(), target.parameters()
                source.parameters(), target.parameters()
            ):
            ):
                target_param.data.mul_(1.0 - tau)
                target_param.data.mul_(1.0 - tau)
                torch.add(
                torch.add(
                    target_param.data,
                    target_param.data,
                    source_param.data,
                    source_param.data,
                    alpha=tau,
                    alpha=tau,
                    out=target_param.data,
                    out=target_param.data,
                )
                )
    @staticmethod
    @staticmethod
    def create_residual_self_attention(
    def create_residual_self_attention(
        input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
        input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
    ) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
    ) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
        """
        """
        Creates an RSA if there are variable length observations found in the input processors.
        Creates an RSA if there are variable length observations found in the input processors.
        :param input_processors: A ModuleList of input processors as returned by the function
        :param input_processors: A ModuleList of input processors as returned by the function
            create_input_processors().
            create_input_processors().
        :param embedding sizes: A List of embedding sizes as returned by create_input_processors().
        :param embedding sizes: A List of embedding sizes as returned by create_input_processors().
        :param hidden_size: The hidden size to use for the RSA.
        :param hidden_size: The hidden size to use for the RSA.
        :returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
        :returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
            Returns None for the RSA and encoder if no var len inputs are detected.
            Returns None for the RSA and encoder if no var len inputs are detected.
        """
        """
        rsa, x_self_encoder = None, None
        rsa, x_self_encoder = None, None
        entity_num_max: int = 0
        entity_num_max: int = 0
        var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
        var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
        for processor in var_processors:
        for processor in var_processors:
            entity_max: int = processor.entity_num_max_elements
            entity_max: int = processor.entity_num_max_elements
            # Only adds entity max if it was known at construction
            # Only adds entity max if it was known at construction
            if entity_max > 0:
            if entity_max > 0:
                entity_num_max += entity_max
                entity_num_max += entity_max
        if len(var_processors) > 0:
        if len(var_processors) > 0:
            if sum(embedding_sizes):
            if sum(embedding_sizes):
                x_self_encoder = LinearEncoder(
                x_self_encoder = LinearEncoder(
                    sum(embedding_sizes),
                    sum(embedding_sizes),
                    1,
                    1,
                    hidden_size,
                    hidden_size,
                    kernel_init=Initialization.Normal,
                    kernel_init=Initialization.Normal,
                    kernel_gain=(0.125 / hidden_size) ** 0.5,
                    kernel_gain=(0.125 / hidden_size) ** 0.5,
                )
                )
            rsa = ResidualSelfAttention(hidden_size, entity_num_max)
            rsa = ResidualSelfAttention(hidden_size, entity_num_max)
        return rsa, x_self_encoder
        return rsa, x_self_encoder
    @staticmethod
    @staticmethod
    def trust_region_value_loss(
    def trust_region_value_loss(
        values: Dict[str, torch.Tensor],
        values: Dict[str, torch.Tensor],
        old_values: Dict[str, torch.Tensor],
        old_values: Dict[str, torch.Tensor],
        returns: Dict[str, torch.Tensor],
        returns: Dict[str, torch.Tensor],
        epsilon: float,
        epsilon: float,
        loss_masks: torch.Tensor,
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        """
        """
        Evaluates value loss, clipping to stay within a trust region of old value estimates.
        Evaluates value loss, clipping to stay within a trust region of old value estimates.
        Used for PPO and POCA.
        Used for PPO and POCA.
        :param values: Value output of the current network.
        :param values: Value output of the current network.
        :param old_values: Value stored with experiences in buffer.
        :param old_values: Value stored with experiences in buffer.
        :param returns: Computed returns.
        :param returns: Computed returns.
        :param epsilon: Clipping value for value estimate.
        :param epsilon: Clipping value for value estimate.
        :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        """
        value_losses = []
        value_losses = []
        for name, head in values.items():
        for name, head in values.items():
            old_val_tensor = old_values[name]
            old_val_tensor = old_values[name]
            returns_tensor = returns[name]
            returns_tensor = returns[name]
            clipped_value_estimate = old_val_tensor + torch.clamp(
            clipped_value_estimate = old_val_tensor + torch.clamp(
                head - old_val_tensor, -1 * epsilon, epsilon
                head - old_val_tensor, -1 * epsilon, epsilon
            )
            )
            v_opt_a = (returns_tensor - head) ** 2
            v_opt_a = (returns_tensor - head) ** 2
            v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
            v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
            value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
            value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks)
            #value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
            value_losses.append(value_loss)
            value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        value_loss = torch.mean(torch.stack(value_losses))
        return value_loss
        return value_loss
    @staticmethod
    @staticmethod
    def trust_region_policy_loss(
    def trust_region_policy_loss(
        advantages: torch.Tensor,
        advantages: torch.Tensor,
        log_probs: torch.Tensor,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        loss_masks: torch.Tensor,
        loss_masks: torch.Tensor,
        epsilon: float,
        epsilon: float,
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        """
        """
        Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
        Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
        :param advantages: Computed advantages.
        :param advantages: Computed advantages.
        :param log_probs: Current policy probabilities
        :param log_probs: Current policy probabilities
        :param old_log_probs: Past policy probabilities
        :param old_log_probs: Past policy probabilities
        :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        """
        advantage = advantages.unsqueeze(-1)
        advantage = advantages.unsqueeze(-1)
        r_theta = torch.exp(log_probs - old_log_probs)
        r_theta = torch.exp(log_probs - old_log_probs)
        p_opt_a = r_theta * advantage
        p_opt_a = r_theta * advantage
        p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
        p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
        policy_loss = -1 * ModelUtils.masked_mean(
        policy_loss = -1 * ModelUtils.masked_mean(
            torch.min(p_opt_a, p_opt_b), loss_masks
            torch.min(p_opt_a, p_opt_b), loss_masks
        )
        )
        return policy_loss
        return policy_loss