main.cpp 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #include <iostream>
  2. #include <torch/script.h>
  3. #include <torch/torch.h>
  4. #include <torchvision/vision.h>
  5. int main() {
  6. torch::DeviceType device_type;
  7. device_type = torch::kCPU;
  8. torch::jit::script::Module model;
  9. try {
  10. std::cout << "Loading model\n";
  11. // Deserialize the ScriptModule from a file using torch::jit::load().
  12. model = torch::jit::load("resnet18.pt");
  13. std::cout << "Model loaded\n";
  14. } catch (const torch::Error& e) {
  15. std::cout << "error loading the model\n";
  16. return -1;
  17. } catch (const std::exception& e) {
  18. std::cout << "Other error: " << e.what() << "\n";
  19. return -1;
  20. }
  21. // TorchScript models require a List[IValue] as input
  22. std::vector<torch::jit::IValue> inputs;
  23. // Create a random input tensor and run it through the model.
  24. inputs.push_back(torch::rand({1, 3, 10, 10}));
  25. auto out = model.forward(inputs);
  26. std::cout << out << "\n";
  27. if (torch::cuda::is_available()) {
  28. // Move model and inputs to GPU
  29. model.to(torch::kCUDA);
  30. // Add GPU inputs
  31. inputs.clear();
  32. torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
  33. inputs.push_back(torch::rand({1, 3, 10, 10}, options));
  34. auto gpu_out = model.forward(inputs);
  35. std::cout << gpu_out << "\n";
  36. }
  37. }