import torch import torch.nn as nn class ConvertBCHWtoCBHW(nn.Module): """Convert tensor from (B, C, H, W) to (C, B, H, W)""" def forward(self, vid: torch.Tensor) -> torch.Tensor: return vid.permute(1, 0, 2, 3)