trace_model.py 263 B

12345678910111213
  1. import os.path as osp
  2. import torch
  3. import torchvision
  4. HERE = osp.dirname(osp.abspath(__file__))
  5. ASSETS = osp.dirname(osp.dirname(HERE))
  6. model = torchvision.models.resnet18()
  7. model.eval()
  8. traced_model = torch.jit.script(model)
  9. traced_model.save("resnet18.pt")