trace_model.py 338 B

12345678910111213
  1. import os.path as osp
  2. import torch
  3. import torchvision
  4. HERE = osp.dirname(osp.abspath(__file__))
  5. ASSETS = osp.dirname(osp.dirname(HERE))
  6. model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
  7. model.eval()
  8. traced_model = torch.jit.script(model)
  9. traced_model.save("fasterrcnn_resnet50_fpn.pt")