PPO handle truncation w/ EnvPool's API (attempt 4)

Created Diff never expires
3 removals
445 lines
63 additions
494 lines
"""
Handle truncation
python -i ppo_atari_envpool_xla_jax_truncation.py --env-id Breakout-v5 --num-envs 1 --num-steps 8 --num-minibatches 2 --update-epochs 2

>>> storage.dones.flatten()
DeviceArray([0., 0., 0., 0., 0., 0., 1., 0.], dtype=float32)
>>> storage.truncations.flatten()
DeviceArray([0., 0., 0., 0., 0., 0., 1., 0.], dtype=float32)
>>> storage.rewards.flatten()
DeviceArray([0., 0., 1., 0., 0., 0., 0., 0.], dtype=float32)
>>> storage.values.flatten()
DeviceArray([ 0.00226192, 0.00071621, 0.00114149, -0.00414939,
-0.00838596, -0.01181885, -0.01047847, 0.00127411], dtype=float32)

# bootstrap value
>>> jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards).flatten()
DeviceArray([ 0. , 0. , 1. , 0. ,
0. , 0. , -0.01047847, 0. ], dtype=float32)

"""


# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy
import argparse
import argparse
import os
import os
import random
import random
import time
import time
from distutils.util import strtobool
from distutils.util import strtobool
from typing import Sequence
from typing import Sequence


