opt_state_diff
4 removals
171 lines
4 additions
171 lines
(EmptyState(), (ScaleByAdamState(count=Array(0, dtype=int32), mu=React(
(EmptyState(), (ScaleByAdamState(count=Array(130, dtype=int32), mu=React(
max_iters=10,
max_iters=10,
bottleneck=128,
bottleneck=128,
SEQLEN=16,
SEQLEN=16,
embed_dim=128,
embed_dim=128,
input_proj=LinearProj(
input_proj=LinearProj(
bias=f32[128],
bias=f32[128],
weight=f32[128,128],
weight=f32[128,128],
input_dim=128,
input_dim=128,
output_dim=128,
output_dim=128,
use_bias=True
use_bias=True
),
),
input_act=NewGELU(),
input_act=NewGELU(),
out_head=LinearProj(
out_head=LinearProj(
bias=f32[2],
bias=f32[2],
weight=f32[128,2],
weight=f32[128,2],
input_dim=128,
input_dim=128,
output_dim=2,
output_dim=2,
use_bias=True
use_bias=True
),
),
embed_layer=Embedding(num_embeddings=8, embedding_size=128, weight=f32[8,128]),
embed_layer=Embedding(num_embeddings=8, embedding_size=128, weight=f32[8,128]),
main_block=RecurrentModule(
main_block=RecurrentModule(
gelu=NewGELU(),
gelu=NewGELU(),
reshape_layer=LinearProj(
reshape_layer=LinearProj(
bias=f32[128],
bias=f32[128],
weight=f32[256,128],
weight=f32[256,128],
input_dim=256,
input_dim=256,
output_dim=128,
output_dim=128,
use_bias=True
use_bias=True
),
),
key=u32[2],
key=None,
attention_blocks=[
attention_blocks=[
AttentionBlock(
AttentionBlock(
activation=NewGELU(),
activation=NewGELU(),
attn_gate=LiteAttention(
attn_gate=LiteAttention(
input_dim=256,
input_dim=256,
weight=LinearProj(
weight=LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=False
use_bias=False
)
)
),
),
ln1=LayerNorm(
ln1=LayerNorm(
shape=256,
shape=256,
eps=1e-05,
eps=1e-05,
use_weight=True,
use_weight=True,
use_bias=True,
use_bias=True,
weight=f32[256],
weight=f32[256],
bias=f32[256]
bias=f32[256]
),
),
ln2=LayerNorm(
ln2=LayerNorm(
shape=256,
shape=256,
eps=1e-05,
eps=1e-05,
use_weight=True,
use_weight=True,
use_bias=True,
use_bias=True,
weight=f32[256],
weight=f32[256],
bias=f32[256]
bias=f32[256]
),
),
mlp=MLP(
mlp=MLP(
layers=[
layers=[
LinearProj(
LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=True
use_bias=True
),
),
Lambda(fn=NewGELU()),
Lambda(fn=NewGELU()),
LinearProj(
LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=True
use_bias=True
)
)
],
],
dropout=Dropout(p=None, inference=None)
dropout=Dropout(p=None, inference=None)
)
)
)
)
]
]
),
),
id=Identity(),
id=Identity(),
pos_enc=f32[16,128]
pos_enc=f32[16,128]
), nu=React(
), nu=React(
max_iters=10,
max_iters=10,
bottleneck=128,
bottleneck=128,
SEQLEN=16,
SEQLEN=16,
embed_dim=128,
embed_dim=128,
input_proj=LinearProj(
input_proj=LinearProj(
bias=f32[128],
bias=f32[128],
weight=f32[128,128],
weight=f32[128,128],
input_dim=128,
input_dim=128,
output_dim=128,
output_dim=128,
use_bias=True
use_bias=True
),
),
input_act=NewGELU(),
input_act=NewGELU(),
out_head=LinearProj(
out_head=LinearProj(
bias=f32[2],
bias=f32[2],
weight=f32[128,2],
weight=f32[128,2],
input_dim=128,
input_dim=128,
output_dim=2,
output_dim=2,
use_bias=True
use_bias=True
),
),
embed_layer=Embedding(num_embeddings=8, embedding_size=128, weight=f32[8,128]),
embed_layer=Embedding(num_embeddings=8, embedding_size=128, weight=f32[8,128]),
main_block=RecurrentModule(
main_block=RecurrentModule(
gelu=NewGELU(),
gelu=NewGELU(),
reshape_layer=LinearProj(
reshape_layer=LinearProj(
bias=f32[128],
bias=f32[128],
weight=f32[256,128],
weight=f32[256,128],
input_dim=256,
input_dim=256,
output_dim=128,
output_dim=128,
use_bias=True
use_bias=True
),
),
key=u32[2],
key=None,
attention_blocks=[
attention_blocks=[
AttentionBlock(
AttentionBlock(
activation=NewGELU(),
activation=NewGELU(),
attn_gate=LiteAttention(
attn_gate=LiteAttention(
input_dim=256,
input_dim=256,
weight=LinearProj(
weight=LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=False
use_bias=False
)
)
),
),
ln1=LayerNorm(
ln1=LayerNorm(
shape=256,
shape=256,
eps=1e-05,
eps=1e-05,
use_weight=True,
use_weight=True,
use_bias=True,
use_bias=True,
weight=f32[256],
weight=f32[256],
bias=f32[256]
bias=f32[256]
),
),
ln2=LayerNorm(
ln2=LayerNorm(
shape=256,
shape=256,
eps=1e-05,
eps=1e-05,
use_weight=True,
use_weight=True,
use_bias=True,
use_bias=True,
weight=f32[256],
weight=f32[256],
bias=f32[256]
bias=f32[256]
),
),
mlp=MLP(
mlp=MLP(
layers=[
layers=[
LinearProj(
LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=True
use_bias=True
),
),
Lambda(fn=NewGELU()),
Lambda(fn=NewGELU()),
LinearProj(
LinearProj(
bias=f32[256],
bias=f32[256],
weight=f32[256,256],
weight=f32[256,256],
input_dim=256,
input_dim=256,
output_dim=256,
output_dim=256,
use_bias=True
use_bias=True
)
)
],
],
dropout=Dropout(p=None, inference=None)
dropout=Dropout(p=None, inference=None)
)
)
)
)
]
]
),
),
id=Identity(),
id=Identity(),
pos_enc=f32[16,128]
pos_enc=f32[16,128]
)), EmptyState(), ScaleByScheduleState(count=Array(0, dtype=int32))))
)), EmptyState(), ScaleByScheduleState(count=Array(130, dtype=int32))))