test modify swiftformer to temporal input
This commit is contained in:
60
test_model.py
Normal file
60
test_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for SwiftFormerTemporal model
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add current directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS
|
||||
|
||||
def test_model():
|
||||
print("Testing SwiftFormerTemporal model...")
|
||||
|
||||
# Create model
|
||||
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
|
||||
print(f'Model created: {model.__class__.__name__}')
|
||||
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 2
|
||||
num_frames = 3
|
||||
height = width = 224
|
||||
x = torch.randn(batch_size, 3 * num_frames, height, width)
|
||||
|
||||
print(f'\nInput shape: {x.shape}')
|
||||
|
||||
with torch.no_grad():
|
||||
pred_frame, representation = model(x)
|
||||
|
||||
print(f'Predicted frame shape: {pred_frame.shape}')
|
||||
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
|
||||
|
||||
# Check output ranges
|
||||
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
|
||||
|
||||
# Test loss function
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
criterion = MultiTaskLoss()
|
||||
target = torch.randn_like(pred_frame)
|
||||
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
|
||||
|
||||
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
|
||||
print(f'\nLoss test:')
|
||||
for k, v in loss_dict.items():
|
||||
print(f' {k}: {v:.4f}')
|
||||
|
||||
print('\nAll tests passed!')
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
test_model()
|
||||
except Exception as e:
|
||||
print(f'Test failed with error: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user