Additive vs Vanilla
7 removals
Words removed | 18 |
Total words | 84 |
Words removed (%) | 21.43 |
20 lines
24 additions
Words added | 50 |
Total words | 116 |
Words added (%) | 43.10 |
28 lines
def _calculate_scores(self, query, key):
def _calculate_scores(self, query, key):
"""Calculates attention scores as a nonlinear sum of query and key.
"""Calculates attention scores as a query-key dot product.
Args:
Args:
query: Query tensor of shape `[batch_size, Tq, dim]`.
query: Query tensor of shape `[batch_size, Tq, dim]`.
key: Key tensor of shape `[batch_size, Tv, dim]`.
key: Key tensor of shape `[batch_size, Tv, dim]`.
Returns:
Returns:
Tensor of shape `[batch_size, Tq, Tv]`.
Tensor of shape `[batch_size, Tq, Tv]`.
"""
"""
# Reshape tensors to enable broadcasting.
if self.score_mode == "dot":
# Reshape into [batch_size, Tq, 1, dim].
scores = tf.matmul(query, key, transpose_b=True)
q_reshaped = tf.expand_dims(query, axis=-2)
if self.scale is not None:
# Reshape into [batch_size, 1, Tv, dim].
scores *= self.scale
k_reshaped = tf.expand_dims(key, axis=-3)
elif self.score_mode == "concat":
if self.use_scale:
# Reshape tensors to enable broadcasting.
scale = self.scale
# Reshape into [batch_size, Tq, 1, dim].
else:
q_reshaped = tf.expand_dims(query, axis=-2)
scale = 1.0
# Reshape into [batch_size, 1, Tv, dim].
return tf.reduce_sum(scale * tf.tanh(q_reshaped + k_reshaped), axis=-1)
k_reshaped = tf.expand_dims(key, axis=-3)
if self.scale is not None:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1
)
else:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(q_reshaped + k_reshaped), axis=-1
)
return scores