presets.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import transforms as T
  3. class OpticalFlowPresetEval(torch.nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. self.transforms = T.Compose(
  7. [
  8. T.PILToTensor(),
  9. T.ConvertImageDtype(torch.float32),
  10. T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
  11. T.ValidateModelInput(),
  12. ]
  13. )
  14. def forward(self, img1, img2, flow, valid):
  15. return self.transforms(img1, img2, flow, valid)
  16. class OpticalFlowPresetTrain(torch.nn.Module):
  17. def __init__(
  18. self,
  19. *,
  20. # RandomResizeAndCrop params
  21. crop_size,
  22. min_scale=-0.2,
  23. max_scale=0.5,
  24. stretch_prob=0.8,
  25. # AsymmetricColorJitter params
  26. brightness=0.4,
  27. contrast=0.4,
  28. saturation=0.4,
  29. hue=0.5 / 3.14,
  30. # Random[H,V]Flip params
  31. asymmetric_jitter_prob=0.2,
  32. do_flip=True,
  33. ):
  34. super().__init__()
  35. transforms = [
  36. T.PILToTensor(),
  37. T.AsymmetricColorJitter(
  38. brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
  39. ),
  40. T.RandomResizeAndCrop(
  41. crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
  42. ),
  43. ]
  44. if do_flip:
  45. transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
  46. transforms += [
  47. T.ConvertImageDtype(torch.float32),
  48. T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
  49. T.RandomErasing(max_erase=2),
  50. T.MakeValidFlowMask(),
  51. T.ValidateModelInput(),
  52. ]
  53. self.transforms = T.Compose(transforms)
  54. def forward(self, img1, img2, flow, valid):
  55. return self.transforms(img1, img2, flow, valid)