Comparing sensitive data, confidential files or internal emails?

Most legal and privacy policies prohibit uploading sensitive data online. Diffchecker Desktop ensures your confidential information never leaves your computer. Work offline and compare documents securely.

Implementation of Invalid Action Masking

Created Diff never expires
8 removals
465 lines
22 additions
478 lines
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.optim as optim
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


from cleanrl.common import preprocess_obs_space, preprocess_ac_space
from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import argparse
import numpy as np
import numpy as np
import gym
import gym
import gym_microrts
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.wrappers import TimeLimit, Monitor
import pybullet_envs
import pybullet_envs
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import time
import random
import random
import os
import os


# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
self.count = epsilon


def update(self, x):
def update(self, x):
batch_mean = np.mean([x], axis=0)
batch_mean = np.mean([x], axis=0)
batch_var = np.var([x], axis=0)
batch_var = np.var([x], axis=0)
batch_count = 1
batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count)
self.update_from_moments(batch_mean, batch_var, batch_count)


def update_from_moments(self, batch_mean, batch_var, batch_count):
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
delta = batch_mean - mean
tot_count = count + batch_count
tot_count = count + batch_count


new_mean = mean + delta * batch_count / tot_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_a = var * count
m_b = batch_var * batch_count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_var = M2 / tot_count
new_count = tot_count
new_count = tot_count


return new_mean, new_var, new_count
return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
class NormalizedEnv(gym.core.Wrapper):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
super(NormalizedEnv, self).__init__(env)
super(NormalizedEnv, self).__init__(env)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.clipob = clipob
self.clipob = clipob
self.cliprew = cliprew
self.cliprew = cliprew
self.ret = np.zeros(())
self.ret = np.zeros(())
self.gamma = gamma
self.gamma = gamma
self.epsilon = epsilon
self.epsilon = epsilon


def step(self, action):
def step(self, action):
obs, rews, news, infos = self.env.step(action)
obs, rews, news, infos = self.env.step(action)
infos['real_reward'] = rews
infos['real_reward'] = rews
# print("before", self.ret)
# print("before", self.ret)
self.ret = self.ret * self.gamma + rews
self.ret = self.ret * self.gamma + rews
# print("after", self.ret)
# print("after", self.ret)
obs = self._obfilt(obs)
obs = self._obfilt(obs)
if self.ret_rms:
if self.ret_rms:
self.ret_rms.update(np.array([self.ret].copy()))
self.ret_rms.update(np.array([self.ret].copy()))
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret = self.ret * (1-float(news))
self.ret = self.ret * (1-float(news))
return obs, rews, news, infos
return obs, rews, news, infos


def _obfilt(self, obs):
def _obfilt(self, obs):
if self.ob_rms:
if self.ob_rms:
self.ob_rms.update(obs)
self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
return obs
else:
else:
return obs
return obs


def reset(self):
def reset(self):
self.ret = np.zeros(())
self.ret = np.zeros(())
obs = self.env.reset()
obs = self.env.reset()
return self._obfilt(obs)
return self._obfilt(obs)


