smoke_test.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """Run smoke tests"""
  2. import sys
  3. from pathlib import Path
  4. import torch
  5. import torchvision
  6. from torchvision.io import decode_jpeg, read_file, read_image
  7. from torchvision.models import resnet50, ResNet50_Weights
  8. SCRIPT_DIR = Path(__file__).parent
  9. def smoke_test_torchvision() -> None:
  10. print(
  11. "Is torchvision usable?",
  12. all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
  13. )
  14. def smoke_test_torchvision_read_decode() -> None:
  15. img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
  16. if img_jpg.shape != (3, 606, 517):
  17. raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
  18. img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
  19. if img_png.shape != (4, 471, 354):
  20. raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
  21. def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
  22. img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
  23. img_jpg = decode_jpeg(img_jpg_data, device=device)
  24. if img_jpg.shape != (3, 606, 517):
  25. raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
  26. def smoke_test_compile() -> None:
  27. try:
  28. model = resnet50().cuda()
  29. model = torch.compile(model)
  30. x = torch.randn(1, 3, 224, 224, device="cuda")
  31. out = model(x)
  32. print(f"torch.compile model output: {out.shape}")
  33. except RuntimeError:
  34. if sys.platform == "win32":
  35. print("Successfully caught torch.compile RuntimeError on win")
  36. elif sys.version_info >= (3, 11, 0):
  37. print("Successfully caught torch.compile RuntimeError on Python 3.11")
  38. else:
  39. raise
  40. def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
  41. img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
  42. # Step 1: Initialize model with the best available weights
  43. weights = ResNet50_Weights.DEFAULT
  44. model = resnet50(weights=weights).to(device)
  45. model.eval()
  46. # Step 2: Initialize the inference transforms
  47. preprocess = weights.transforms()
  48. # Step 3: Apply inference preprocessing transforms
  49. batch = preprocess(img).unsqueeze(0)
  50. # Step 4: Use the model and print the predicted category
  51. prediction = model(batch).squeeze(0).softmax(0)
  52. class_id = prediction.argmax().item()
  53. score = prediction[class_id].item()
  54. category_name = weights.meta["categories"][class_id]
  55. expected_category = "German shepherd"
  56. print(f"{category_name} ({device}): {100 * score:.1f}%")
  57. if category_name != expected_category:
  58. raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
  59. def main() -> None:
  60. print(f"torchvision: {torchvision.__version__}")
  61. print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
  62. # Turn 1.11.0aHASH into 1.11 (major.minor only)
  63. version = ".".join(torchvision.__version__.split(".")[:2])
  64. if version >= "0.16":
  65. print(f"{torch.ops.image._jpeg_version() = }")
  66. assert torch.ops.image._is_compiled_against_turbo()
  67. smoke_test_torchvision()
  68. smoke_test_torchvision_read_decode()
  69. smoke_test_torchvision_resnet50_classify()
  70. smoke_test_torchvision_decode_jpeg()
  71. if torch.cuda.is_available():
  72. smoke_test_torchvision_decode_jpeg("cuda")
  73. smoke_test_torchvision_resnet50_classify("cuda")
  74. smoke_test_compile()
  75. if torch.backends.mps.is_available():
  76. smoke_test_torchvision_resnet50_classify("mps")
  77. if __name__ == "__main__":
  78. main()