mlagents/trainers/torch_entities/distributions.py

Created Diff never expires
1 removal
249 lines
1 addition
249 lines
import abc
import abc
from typing import List
from typing import List
from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn
import numpy as np
import numpy as np
import math
import math
from mlagents.trainers.torch_entities.layers import linear_layer, Initialization
from mlagents.trainers.torch_entities.layers import linear_layer, Initialization


EPSILON = 1e-7 # Small value to avoid divide by zero
EPSILON = 1e-7 # Small value to avoid divide by zero




class DistInstance(nn.Module, abc.ABC):
class DistInstance(nn.Module, abc.ABC):
@abc.abstractmethod
@abc.abstractmethod
def sample(self) -> torch.Tensor:
def sample(self) -> torch.Tensor:
"""
"""
Return a sample from this distribution.
Return a sample from this distribution.
"""
"""
pass
pass


@abc.abstractmethod
@abc.abstractmethod
def deterministic_sample(self) -> torch.Tensor:
def deterministic_sample(self) -> torch.Tensor:
"""
"""
Return the most probable sample from this distribution.
Return the most probable sample from this distribution.
"""
"""
pass
pass


@abc.abstractmethod
@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
"""
Returns the log probabilities of a particular value.
Returns the log probabilities of a particular value.
:param value: A value sampled from the distribution.
:param value: A value sampled from the distribution.
:returns: Log probabilities of the given value.
:returns: Log probabilities of the given value.
"""
"""
pass
pass


@abc.abstractmethod
@abc.abstractmethod
def entropy(self) -> torch.Tensor:
def entropy(self) -> torch.Tensor:
"""
"""
Returns the entropy of this distribution.
Returns the entropy of this distribution.
"""
"""
pass
pass


@abc.abstractmethod
@abc.abstractmethod
def exported_model_output(self) -> torch.Tensor:
def exported_model_output(self) -> torch.Tensor:
"""
"""
Returns the tensor to be exported to ONNX for the distribution
Returns the tensor to be exported to ONNX for the distribution
"""
"""
pass
pass




class DiscreteDistInstance(DistInstance):
class DiscreteDistInstance(DistInstance):
@abc.abstractmethod
@abc.abstractmethod
def all_log_prob(self) -> torch.Tensor:
def all_log_prob(self) -> torch.Tensor:
"""
"""
Returns the log probabilities of all actions represented by this distribution.
Returns the log probabilities of all actions represented by this distribution.
"""
"""
pass
pass




class GaussianDistInstance(DistInstance):
class GaussianDistInstance(DistInstance):
def __init__(self, mean, std):
def __init__(self, mean, std):
super().__init__()
super().__init__()
self.mean = mean
self.mean = mean
self.std = std
self.std = std


def sample(self):
def sample(self):
sample = self.mean + torch.randn_like(self.mean) * self.std
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample
return sample


def deterministic_sample(self):
def deterministic_sample(self):
return self.mean
return self.mean


def log_prob(self, value):
def log_prob(self, value):
var = self.std**2
var = self.std**2
log_scale = torch.log(self.std + EPSILON)
log_scale = torch.log(self.std + EPSILON)
return (
return (
-((value - self.mean) ** 2) / (2 * var + EPSILON)
-((value - self.mean) ** 2) / (2 * var + EPSILON)
- log_scale
- log_scale
- math.log(math.sqrt(2 * math.pi))
- math.log(math.sqrt(2 * math.pi))
)
)


def pdf(self, value):
def pdf(self, value):
log_prob = self.log_prob(value)
log_prob = self.log_prob(value)
return torch.exp(log_prob)
return torch.exp(log_prob)


def entropy(self):
def entropy(self):
return torch.mean(
return torch.sum(
0.5 * torch.log(2 * math.pi * math.e * self.std**2 + EPSILON),
0.5 * torch.log(2 * math.pi * math.e * self.std**2 + EPSILON),
dim=1,
dim=1,
keepdim=True,
keepdim=True,
) # Use equivalent behavior to TF
) # Use equivalent behavior to TF


def exported_model_output(self):
def exported_model_output(self):
return self.sample()
return self.sample()




class TanhGaussianDistInstance(GaussianDistInstance):
class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):
def __init__(self, mean, std):
super().__init__(mean, std)
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)


def sample(self):
def sample(self):
unsquashed_sample = super().sample()
unsquashed_sample = super().sample()
squashed = self.transform(unsquashed_sample)
squashed = self.transform(unsquashed_sample)
return squashed
return squashed


def _inverse_tanh(self, value):
def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)


def log_prob(self, value):
def log_prob(self, value):
unsquashed = self.transform.inv(value)
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
unsquashed, value
)
)




class CategoricalDistInstance(DiscreteDistInstance):
class CategoricalDistInstance(DiscreteDistInstance):
def __init__(self, logits):
def __init__(self, logits):
super().__init__()
super().__init__()
self.logits = logits
self.logits = logits
self.probs = torch.softmax(self.logits, dim=-1)
self.probs = torch.softmax(self.logits, dim=-1)


def sample(self):
def sample(self):
return torch.multinomial(self.probs, 1)
return torch.multinomial(self.probs, 1)


def deterministic_sample(self):
def deterministic_sample(self):
return torch.argmax(self.probs, dim=1, keepdim=True)
return torch.argmax(self.probs, dim=1, keepdim=True)


def pdf(self, value):
def pdf(self, value):
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
# but torch.diag is not supported by ONNX export.
# but torch.diag is not supported by ONNX export.
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
return torch.gather(
return torch.gather(
self.probs.permute(1, 0)[value.flatten().long()], -1, idx
self.probs.permute(1, 0)[value.flatten().long()], -1, idx
).squeeze(-1)
).squeeze(-1)


def log_prob(self, value):
def log_prob(self, value):
return torch.log(self.pdf(value) + EPSILON)
return torch.log(self.pdf(value) + EPSILON)


def all_log_prob(self):
def all_log_prob(self):
return torch.log(self.probs + EPSILON)
return torch.log(self.probs + EPSILON)


def entropy(self):
def entropy(self):
return -torch.sum(
return -torch.sum(
self.probs * torch.log(self.probs + EPSILON), dim=-1
self.probs * torch.log(self.probs + EPSILON), dim=-1
).unsqueeze(-1)
).unsqueeze(-1)


def exported_model_output(self):
def exported_model_output(self):
return self.sample()
return self.sample()




class GaussianDistribution(nn.Module):
class GaussianDistribution(nn.Module):
def __init__(
def __init__(
self,
self,
hidden_size: int,
hidden_size: int,
num_outputs: int,
num_outputs: int,
conditional_sigma: bool = False,
conditional_sigma: bool = False,
tanh_squash: bool = False,
tanh_squash: bool = False,
):
):
super().__init__()
super().__init__()
self.conditional_sigma = conditional_sigma
self.conditional_sigma = conditional_sigma
self.mu = linear_layer(
self.mu = linear_layer(
hidden_size,
hidden_size,
num_outputs,
num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.2,
kernel_gain=0.2,
bias_init=Initialization.Zero,
bias_init=Initialization.Zero,
)
)
self.tanh_squash = tanh_squash
self.tanh_squash = tanh_squash
if conditional_sigma:
if conditional_sigma:
self.log_sigma = linear_layer(
self.log_sigma = linear_layer(
hidden_size,
hidden_size,
num_outputs,
num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.2,
kernel_gain=0.2,
bias_init=Initialization.Zero,
bias_init=Initialization.Zero,
)
)
else:
else:
self.log_sigma = nn.Parameter(
self.log_sigma = nn.Parameter(
torch.zeros(1, num_outputs, requires_grad=True)
torch.zeros(1, num_outputs, requires_grad=True)
)
)


def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
mu = self.mu(inputs)
if self.conditional_sigma:
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else:
else:
# Expand so that entropy matches batch size. Note that we're using
# Expand so that entropy matches batch size. Note that we're using
# mu*0 here to get the batch size implicitly since Sentis
# mu*0 here to get the batch size implicitly since Sentis
# throws error on runtime broadcasting due to unknown reason. We
# throws error on runtime broadcasting due to unknown reason. We
# use this to replace torch.expand() because it is not supported in
# use this to replace torch.expand() because it is not supported in
# the verified version of Sentis (1.2.0-exp.2).
# the verified version of Sentis (1.2.0-exp.2).
log_sigma = mu * 0 + self.log_sigma
log_sigma = mu * 0 + self.log_sigma
if self.tanh_squash:
if self.tanh_squash:
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
else:
else:
return GaussianDistInstance(mu, torch.exp(log_sigma))
return GaussianDistInstance(mu, torch.exp(log_sigma))




class MultiCategoricalDistribution(nn.Module):
class MultiCategoricalDistribution(nn.Module):
def __init__(self, hidden_size: int, act_sizes: List[int]):
def __init__(self, hidden_size: int, act_sizes: List[int]):
super().__init__()
super().__init__()
self.act_sizes = act_sizes
self.act_sizes = act_sizes
self.branches = self._create_policy_branches(hidden_size)
self.branches = self._create_policy_branches(hidden_size)


def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
branches = []
branches = []
for size in self.act_sizes:
for size in self.act_sizes:
branch_output_layer = linear_layer(
branch_output_layer = linear_layer(
hidden_size,
hidden_size,
size,
size,
kernel_init=Initialization.KaimingHeNormal,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
kernel_gain=0.1,
bias_init=Initialization.Zero,
bias_init=Initialization.Zero,
)
)
branches.append(branch_output_layer)
branches.append(branch_output_layer)
return nn.ModuleList(branches)
return nn.ModuleList(branches)


def _mask_branch(
def _mask_branch(
self, logits: torch.Tensor, allow_mask: torch.Tensor
self, logits: torch.Tensor, allow_mask: torch.Tensor
) -> torch.Tensor:
) -> torch.Tensor:
# Zero out masked logits, then subtract a large value. Technique mentioned here:
# Zero out masked logits, then subtract a large value. Technique mentioned here:
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Sentis-friendly.
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Sentis-friendly.
block_mask = -1.0 * allow_mask + 1.0
block_mask = -1.0 * allow_mask + 1.0
# We do -1 * tensor + constant instead of constant - tensor because it seems
# We do -1 * tensor + constant instead of constant - tensor because it seems
# Sentis might swap the inputs of a "Sub" operation
# Sentis might swap the inputs of a "Sub" operation
logits = logits * allow_mask - 1e8 * block_mask
logits = logits * allow_mask - 1e8 * block_mask


return logits
return logits


def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []
split_masks = []
for idx, _ in enumerate(self.act_sizes):
for idx, _ in enumerate(self.act_sizes):
start = int(np.sum(self.act_sizes[:idx]))
start = int(np.sum(self.act_sizes[:idx]))
end = int(np.sum(self.act_sizes[: idx + 1]))
end = int(np.sum(self.act_sizes[: idx + 1]))
split_masks.append(masks[:, start:end])
split_masks.append(masks[:, start:end])
return split_masks
return split_masks


def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
# Todo - Support multiple branches in mask code
# Todo - Support multiple branches in mask code
branch_distributions = []
branch_distributions = []
masks = self._split_masks(masks)
masks = self._split_masks(masks)
for idx, branch in enumerate(self.branches):
for idx, branch in enumerate(self.branches):
logits = branch(inputs)
logits = branch(inputs)
norm_logits = self._mask_branch(logits, masks[idx])
norm_logits = self._mask_branch(logits, masks[idx])
distribution = CategoricalDistInstance(norm_logits)
distribution = CategoricalDistInstance(norm_logits)
branch_distributions.append(distribution)
branch_distributions.append(distribution)
return branch_distributions
return branch_distributions