model_diff_help1
4 removals
Words removed | 4 |
Total words | 118 |
Words removed (%) | 3.39 |
26 lines
10 additions
Words added | 14 |
Total words | 128 |
Words added (%) | 10.94 |
31 lines
class Model(nn.Module):
class Model(nn.Module):
def forward(self, img1, img2):
def forward(self, img1, img2):
# Calculate the mean of the two input tensors
# Calculate the mean of the two input tensors
mean1 = torch.mean(img1, dim=0)
mean1 = torch.mean(img1, dim=1, keepdim=True)
mean2 = torch.mean(img2, dim=0)
mean2 = torch.mean(img2, dim=1, keepdim=True)
# Calculate the absolute difference between the two mean tensors
# Calculate the absolute difference between the two mean tensors
diff = torch.sqrt(torch.pow(mean1 - mean2, 2)).float()
diff = torch.sqrt(torch.pow(mean1 - mean2, 2)).float()
print(diff.shape)
print(diff.shape)
threshold = 30.0
threshold = 30.0
# Create a binary mask where differences are higher than the threshold
# Create a binary mask where differences are higher than the threshold
mask = torch.where(diff > threshold, torch.tensor(1.0), torch.tensor(0.0))
mask = torch.where(diff > threshold, torch.tensor(1.0), torch.tensor(0.0))
print(mask.shape)
print(mask.shape)
# Count the number of moving pixels
# Count the number of moving pixels
movingPx = torch.sum(mask)
movingPx = torch.sum(mask).view(1,1,1,1)
print(movingPx)
print(movingPx)
# Calculate the total number of pixels
# Calculate the total number of pixels
totalPx = torch.tensor(mask.shape[0] * mask.shape[1], dtype=torch.float32)
totalPx = torch.tensor(mask.shape[0] * mask.shape[1], dtype=torch.float32)
# Calculate the ratio of moving pixels to the total number of pixels
# Calculate the ratio of moving pixels to the total number of pixels
movingRatio = movingPx / totalPx
movingRatio = movingPx / totalPx
return movingRatio.unsqueeze(0) # Ensure the output is a tensor with an added dimension
return movingRatio # Ensure the output is a tensor with an added dimension
model = Model()
torch.onnx.export(model, (torch.randn(1,3,720,720), torch.randn(1,3,720,720)), "model_diff.onnx", opset_version=16)