Merge pull request #12 from ThomasCai/main

Fix the issue when the distillation type is set to none.
This commit is contained in:
Abdelrahman Shaker
2023-11-30 15:41:26 +04:00
committed by GitHub

View File

@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2
else:
cls_out = self.head(x.mean(-2))
cls_out = self.head(x.flatten(2).mean(-1))
# For image classification
return cls_out