Implementation of Invalid Action Masking

Created Diff never expires
8 removals
465 lines
22 additions
478 lines
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.optim as optim
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


from cleanrl.common import preprocess_obs_space, preprocess_ac_space
from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import argparse
import numpy as np
import numpy as np
import gym
import gym
import gym_microrts
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.wrappers import TimeLimit, Monitor
import pybullet_envs
import pybullet_envs
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import time
import random
import random
import os
import os


# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
self.count = epsilon


def update(self, x):
def update(self, x):
batch_mean = np.mean([x], axis=0)
batch_mean = np.mean([x], axis=0)
batch_var = np.var([x], axis=0)
batch_var = np.var([x], axis=0)
batch_count = 1
batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count)
self.update_from_moments(batch_mean, batch_var, batch_count)


def update_from_moments(self, batch_mean, batch_var, batch_count):
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
delta = batch_mean - mean
tot_count = count + batch_count
tot_count = count + batch_count


new_mean = mean + delta * batch_count / tot_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_a = var * count
m_b = batch_var * batch_count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_var = M2 / tot_count
new_count = tot_count
new_count = tot_count


return new_mean, new_var, new_count
return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
class NormalizedEnv(gym.core.Wrapper):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
super(NormalizedEnv, self).__init__(env)
super(NormalizedEnv, self).__init__(env)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.clipob = clipob
self.clipob = clipob
self.cliprew = cliprew
self.cliprew = cliprew
self.ret = np.zeros(())
self.ret = np.zeros(())
self.gamma = gamma
self.gamma = gamma
self.epsilon = epsilon
self.epsilon = epsilon


def step(self, action):
def step(self, action):
obs, rews, news, infos = self.env.step(action)
obs, rews, news, infos = self.env.step(action)
infos['real_reward'] = rews
infos['real_reward'] = rews
# print("before", self.ret)
# print("before", self.ret)
self.ret = self.ret * self.gamma + rews
self.ret = self.ret * self.gamma + rews
# print("after", self.ret)
# print("after", self.ret)
obs = self._obfilt(obs)
obs = self._obfilt(obs)
if self.ret_rms:
if self.ret_rms:
self.ret_rms.update(np.array([self.ret].copy()))
self.ret_rms.update(np.array([self.ret].copy()))
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret = self.ret * (1-float(news))
self.ret = self.ret * (1-float(news))
return obs, rews, news, infos
return obs, rews, news, infos


def _obfilt(self, obs):
def _obfilt(self, obs):
if self.ob_rms:
if self.ob_rms:
self.ob_rms.update(obs)
self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
return obs
else:
else:
return obs
return obs


def reset(self):
def reset(self):
self.ret = np.zeros(())
self.ret = np.zeros(())
obs = self.env.reset()
obs = self.env.reset()
return self._obfilt(obs)
return self._obfilt(obs)


