test_custom_operators.cpp 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #include <gtest/gtest.h>
  2. #include <torch/script.h>
  3. #include <torch/torch.h>
  4. // FIXME: the include path differs from OSS due to the extra csrc
  5. #include <torchvision/csrc/ops/nms.h>
  6. TEST(test_custom_operators, nms) {
  7. // make sure that the torchvision ops are visible to the jit interpreter
  8. auto& ops = torch::jit::getAllOperatorsFor(torch::jit::Symbol::fromQualString("torchvision::nms"));
  9. ASSERT_EQ(ops.size(), 1);
  10. auto& op = ops.front();
  11. ASSERT_EQ(op->schema().name(), "torchvision::nms");
  12. torch::jit::Stack stack;
  13. at::Tensor boxes = at::rand({50, 4}), scores = at::rand({50});
  14. double thresh = 0.7;
  15. torch::jit::push(stack, boxes, scores, thresh);
  16. op->getOperation()(stack);
  17. at::Tensor output_jit;
  18. torch::jit::pop(stack, output_jit);
  19. at::Tensor output = vision::ops::nms(boxes, scores, thresh);
  20. ASSERT_TRUE(output_jit.allclose(output));
  21. }
  22. TEST(test_custom_operators, roi_align_visible) {
  23. // make sure that the torchvision ops are visible to the jit interpreter even if
  24. // not explicitly included
  25. auto& ops = torch::jit::getAllOperatorsFor(torch::jit::Symbol::fromQualString("torchvision::roi_align"));
  26. ASSERT_EQ(ops.size(), 1);
  27. auto& op = ops.front();
  28. ASSERT_EQ(op->schema().name(), "torchvision::roi_align");
  29. torch::jit::Stack stack;
  30. float roi_data[] = {
  31. 0., 0., 0., 5., 5.,
  32. 0., 5., 5., 10., 10.
  33. };
  34. at::Tensor input = at::rand({1, 2, 10, 10}), rois = at::from_blob(roi_data, {2, 5});
  35. double spatial_scale = 1.0;
  36. int64_t pooled_height = 3, pooled_width = 3, sampling_ratio = -1;
  37. bool aligned = true;
  38. torch::jit::push(stack, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
  39. op->getOperation()(stack);
  40. at::Tensor output_jit;
  41. torch::jit::pop(stack, output_jit);
  42. ASSERT_EQ(output_jit.sizes()[0], 2);
  43. ASSERT_EQ(output_jit.sizes()[1], 2);
  44. ASSERT_EQ(output_jit.sizes()[2], 3);
  45. ASSERT_EQ(output_jit.sizes()[3], 3);
  46. }