if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent')
parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
help='the id of the gym environment')
help='the id of the gym environment')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
help='seed of the experiment')
parser.add_argument('--episode-length', type=int, default=0,
parser.add_argument('--episode-length', type=int, default=0,
help='the maximum length of each episode')
help='the maximum length of each episode')
parser.add_argument('--total-timesteps', type=int, default=100000,
parser.add_argument('--total-timesteps', type=int, default=100000,
help='total timesteps of the experiments')
help='total timesteps of the experiments')
parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
help='if toggled, cuda will not be enabled by default')
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', action='store_true', default=False,
parser.add_argument('--prod-mode', action='store_true', default=False,
help='run the script in production mode and use wandb to log outputs')
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', action='store_true', default=False,
parser.add_argument('--capture-video', action='store_true', default=False,
help='weather to capture videos of the agent performances (check out `videos` folder)')
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument('--batch-size', type=int, default=2048,
parser.add_argument('--batch-size', type=int, default=2048,
help='the batch size of ppo')
help='the batch size of ppo')
parser.add_argument('--minibatch-size', type=int, default=256,
parser.add_argument('--minibatch-size', type=int, default=256,
help='the mini batch size of ppo')
help='the mini batch size of ppo')
parser.add_argument('--gamma', type=float, default=0.99,
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.97,
parser.add_argument('--gae-lambda', type=float, default=0.97,
help='the lambda for the general advantage estimation')
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2,
parser.add_argument('--clip-coef', type=float, default=0.2,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=10,
parser.add_argument('--update-epochs', type=int, default=10,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', action='store_true', default=False,
parser.add_argument('--kle-stop', action='store_true', default=False,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', action='store_true', default=False,
parser.add_argument('--kle-rollback', action='store_true', default=False,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.015,
parser.add_argument('--target-kl', type=float, default=0.015,
help='the target-kl variable that is referred by --kl')
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', action='store_true', default=True,
parser.add_argument('--gae', action='store_true', default=True,
help='Use GAE for advantage computation')
help='Use GAE for advantage computation')
parser.add_argument('--policy-lr', type=float, default=3e-4,
parser.add_argument('--policy-lr', type=float, default=3e-4,
help="the learning rate of the policy optimizer")
help="the learning rate of the policy optimizer")
parser.add_argument('--value-lr', type=float, default=3e-4,
parser.add_argument('--value-lr', type=float, default=3e-4,
help="the learning rate of the critic optimizer")
help="the learning rate of the critic optimizer")
parser.add_argument('--norm-obs', action='store_true', default=True,
parser.add_argument('--norm-obs', action='store_true', default=True,
help="Toggles observation normalization")
help="Toggles observation normalization")
parser.add_argument('--norm-returns', action='store_true', default=False,
parser.add_argument('--norm-returns', action='store_true', default=False,
help="Toggles returns normalization")
help="Toggles returns normalization")
parser.add_argument('--norm-adv', action='store_true', default=True,
parser.add_argument('--norm-adv', action='store_true', default=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument('--obs-clip', type=float, default=10.0,
parser.add_argument('--obs-clip', type=float, default=10.0,
help="Value for reward clipping, as per the paper")
help="Value for reward clipping, as per the paper")
parser.add_argument('--rew-clip', type=float, default=10.0,
parser.add_argument('--rew-clip', type=float, default=10.0,
help="Value for observation clipping, as per the paper")
help="Value for observation clipping, as per the paper")
parser.add_argument('--anneal-lr', action='store_true', default=True,
parser.add_argument('--anneal-lr', action='store_true', default=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
help='Selects the scheme to be used for weights initialization'),
help='Selects the scheme to be used for weights initialization'),
parser.add_argument('--clip-vloss', action="store_true", default=True,
parser.add_argument('--clip-vloss', action="store_true", default=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
parser.add_argument('--pol-layer-norm', action='store_true', default=False,
parser.add_argument('--pol-layer-norm', action='store_true', default=False,
help='Enables layer normalization in the policy network')
help='Enables layer normalization in the policy network')
args = parser.parse_args()
args = parser.parse_args()
if not args.seed:
if not args.seed:
args.seed = int(time.time())
args.seed = int(time.time())


args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))


if args.prod_mode:
if args.prod_mode:
import wandb
import wandb
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
writer = SummaryWriter(f"/tmp/{experiment_name}")
writer = SummaryWriter(f"/tmp/{experiment_name}")
wandb.save(os.path.abspath(__file__))
wandb.save(os.path.abspath(__file__))


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
env = gym.make(args.gym_id)
# respect the default timelimit
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
if isinstance(env, TimeLimit):
if int(args.episode_length):
if int(args.episode_length):
env._max_episode_steps = int(args.episode_length)
env._max_episode_steps = int(args.episode_length)
args.episode_length = env._max_episode_steps
args.episode_length = env._max_episode_steps
else:
else:
env = TimeLimit(env, int(args.episode_length))
env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.seed(args.seed)
env.action_space.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
if args.capture_video:
env = Monitor(env, f'videos/{experiment_name}')
env = Monitor(env, f'videos/{experiment_name}')


# ALGO LOGIC: initialize agent here:
# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
class CategoricalMasked(Categorical):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
self.masks = masks
self.masks = masks
if len(self.masks) == 0:
if len(self.masks) == 0:
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
else:
else:
self.masks = masks.type(torch.BoolTensor).to(device)
self.masks = masks.type(torch.BoolTensor).to(device)
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
def entropy(self):
def entropy(self):
if len(self.masks) == 0:
if len(self.masks) == 0:
return super(CategoricalMasked, self).entropy()
return super(CategoricalMasked, self).entropy()
p_log_p = self.logits * self.probs
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
return -p_log_p.sum(-1)
return -p_log_p.sum(-1)


class Policy(nn.Module):
class Policy(nn.Module):
def __init__(self):
def __init__(self):
super(Policy, self).__init__()
super(Policy, self).__init__()
self.features = nn.Sequential(
self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,),
nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU(),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3),
nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU())
nn.ReLU())
self.fc = nn.Sequential(
self.fc = nn.Sequential(
nn.Linear(32*6*6, 128),
nn.Linear(32*6*6, 128),
nn.ReLU(),
nn.ReLU(),
nn.Linear(128, env.action_space.nvec.sum())
nn.Linear(128, env.action_space.nvec.sum())
)
)


def forward(self, x):
def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x)
x = self.features(x)
x = x.view(x.size(0), -1)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.fc(x)
return x
return x


def get_action(self, x, action=None):
def get_action(self, x, action=None, invalid_action_masks=None):
logits = self.forward(x)
logits = self.forward(x)
split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if invalid_action_masks is not None:
split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
else:
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if action is None:
if action is None:
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
# entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])

return action, logprob, [], multi_categoricals
return action, logprob, [], multi_categoricals


class Value(nn.Module):
class Value(nn.Module):
def __init__(self):
def __init__(self):
super(Value, self).__init__()
super(Value, self).__init__()
self.features = nn.Sequential(
self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,),
nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU(),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3),
nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1),
nn.MaxPool2d(1),
nn.ReLU())
nn.ReLU())
self.fc = nn.Sequential(
self.fc = nn.Sequential(
nn.Linear(32*6*6, 128),
nn.Linear(32*6*6, 128),
nn.ReLU(),
nn.ReLU(),
nn.Linear(128, 1)
nn.Linear(128, 1)
)
)


