DSv1 to DSv2

Created Diff never expires
18 removals
60 lines
16 additions
58 lines
def _exec_send_grads(self, buffer_id):
def _exec_send_grads(self, buffer_id):
if self.wall_clock_breakdown():
if self.wall_clock_breakdown():
self.timers('pipe_send_grad').start()
self.timers('pipe_send_grad').start()


inputs = self.pipe_buffers['inputs'][buffer_id]
inputs = self.pipe_buffers['inputs'][buffer_id]


# Partition the gradient
# Partition the gradient
if self.is_grad_partitioned:
if self.is_grad_partitioned:

if isinstance(inputs, tuple):

first_input = inputs[0]

assert all([torch.is_tensor(elt) for elt in inputs[1:]])

inputs_grad_tail = [

elt.grad for elt in inputs[1:] if elt.grad is not None

]

elif torch.is_tensor(inputs):

first_input = inputs

inputs_grad_tail = []

else:

raise ValueError("expecting a tensor or a tuple of tensors")

assert torch.is_tensor(first_input)
part = PartitionedTensor(tensor=inputs[0].grad,
part = PartitionedTensor(tensor=first_input.grad,
group=self.grid.get_slice_parallel_group())
group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
# Inject the partitoned tensor into the output before sending


inputs = tuple([part.to_meta(), part.data(), inputs[1]])
inputs = (part.to_meta(), part.data(), *inputs_grad_tail)


# XXX Terrible hack
# XXX Terrible hack
# Drop the attention mask from the input buffer here. It does not have
# Drop the attention mask from the input buffer here. It does not have
# a grad that needs to be communicated. We free the buffer immediately
# a grad that needs to be communicated. We free the buffer immediately
# after, so no need to restore it. The receiver also has a hack that skips
# after, so no need to restore it. The receiver also has a hack that skips
# the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
# the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
if self.module.__class__.__name__ == 'GPT2ModelPipe':
if self.has_attention_mask or self.has_bool_tensors:
inputs = list(inputs)
inputs = list(inputs)
inputs.pop()
inputs.pop()
inputs = tuple(inputs)
inputs = tuple(inputs)


if isinstance(inputs, torch.Tensor):
if isinstance(inputs, torch.Tensor):
assert inputs.grad is not None
assert inputs.grad is not None
p2p.send(inputs.grad, self.prev_stage)
p2p.send(inputs.grad, self.prev_stage)
else:
else:
# XXX terrible hacky branch
# XXX terrible hacky branch
if self.is_grad_partitioned:
if self.is_grad_partitioned:
# First two sends are partitioned gradient
# First two sends are partitioned gradient
p2p.send(inputs[0], self.prev_stage)
p2p.send(inputs[0], self.prev_stage)
p2p.send(inputs[1], self.prev_stage)
p2p.send(inputs[1], self.prev_stage)
else:
else:
for idx, buffer in enumerate(inputs):
for idx, buffer in enumerate(inputs):
# Skip tensors that will not produce a grad
# Skip tensors that will not produce a grad
if not buffer.is_floating_point():
if not buffer.is_floating_point():
assert buffer.grad is None
assert buffer.grad is None
continue
continue
assert buffer.grad is not None
assert buffer.grad is not None
p2p.send(buffer.grad, self.prev_stage)
p2p.send(buffer.grad, self.prev_stage)


# We can free up the input buffer now
# We can free up the input buffer now
self.pipe_buffers['inputs'][buffer_id] = None
self.pipe_buffers['inputs'][buffer_id] = None


if self.wall_clock_breakdown():
if self.wall_clock_breakdown():
self.timers('pipe_send_grad').stop()
self.timers('pipe_send_grad').stop()