if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent')
parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
help='the id of the gym environment')
help='the id of the gym environment')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
help='seed of the experiment')
parser.add_argument('--episode-length', type=int, default=0,
parser.add_argument('--episode-length', type=int, default=0,
help='the maximum length of each episode')
help='the maximum length of each episode')
parser.add_argument('--total-timesteps', type=int, default=100000,
parser.add_argument('--total-timesteps', type=int, default=100000,
help='total timesteps of the experiments')
help='total timesteps of the experiments')
parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
help='if toggled, cuda will not be enabled by default')
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', action='store_true', default=False,
parser.add_argument('--prod-mode', action='store_true', default=False,
help='run the script in production mode and use wandb to log outputs')
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', action='store_true', default=False,
parser.add_argument('--capture-video', action='store_true', default=False,
help='weather to capture videos of the agent performances (check out `videos` folder)')
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument('--batch-size', type=int, default=2048,
parser.add_argument('--batch-size', type=int, default=2048,
help='the batch size of ppo')
help='the batch size of ppo')
parser.add_argument('--minibatch-size', type=int, default=256,
parser.add_argument('--minibatch-size', type=int, default=256,
help='the mini batch size of ppo')
help='the mini batch size of ppo')
parser.add_argument('--gamma', type=float, default=0.99,
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.97,
parser.add_argument('--gae-lambda', type=float, default=0.97,
help='the lambda for the general advantage estimation')
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2,
parser.add_argument('--clip-coef', type=float, default=0.2,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=10,
parser.add_argument('--update-epochs', type=int, default=10,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', action='store_true', default=False,
parser.add_argument('--kle-stop', action='store_true', default=False,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', action='store_true', default=False,
parser.add_argument('--kle-rollback', action='store_true', default=False,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.015,
parser.add_argument('--target-kl', type=float, default=0.015,
help='the target-kl variable that is referred by --kl')
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', action='store_true', default=True,
parser.add_argument('--gae', action='store_true', default=True,
help='Use GAE for advantage computation')
help='Use GAE for advantage computation')
parser.add_argument('--policy-lr', type=float, default=3e-4,
parser.add_argument('--policy-lr', type=float, default=3e-4,
help="the learning rate of the policy optimizer")
help="the learning rate of the policy optimizer")
parser.add_argument('--value-lr', type=float, default=3e-4,
parser.add_argument('--value-lr', type=float, default=3e-4,
help="the learning rate of the critic optimizer")
help="the learning rate of the critic optimizer")
parser.add_argument('--norm-obs', action='store_true', default=True,
parser.add_argument('--norm-obs', action='store_true', default=True,
help="Toggles observation normalization")
help="Toggles observation normalization")
parser.add_argument('--norm-returns', action='store_true', default=False,
parser.add_argument('--norm-returns', action='store_true', default=False,
help="Toggles returns normalization")
help="Toggles returns normalization")
parser.add_argument('--norm-adv', action='store_true', default=True,
parser.add_argument('--norm-adv', action='store_true', default=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument('--obs-clip', type=float, default=10.0,
parser.add_argument('--obs-clip', type=float, default=10.0,
help="Value for reward clipping, as per the paper")
help="Value for reward clipping, as per the paper")
parser.add_argument('--rew-clip', type=float, default=10.0,
parser.add_argument('--rew-clip', type=float, default=10.0,
help="Value for observation clipping, as per the paper")
help="Value for observation clipping, as per the paper")
parser.add_argument('--anneal-lr', action='store_true', default=True,
parser.add_argument('--anneal-lr', action='store_true', default=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
help='Selects the scheme to be used for weights initialization'),
help='Selects the scheme to be used for weights initialization'),
parser.add_argument('--clip-vloss', action="store_true", default=True,
parser.add_argument('--clip-vloss', action="store_true", default=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
parser.add_argument('--pol-layer-norm', action='store_true', default=False,
parser.add_argument('--pol-layer-norm', action='store_true', default=False,
help='Enables layer normalization in the policy network')
help='Enables layer normalization in the policy network')
args = parser.parse_args()
args = parser.parse_args()
if not args.seed:
if not args.seed:
args.seed = int(time.time())
args.seed = int(time.time())


args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))


if args.prod_mode:
if args.prod_mode:
import wandb
import wandb
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
writer = SummaryWriter(f"/tmp/{experiment_name}")
writer = SummaryWriter(f"/tmp/{experiment_name}")
wandb.save(os.path.abspath(__file__))
wandb.save(os.path.abspath(__file__))


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
env = gym.make(args.gym_id)
# respect the default timelimit
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
if isinstance(env, TimeLimit):
if int(args.episode_length):
if int(args.episode_length):
env._max_episode_steps = int(args.episode_length)
env._max_episode_steps = int(args.episode_length)
args.episode_length = env._max_episode_steps
args.episode_length = env._max_episode_steps
else:
else:
env = TimeLimit(env, int(args.episode_length))
env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.seed(args.seed)
env.action_space.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
if args.capture_video:
env = Monitor(env, f'videos/{experiment_name}')
env = Monitor(env, f'videos/{experiment_name}')


# ALGO LOGIC: initialize agent here:
# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
class CategoricalMasked(Categorical):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
self.masks = masks
self.masks = masks
if len(self.masks) == 0:
if len(self.masks) == 0:
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
else:
else:
self.masks = masks.type(torch.BoolTensor).to(device)
self.masks = masks.type(torch.BoolTensor).to(device)
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
def entropy(self):
def entropy(self):
if len(self.masks) == 0:
if len(self.masks) == 0:
return super(CategoricalMasked, self).entropy()
return super(CategoricalMasked, self).entropy()
p_log_p = self.logits * self.probs
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
return -p_log_p.sum(-1)
return -p_log_p.sum(-1)


class Policy(nn.Module):
class Policy(nn.Module):
def __init__(self):
def __init__(self):
super(Policy, self).__init__()
super(Policy, self).__init__()
self.features = nn.Sequential(
self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,),
nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU(),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3),
nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU())
nn.ReLU())
self.fc = nn.Sequential(
self.fc = nn.Sequential(
nn.Linear(32*6*6, 128),
nn.Linear(32*6*6, 128),
nn.ReLU(),
nn.ReLU(),
nn.Linear(128, env.action_space.nvec.sum())
nn.Linear(128, env.action_space.nvec.sum())
)
)


def forward(self, x):
def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x)
x = self.features(x)
x = x.view(x.size(0), -1)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.fc(x)
return x
return x


def get_action(self, x, action=None):
def get_action(self, x, action=None, invalid_action_masks=None):
logits = self.forward(x)
logits = self.forward(x)
split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if invalid_action_masks is not None:
split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
else:
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if action is None:
if action is None:
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
# entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])

return action, logprob, [], multi_categoricals
return action, logprob, [], multi_categoricals


class Value(nn.Module):
class Value(nn.Module):
def __init__(self):
def __init__(self):
super(Value, self).__init__()
super(Value, self).__init__()
self.features = nn.Sequential(
self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,),
nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU(),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3),
nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU())
nn.ReLU())
self.fc = nn.Sequential(
self.fc = nn.Sequential(
nn.Linear(32*6*6, 128),
nn.Linear(32*6*6, 128),
nn.ReLU(),
nn.ReLU(),
nn.Linear(128, 1)
nn.Linear(128, 1)
)
)