os.environ[
os.environ[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
"XLA_PYTHON_CLIENT_MEM_FRACTION"
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991


import envpool
import envpool
import flax
import flax
import flax.linen as nn
import flax.linen as nn
import gym
import gym
import jax
import jax
import jax.numpy as jnp
import jax.numpy as jnp
import numpy as np
import numpy as np
import optax
import optax
from flax.linen.initializers import constant, orthogonal
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from flax.training.train_state import TrainState
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter




def parse_args():
def parse_args():
# fmt: off
# fmt: off
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
help="whether to capture videos of the agent performances (check out `videos` folder)")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="Pong-v5",
parser.add_argument("--env-id", type=str, default="Pong-v5",
help="the id of the environment")
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=10000000,
parser.add_argument("--total-timesteps", type=int, default=10000000,
help="total timesteps of the experiments")
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=8,
parser.add_argument("--num-envs", type=int, default=8,
help="the number of parallel game environments")
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=128,
parser.add_argument("--num-steps", type=int, default=128,
help="the number of steps to run in each environment per policy rollout")
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gamma", type=float, default=0.99,
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=4,
parser.add_argument("--num-minibatches", type=int, default=4,
help="the number of mini-batches")
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=4,
parser.add_argument("--update-epochs", type=int, default=4,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.1,
parser.add_argument("--clip-coef", type=float, default=0.1,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument("--ent-coef", type=float, default=0.01,
parser.add_argument("--ent-coef", type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
help="the target KL divergence threshold")
args = parser.parse_args()
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_updates = args.total_timesteps // args.batch_size
args.num_updates = args.total_timesteps // args.batch_size
# fmt: on
# fmt: on
return args
return args




class Network(nn.Module):
class Network(nn.Module):
@nn.compact
@nn.compact
def __call__(self, x):
def __call__(self, x):
x = jnp.transpose(x, (0, 2, 3, 1))
x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
x = x / (255.0)
x = nn.Conv(
x = nn.Conv(
32,
32,
kernel_size=(8, 8),
kernel_size=(8, 8),
strides=(4, 4),
strides=(4, 4),
padding="VALID",
padding="VALID",
kernel_init=orthogonal(np.sqrt(2)),
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
bias_init=constant(0.0),
)(x)
)(x)
x = nn.relu(x)
x = nn.relu(x)
x = nn.Conv(
x = nn.Conv(
64,
64,
kernel_size=(4, 4),
kernel_size=(4, 4),
strides=(2, 2),
strides=(2, 2),
padding="VALID",
padding="VALID",
kernel_init=orthogonal(np.sqrt(2)),
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
bias_init=constant(0.0),
)(x)
)(x)
x = nn.relu(x)
x = nn.relu(x)
x = nn.Conv(
x = nn.Conv(
64,
64,
kernel_size=(3, 3),
kernel_size=(3, 3),
strides=(1, 1),
strides=(1, 1),
padding="VALID",
padding="VALID",
kernel_init=orthogonal(np.sqrt(2)),
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
bias_init=constant(0.0),
)(x)
)(x)
x = nn.relu(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
x = nn.relu(x)
x = nn.relu(x)
return x
return x




class Critic(nn.Module):
class Critic(nn.Module):
@nn.compact
@nn.compact
def __call__(self, x):
def __call__(self, x):
return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)
return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)




class Actor(nn.Module):
class Actor(nn.Module):
action_dim: Sequence[int]
action_dim: Sequence[int]


@nn.compact
@nn.compact
def __call__(self, x):
def __call__(self, x):
return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)
return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)




@flax.struct.dataclass
@flax.struct.dataclass
class AgentParams:
class AgentParams:
network_params: flax.core.FrozenDict
network_params: flax.core.FrozenDict
actor_params: flax.core.FrozenDict
actor_params: flax.core.FrozenDict
critic_params: flax.core.FrozenDict
critic_params: flax.core.FrozenDict




@flax.struct.dataclass
@flax.struct.dataclass
class Storage:
class Storage:
obs: jnp.array
obs: jnp.array
actions: jnp.array
actions: jnp.array
logprobs: jnp.array
logprobs: jnp.array
dones: jnp.array
dones: jnp.array
values: jnp.array
values: jnp.array
advantages: jnp.array
advantages: jnp.array
returns: jnp.array
returns: jnp.array
rewards: jnp.array
rewards: jnp.array
truncations: jnp.array




@flax.struct.dataclass
@flax.struct.dataclass
class EpisodeStatistics:
class EpisodeStatistics:
episode_returns: jnp.array
episode_returns: jnp.array
episode_lengths: jnp.array
episode_lengths: jnp.array
returned_episode_returns: jnp.array
returned_episode_returns: jnp.array
returned_episode_lengths: jnp.array
returned_episode_lengths: jnp.array




if __name__ == "__main__":
if __name__ == "__main__":
args = parse_args()
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
if args.track:
import wandb
import wandb


wandb.init(
wandb.init(
project=args.wandb_project_name,
project=args.wandb_project_name,
entity=args.wandb_entity,
entity=args.wandb_entity,
sync_tensorboard=True,
sync_tensorboard=True,
config=vars(args),
config=vars(args),
name=run_name,
name=run_name,
monitor_gym=True,
monitor_gym=True,
save_code=True,
save_code=True,
)
)
writer = SummaryWriter(f"runs/{run_name}")
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
writer.add_text(
"hyperparameters",
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
)


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key = jax.random.PRNGKey(args.seed)
key, network_key, actor_key, critic_key = jax.random.split(key, 4)
key, network_key, actor_key, critic_key = jax.random.split(key, 4)


# env setup
# env setup
envs = envpool.make(
envs = envpool.make(
args.env_id,
args.env_id,
env_type="gym",
env_type="gym",
num_envs=args.num_envs,
num_envs=args.num_envs,
episodic_life=True,
episodic_life=True,
reward_clip=True,
reward_clip=True,
seed=args.seed,
seed=args.seed,
max_episode_steps=100,
)
)
envs.num_envs = args.num_envs
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs.single_observation_space = envs.observation_space
envs.is_vector_env = True
envs.is_vector_env = True
episode_stats = EpisodeStatistics(
episode_stats = EpisodeStatistics(
episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
)
)
handle, recv, send, step_env = envs.xla()
handle, recv, send, step_env = envs.xla()


def step_env_wrappeed(episode_stats, handle, action):
def step_env_wrappeed(episode_stats, handle, action):
handle, (next_obs, reward, next_done, info) = step_env(handle, action)
handle, (next_obs, reward, next_done, info) = step_env(handle, action)
new_episode_return = episode_stats.episode_returns + info["reward"]
new_episode_return = episode_stats.episode_returns + info["reward"]
new_episode_length = episode_stats.episode_lengths + 1
new_episode_length = episode_stats.episode_lengths + 1
episode_stats = episode_stats.replace(
episode_stats = episode_stats.replace(
episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
# only update the `returned_episode_returns` if the episode is done
# only update the `returned_episode_returns` if the episode is done
returned_episode_returns=jnp.where(
returned_episode_returns=jnp.where(
info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns
info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns
),
),
returned_episode_lengths=jnp.where(
returned_episode_lengths=jnp.where(
info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths
info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths
),
),
)
)
return episode_stats, handle, (next_obs, reward, next_done, info)
return episode_stats, handle, (next_obs, reward, next_done, info)


assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"


def linear_schedule(count):
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches * args.update_epochs) gradient updates
# (args.num_minibatches * args.update_epochs) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
return args.learning_rate * frac


network = Network()
network = Network()
actor = Actor(action_dim=envs.single_action_space.n)
actor = Actor(action_dim=envs.single_action_space.n)
critic = Critic()
critic = Critic()
network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))
network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))
agent_state = TrainState.create(
agent_state = TrainState.create(
apply_fn=None,
apply_fn=None,
params=AgentParams(
params=AgentParams(
network_params,
network_params,
actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
),
),
tx=optax.chain(
tx=optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
),
),
)
)
network.apply = jax.jit(network.apply)
network.apply = jax.jit(network.apply)
actor.apply = jax.jit(actor.apply)
actor.apply = jax.jit(actor.apply)
critic.apply = jax.jit(critic.apply)
critic.apply = jax.jit(critic.apply)


# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
storage = Storage(
storage = Storage(
obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape),
obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape),
actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, dtype=jnp.int32),
actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, dtype=jnp.int32),
logprobs=jnp.zeros((args.num_steps, args.num_envs)),
logprobs=jnp.zeros((args.num_steps, args.num_envs)),
dones=jnp.zeros((args.num_steps, args.num_envs)),
dones=jnp.zeros((args.num_steps, args.num_envs)),
values=jnp.zeros((args.num_steps, args.num_envs)),
values=jnp.zeros((args.num_steps, args.num_envs)),
advantages=jnp.zeros((args.num_steps, args.num_envs)),
advantages=jnp.zeros((args.num_steps, args.num_envs)),
returns=jnp.zeros((args.num_steps, args.num_envs)),
returns=jnp.zeros((args.num_steps, args.num_envs)),
rewards=jnp.zeros((args.num_steps, args.num_envs)),
rewards=jnp.zeros((args.num_steps, args.num_envs)),
truncations=jnp.zeros((args.num_steps, args.num_envs)),
)
)


@jax.jit
@jax.jit
def get_action_and_value(
def get_action_and_value(
agent_state: TrainState,
agent_state: TrainState,
next_obs: np.ndarray,
next_obs: np.ndarray,
next_done: np.ndarray,
next_done: np.ndarray,
next_truncated: np.ndarray,
storage: Storage,
storage: Storage,
step: int,
step: int,
key: jax.random.PRNGKey,
key: jax.random.PRNGKey,
):
):
"""sample action, calculate value, logprob, entropy, and update storage"""
"""sample action, calculate value, logprob, entropy, and update storage"""
hidden = network.apply(agent_state.params.network_params, next_obs)
hidden = network.apply(agent_state.params.network_params, next_obs)
logits = actor.apply(agent_state.params.actor_params, hidden)
logits = actor.apply(agent_state.params.actor_params, hidden)
# sample action: Gumbel-softmax trick
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
value = critic.apply(agent_state.params.critic_params, hidden)
value = critic.apply(agent_state.params.critic_params, hidden)
storage = storage.replace(
storage = storage.replace(
obs=storage.obs.at[step].set(next_obs),
obs=storage.obs.at[step].set(next_obs),
dones=storage.dones.at[step].set(next_done),
dones=storage.dones.at[step].set(next_done),
truncations=storage.truncations.at[step].set(next_truncated),
actions=storage.actions.at[step].set(action),
actions=storage.actions.at[step].set(action),
logprobs=storage.logprobs.at[step].set(logprob),
logprobs=storage.logprobs.at[step].set(logprob),
values=storage.values.at[step].set(value.squeeze()),
values=storage.values.at[step].set(value.squeeze()),
)
)
return storage, action, key
return storage, action, key


@jax.jit
@jax.jit
def get_action_and_value2(
def get_action_and_value2(
params: flax.core.FrozenDict,
params: flax.core.FrozenDict,
x: np.ndarray,
x: np.ndarray,
action: np.ndarray,
action: np.ndarray,
mask: np.ndarray,
):
):
"""calculate value, logprob of supplied `action`, and entropy"""
"""calculate value, logprob of supplied `action`, and entropy"""
hidden = network.apply(params.network_params, x)
hidden = network.apply(params.network_params, x)
logits = actor.apply(params.actor_params, hidden)
logits = actor.apply(params.actor_params, hidden)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
# normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
# normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)

# maks out truncated states during the learning pass so that they don't affect the loss
logits = jnp.where(mask.reshape((-1, 1)) * jnp.ones((1,4)), jnp.zeros_like(logits) -1e+8, logits)
p_log_p = logits * jax.nn.softmax(logits)
p_log_p = logits * jax.nn.softmax(logits)
entropy = -p_log_p.sum(-1)
entropy = -p_log_p.sum(-1)
value = critic.apply(params.critic_params, hidden).squeeze()
value = critic.apply(params.critic_params, hidden).squeeze()
return logprob, entropy, value
return logprob, entropy, value


