diff --git a/models/swiftformer.py b/models/swiftformer.py index 1b74936..b545557 100644 --- a/models/swiftformer.py +++ b/models/swiftformer.py @@ -437,7 +437,7 @@ class SwiftFormer(nn.Module): if not self.training: cls_out = (cls_out[0] + cls_out[1]) / 2 else: - cls_out = self.head(x.mean(-2)) + cls_out = self.head(x.flatten(2).mean(-1)) # For image classification return cls_out