def forward(self, x):
def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x)
x = self.features(x)
x = x.view(x.size(0), -1)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.fc(x)
return x
return x


def discount_cumsum(x, dones, gamma):
def discount_cumsum(x, dones, gamma):
"""
"""
computing discounted cumulative sums of vectors that resets with dones
computing discounted cumulative sums of vectors that resets with dones
input:
input:
vector x, vector dones,
vector x, vector dones,
[x0, [0,
[x0, [0,
x1, 0,
x1, 0,
x2 1,
x2 1,
x3 0,
x3 0,
x4] 0]
x4] 0]
output:
output:
[x0 + discount * x1 + discount^2 * x2,
[x0 + discount * x1 + discount^2 * x2,
x1 + discount * x2,
x1 + discount * x2,
x2,
x2,
x3 + discount * x4,
x3 + discount * x4,
x4]
x4]
"""
"""
discount_cumsum = np.zeros_like(x)
discount_cumsum = np.zeros_like(x)
discount_cumsum[-1] = x[-1]
discount_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0]-1)):
for t in reversed(range(x.shape[0]-1)):
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
return discount_cumsum
return discount_cumsum


pg = Policy().to(device)
pg = Policy().to(device)
vf = Value().to(device)
vf = Value().to(device)


# MODIFIED: Separate optimizer and learning rates
# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)


# MODIFIED: Initializing learning rate anneal scheduler when need
# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
if args.anneal_lr:
anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)
vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)


loss_fn = nn.MSELoss()
loss_fn = nn.MSELoss()


# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
while global_step < args.total_timesteps:
while global_step < args.total_timesteps:
if args.capture_video:
if args.capture_video:
env.stats_recorder.done=True
env.stats_recorder.done=True
next_obs = np.array(env.reset())
next_obs = np.array(env.reset())


# ALGO Logic: Storage for epoch data
# ALGO Logic: Storage for epoch data
obs = np.empty((args.batch_size,) + env.observation_space.shape)
obs = np.empty((args.batch_size,) + env.observation_space.shape)


actions = np.empty((args.batch_size,) + env.action_space.shape)
actions = np.empty((args.batch_size,) + env.action_space.shape)
logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)
logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)


rewards = np.zeros((args.batch_size,))
rewards = np.zeros((args.batch_size,))
raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
real_rewards = []
real_rewards = []
test_reward = []
returns = np.zeros((args.batch_size,))
returns = np.zeros((args.batch_size,))


dones = np.zeros((args.batch_size,))
dones = np.zeros((args.batch_size,))
values = torch.zeros((args.batch_size,)).to(device)
values = torch.zeros((args.batch_size,)).to(device)


invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
# TRY NOT TO MODIFY: prepare the execution of the game.
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(args.batch_size):
for step in range(args.batch_size):
env.render()
env.render()
global_step += 1
global_step += 1
obs[step] = next_obs.copy()
obs[step] = next_obs.copy()


# ALGO LOGIC: put action logic here
# ALGO LOGIC: put action logic here
invalid_action_mask = torch.ones(env.action_space.nvec.sum())
invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
invalid_action_masks[step] = invalid_action_mask
with torch.no_grad():
with torch.no_grad():
values[step] = vf.forward(obs[step:step+1])
values[step] = vf.forward(obs[step:step+1])
action, logproba, _, probs = pg.get_action(obs[step:step+1])
action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
actions[step] = action[:,0].data.cpu().numpy()
actions[step] = action[:,0].data.cpu().numpy()
logprobs[:,[step]] = logproba
logprobs[:,[step]] = logproba


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
raw_rewards[:,step] = info["rewards"]
real_rewards += [info['real_reward']]
real_rewards += [info['real_reward']]
next_obs = np.array(next_obs)
next_obs = np.array(next_obs)


# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if args.anneal_lr:
if args.anneal_lr:
pg_lr_scheduler.step()
pg_lr_scheduler.step()
vf_lr_scheduler.step()
vf_lr_scheduler.step()


if dones[step]:
if dones[step]:
# Computing the discounted returns:
# Computing the discounted returns:
writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
for i in range(len(env.rfs)):
writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
real_rewards = []
real_rewards = []
next_obs = np.array(env.reset())
next_obs = np.array(env.reset())


# bootstrap reward if not done. reached the batch limit
# bootstrap reward if not done. reached the batch limit
last_value = 0
last_value = 0
if not dones[step]:
if not dones[step]:
last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
bootstrapped_rewards = np.append(rewards, last_value)
bootstrapped_rewards = np.append(rewards, last_value)


# calculate the returns and advantages
# calculate the returns and advantages
if args.gae:
if args.gae:
bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
advantages = torch.Tensor(advantages).to(device)
advantages = torch.Tensor(advantages).to(device)
returns = advantages + values
returns = advantages + values
else:
else:
returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
advantages = returns - values.detach().cpu().numpy()
advantages = returns - values.detach().cpu().numpy()
advantages = torch.Tensor(advantages).to(device)
advantages = torch.Tensor(advantages).to(device)
returns = torch.Tensor(returns).to(device)
returns = torch.Tensor(returns).to(device)


# Advantage normalization
# Advantage normalization
if args.norm_adv:
if args.norm_adv:
EPS = 1e-10
EPS = 1e-10
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)


# Optimizaing policy network
# Optimizaing policy network
entropys = []
entropys = []
target_pg = Policy().to(device)
target_pg = Policy().to(device)
inds = np.arange(args.batch_size,)
inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs):
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
np.random.shuffle(inds)
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
minibatch_ind = inds[start:end]
target_pg.load_state_dict(pg.state_dict())
target_pg.load_state_dict(pg.state_dict())
_, newlogproba, _, _ = pg.get_action(
_, newlogproba, _, _ = pg.get_action(
obs[minibatch_ind],
obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T)
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
invalid_action_masks[minibatch_ind])
ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()
ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()


# Policy loss as in OpenAI SpinUp
# Policy loss as in OpenAI SpinUp
clip_adv = torch.where(advantages[minibatch_ind] > 0,
clip_adv = torch.where(advantages[minibatch_ind] > 0,
(1.+args.clip_coef) * advantages[minibatch_ind],
(1.+args.clip_coef) * advantages[minibatch_ind],
(1.-args.clip_coef) * advantages[minibatch_ind]).to(device)
(1.-args.clip_coef) * advantages[minibatch_ind]).to(device)


# Entropy computation with resampled actions
# Entropy computation with resampled actions
entropy = -(newlogproba.exp() * newlogproba).mean()
entropy = -(newlogproba.exp() * newlogproba).mean()
entropys.append(entropy.item())
entropys.append(entropy.item())


policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
policy_loss = policy_loss.mean()
policy_loss = policy_loss.mean()
pg_optimizer.zero_grad()
pg_optimizer.zero_grad()
policy_loss.backward()
policy_loss.backward()
nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
pg_optimizer.step()
pg_optimizer.step()


approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
# Resample values
# Optimizing value network
new_values = vf.forward(obs[minibatch_ind]).view(-1)
new_values = vf.forward(obs[minibatch_ind]).view(-1)


# Value loss clipping
# Value loss clipping
if args.clip_vloss:
if args.clip_vloss:
v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
v_loss = 0.5 * v_loss_max.mean()
else:
else:
v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))
v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))


v_optimizer.zero_grad()
v_optimizer.zero_grad()
v_loss.backward()
v_loss.backward()
nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
v_optimizer.step()
v_optimizer.step()


if args.kle_stop:
if args.kle_stop:
if approx_kl > args.target_kl:
if approx_kl > args.target_kl:
break
break
if args.kle_rollback:
if args.kle_rollback:
if (logprobs[:,minibatch_ind] -
if (logprobs[:,minibatch_ind] -
pg.get_action(
pg.get_action(
obs[minibatch_ind],
obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
)[1]).mean() > args.target_kl:
invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
pg.load_state_dict(target_pg.state_dict())
pg.load_state_dict(target_pg.state_dict())
break
break


# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
if args.kle_stop or args.kle_rollback:
if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)


env.close()
env.close()
writer.close()
writer.close()