preprocess-bench.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import argparse
  2. import os
  3. from timeit import default_timer as timer
  4. import torch
  5. import torch.utils.data
  6. import torchvision
  7. import torchvision.datasets as datasets
  8. import torchvision.transforms as transforms
  9. from torch.utils.model_zoo import tqdm
  10. parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
  11. parser.add_argument("--data", metavar="PATH", required=True, help="path to dataset")
  12. parser.add_argument(
  13. "--nThreads", "-j", default=2, type=int, metavar="N", help="number of data loading threads (default: 2)"
  14. )
  15. parser.add_argument(
  16. "--batchSize", "-b", default=256, type=int, metavar="N", help="mini-batch size (1 = pure stochastic) Default: 256"
  17. )
  18. parser.add_argument("--accimage", action="store_true", help="use accimage")
  19. if __name__ == "__main__":
  20. args = parser.parse_args()
  21. if args.accimage:
  22. torchvision.set_image_backend("accimage")
  23. print(f"Using {torchvision.get_image_backend()}")
  24. # Data loading code
  25. transform = transforms.Compose(
  26. [
  27. transforms.RandomSizedCrop(224),
  28. transforms.RandomHorizontalFlip(),
  29. transforms.PILToTensor(),
  30. transforms.ConvertImageDtype(torch.float),
  31. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  32. ]
  33. )
  34. traindir = os.path.join(args.data, "train")
  35. valdir = os.path.join(args.data, "val")
  36. train = datasets.ImageFolder(traindir, transform)
  37. val = datasets.ImageFolder(valdir, transform)
  38. train_loader = torch.utils.data.DataLoader(
  39. train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads
  40. )
  41. train_iter = iter(train_loader)
  42. start_time = timer()
  43. batch_count = 20 * args.nThreads
  44. with tqdm(total=batch_count) as pbar:
  45. for _ in tqdm(range(batch_count)):
  46. pbar.update(1)
  47. batch = next(train_iter)
  48. end_time = timer()
  49. print(
  50. "Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
  51. " {image:.2f} ms/image {rate:.0f} images/sec".format(
  52. dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
  53. batch=(end_time - start_time) / float(batch_count) * 1.0e3,
  54. image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e3,
  55. rate=(batch_count * args.batchSize) / (end_time - start_time),
  56. )
  57. )