norm.py 278 B

12345678910111213
  1. import torch
  2. def freeze_batch_norm(model):
  3. for m in model.modules():
  4. if isinstance(m, torch.nn.BatchNorm2d):
  5. m.eval()
  6. def unfreeze_batch_norm(model):
  7. for m in model.modules():
  8. if isinstance(m, torch.nn.BatchNorm2d):
  9. m.train()