@jax.jit
@jax.jit
def compute_gae(
def compute_gae(
agent_state: TrainState,
agent_state: TrainState,
next_obs: np.ndarray,
next_obs: np.ndarray,
next_done: np.ndarray,
next_done: np.ndarray,
storage: Storage,
storage: Storage,
):
):
storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))
storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))
next_value = critic.apply(
next_value = critic.apply(
agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs)
agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs)
).squeeze()
).squeeze()
lastgaelam = 0
lastgaelam = 0
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
nextvalues = next_value
nextvalues = next_value
else:
else:
nextnonterminal = 1.0 - storage.dones[t + 1]
nextnonterminal = 1.0 - storage.dones[t + 1]
nextvalues = storage.values[t + 1]
nextvalues = storage.values[t + 1]
delta = storage.rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t]
delta = storage.rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t]
lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))
storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))
storage = storage.replace(returns=storage.advantages + storage.values)
storage = storage.replace(returns=storage.advantages + storage.values)
return storage
return storage


@jax.jit
@jax.jit
def update_ppo(
def update_ppo(
agent_state: TrainState,
agent_state: TrainState,
storage: Storage,
storage: Storage,
key: jax.random.PRNGKey,
key: jax.random.PRNGKey,
):
):
# handle truncated trajectories
storage = storage.replace(rewards=jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards))
b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape)
b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = storage.logprobs.reshape(-1)
b_logprobs = storage.logprobs.reshape(-1)
b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape)
b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape)
b_truncations = storage.truncations.reshape(-1)
b_advantages = storage.advantages.reshape(-1)
b_advantages = storage.advantages.reshape(-1)
b_returns = storage.returns.reshape(-1)
b_returns = storage.returns.reshape(-1)


def ppo_loss(params, x, a, logp, mb_advantages, mb_returns):
def ppo_loss(params, x, a, truncation_mask, logp, mb_advantages, mb_returns):
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, truncation_mask)
logratio = newlogprob - logp
logratio = newlogprob - logp
ratio = jnp.exp(logratio)
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()


if args.norm_adv:
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)


# Policy loss
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()


# Value loss
# Value loss
v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()
v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()


entropy_loss = entropy.mean()
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))


ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
for _ in range(args.update_epochs):
for _ in range(args.update_epochs):
key, subkey = jax.random.split(key)
key, subkey = jax.random.split(key)
b_inds = jax.random.permutation(subkey, args.batch_size, independent=True)
b_inds = jax.random.permutation(subkey, args.batch_size, independent=True)
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
mb_inds = b_inds[start:end]
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params,
agent_state.params,
b_obs[mb_inds],
b_obs[mb_inds],
b_actions[mb_inds],
b_actions[mb_inds],
b_truncations[mb_inds],
b_logprobs[mb_inds],
b_logprobs[mb_inds],
b_advantages[mb_inds],
b_advantages[mb_inds],
b_returns[mb_inds],
b_returns[mb_inds],
)
)
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key


# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
start_time = time.time()
start_time = time.time()
next_obs = envs.reset()
next_obs = envs.reset()
next_done = np.zeros(args.num_envs)
next_done = np.zeros(args.num_envs)
next_truncated = np.zeros(args.num_envs)


@jax.jit
@jax.jit
def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step):
def rollout(agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step):
for step in range(0, args.num_steps):
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
global_step += 1 * args.num_envs
storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
storage, action, key = get_action_and_value(agent_state, next_obs, next_done, next_truncated, storage, step, key)


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
episode_stats, handle, (next_obs, reward, next_done, _) = step_env_wrappeed(episode_stats, handle, action)
episode_stats, handle, (next_obs, reward, next_done, info) = step_env_wrappeed(episode_stats, handle, action)
storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
storage = storage.replace(
return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
rewards=storage.rewards.at[step].set(reward),
)
next_truncated = info["TimeLimit.truncated"]
return agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step


for update in range(1, args.num_updates + 1):
for update in range(1, args.num_updates + 1):
update_time_start = time.time()
update_time_start = time.time()
agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout(
agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step = rollout(
agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step
)
)
storage = compute_gae(agent_state, next_obs, next_done, storage)
storage = compute_gae(agent_state, next_obs, next_done, storage)

agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo(
agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo(
agent_state,
agent_state,
storage,
storage,
key,
key,
)
)
if storage.rewards.sum() > 0 and storage.truncations.sum() > 0:
print("storage.dones.flatten():\n", storage.dones.flatten())
print("storage.truncations.flatten():\n", storage.truncations.flatten())
print("storage.rewards.flatten():\n", storage.rewards.flatten())
print("storage.values.flatten():\n", storage.values.flatten())
print("NOTE: bootstrap value as below:")
print("jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards).flatten():\n", jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards).flatten())
raise

avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns))
avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns))
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")


# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar(
writer.add_scalar(
"charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step
"charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step
)
)
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step)
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
# writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
# writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
# writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
# writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/loss", loss.item(), global_step)
# writer.add_scalar("losses/loss", loss.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
writer.add_scalar(
writer.add_scalar(
"charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
"charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
)
)


envs.close()
envs.close()
writer.close()
writer.close()