def forward(self, x):
def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x)
x = self.features(x)
x = x.view(x.size(0), -1)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.fc(x)
return x
return x


def discount_cumsum(x, dones, gamma):
def discount_cumsum(x, dones, gamma):
"""
"""
computing discounted cumulative sums of vectors that resets with dones
computing discounted cumulative sums of vectors that resets with dones
input:
input:
vector x, vector dones,
vector x, vector dones,
[x0, [0,
[x0, [0,
x1, 0,
x1, 0,
x2 1,
x2 1,
x3 0,
x3 0,
x4] 0]
x4] 0]
output:
output:
[x0 + discount * x1 + discount^2 * x2,
[x0 + discount * x1 + discount^2 * x2,
x1 + discount * x2,
x1 + discount * x2,
x2,
x2,
x3 + discount * x4,
x3 + discount * x4,
x4]
x4]
"""
"""
discount_cumsum = np.zeros_like(x)
discount_cumsum = np.zeros_like(x)
discount_cumsum[-1] = x[-1]
discount_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0]-1)):
for t in reversed(range(x.shape[0]-1)):
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
return discount_cumsum
return discount_cumsum


pg = Policy().to(device)
pg = Policy().to(device)
vf = Value().to(device)
vf = Value().to(device)


# MODIFIED: Separate optimizer and learning rates
# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)


# MODIFIED: Initializing learning rate anneal scheduler when need
# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
if args.anneal_lr:
anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)
vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)


loss_fn = nn.MSELoss()
loss_fn = nn.MSELoss()


# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
while global_step < args.total_timesteps:
while global_step < args.total_timesteps:
if args.capture_video:
if args.capture_video:
env.stats_recorder.done=True
env.stats_recorder.done=True
next_obs = np.array(env.reset())
next_obs = np.array(env.reset())


# ALGO Logic: Storage for epoch data
# ALGO Logic: Storage for epoch data
obs = np.empty((args.batch_size,) + env.observation_space.shape)
obs = np.empty((args.batch_size,) + env.observation_space.shape)


actions = np.empty((args.batch_size,) + env.action_space.shape)
actions = np.empty((args.batch_size,) + env.action_space.shape)
logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)
logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)


rewards = np.zeros((args.batch_size,))
rewards = np.zeros((args.batch_size,))
raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
real_rewards = []
real_rewards = []
test_reward = []
returns = np.zeros((args.batch_size,))
returns = np.zeros((args.batch_size,))


dones = np.zeros((args.batch_size,))
dones = np.zeros((args.batch_size,))
values = torch.zeros((args.batch_size,)).to(device)
values = torch.zeros((args.batch_size,)).to(device)


invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
# TRY NOT TO MODIFY: prepare the execution of the game.
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(args.batch_size):
for step in range(args.batch_size):
env.render()
env.render()
global_step += 1
global_step += 1
obs[step] = next_obs.copy()
obs[step] = next_obs.copy()


# ALGO LOGIC: put action logic here
# ALGO LOGIC: put action logic here
invalid_action_mask = torch.ones(env.action_space.nvec.sum())
invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
invalid_action_masks[step] = invalid_action_mask
with torch.no_grad():
with torch.no_grad():
values[step] = vf.forward(obs[step:step+1])
values[step] = vf.forward(obs[step:step+1])
action, logproba, _, probs = pg.get_action(obs[step:step+1])
action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
actions[step] = action[:,0].data.cpu().numpy()
actions[step] = action[:,0].data.cpu().numpy()
logprobs[:,[step]] = logproba
logprobs[:,[step]] = logproba


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
raw_rewards[:,step] = info["rewards"]
real_rewards += [info['real_reward']]
real_rewards += [info['real_reward']]
next_obs = np.array(next_obs)
next_obs = np.array(next_obs)


# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if args.anneal_lr:
if args.anneal_lr:
pg_lr_scheduler.step()
pg_lr_scheduler.step()
vf_lr_scheduler.step()
vf_lr_scheduler.step()


if dones[step]:
if dones[step]:
# Computing the discounted returns:
# Computing the discounted returns:
writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
for i in range(len(env.rfs)):
writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
real_rewards = []
real_rewards = []
next_obs = np.array(env.reset())
next_obs = np.array(env.reset())


# bootstrap reward if not done. reached the batch limit
# bootstrap reward if not done. reached the batch limit
last_value = 0
last_value = 0
if not dones[step]:
if not dones[step]:
last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
bootstrapped_rewards = np.append(rewards, last_value)
bootstrapped_rewards = np.append(rewards, last_value)


# calculate the returns and advantages
# calculate the returns and advantages
if args.gae:
if args.gae:
bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
advantages = torch.Tensor(advantages).to(device)
advantages = torch.Tensor(advantages).to(device)
returns = advantages + values
returns = advantages + values
else:
else:
returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
advantages = returns - values.detach().cpu().numpy()
advantages = returns - values.detach().cpu().numpy()
advantages = torch.Tensor(advantages).to(device)
advantages = torch.Tensor(advantages).to(device)
returns = torch.Tensor(returns).to(device)
returns = torch.Tensor(returns).to(device)


# Advantage normalization
# Advantage normalization
if args.norm_adv:
if args.norm_adv:
EPS = 1e-10
EPS = 1e-10
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)


# Optimizaing policy network
# Optimizaing policy network
entropys = []
entropys = []
target_pg = Policy().to(device)
target_pg = Policy().to(device)
inds = np.arange(args.batch_size,)
inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs):
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
np.random.shuffle(inds)
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
minibatch_ind = inds[start:end]
target_pg.load_state_dict(pg.state_dict())
target_pg.load_state_dict(pg.state_dict())
_, newlogproba, _, _ = pg.get_action(
_, newlogproba, _, _ = pg.get_action(
obs[minibatch_ind],
obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T)
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
invalid_action_masks[minibatch_ind])
ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()
ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()


# Policy loss as in OpenAI SpinUp
# Policy loss as in OpenAI SpinUp
clip_adv = torch.where(advantages[minibatch_ind] > 0,
clip_adv = torch.where(advantages[minibatch_ind] > 0,
(1.+args.clip_coef) * advantages[minibatch_ind],
(1.+args.clip_coef) * advantages[minibatch_ind],
(1.-args.clip_coef) * advantages[minibatch_ind]).to(device)
(1.-args.clip_coef) * advantages[minibatch_ind]).to(device)


# Entropy computation with resampled actions
# Entropy computation with resampled actions
entropy = -(newlogproba.exp() * newlogproba).mean()
entropy = -(newlogproba.exp() * newlogproba).mean()
entropys.append(entropy.item())
entropys.append(entropy.item())


policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
policy_loss = policy_loss.mean()
policy_loss = policy_loss.mean()
pg_optimizer.zero_grad()
pg_optimizer.zero_grad()
policy_loss.backward()
policy_loss.backward()
nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
pg_optimizer.step()
pg_optimizer.step()


approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
# Resample values
# Optimizing value network
new_values = vf.forward(obs[minibatch_ind]).view(-1)
new_values = vf.forward(obs[minibatch_ind]).view(-1)


# Value loss clipping
# Value loss clipping
if args.clip_vloss:
if args.clip_vloss:
v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
v_loss = 0.5 * v_loss_max.mean()
else:
else:
v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))
v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))


v_optimizer.zero_grad()
v_optimizer.zero_grad()
v_loss.backward()
v_loss.backward()
nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
v_optimizer.step()
v_optimizer.step()


if args.kle_stop:
if args.kle_stop:
if approx_kl > args.target_kl:
if approx_kl > args.target_kl:
break
break
if args.kle_rollback:
if args.kle_rollback:
if (logprobs[:,minibatch_ind] -
if (logprobs[:,minibatch_ind] -
pg.get_action(
pg.get_action(
obs[minibatch_ind],
obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
)[1]).mean() > args.target_kl:
invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
pg.load_state_dict(target_pg.state_dict())
pg.load_state_dict(target_pg.state_dict())
break
break


# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
if args.kle_stop or args.kle_rollback:
if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)


env.close()
env.close()
writer.close()
writer.close()