Compare commits

...

19 Commits

Author SHA1 Message Date
543beefa2a 删除残差路径和shortcut,镜像问题仍存在 2026-01-16 15:21:47 +08:00
a92a0b29e9 更新模型结构,大步长反卷积后移,启用BN和tanh 2026-01-15 21:12:27 +08:00
df703638da 清理代码,删除跳连接部分 2026-01-11 13:25:34 +08:00
c5502cc87c 修改梯度裁剪的恶性bug,当前能进行训练,但是无论是否使用跳连接,预测帧总是输出对称的的效果,mse收敛到0.10 2026-01-11 10:50:11 +08:00
12de74f130 完善了跳连接,在上decode块后增加特征精炼层,未测效果 2026-01-09 18:23:45 +08:00
500c2eb18f 更新归一化方式,当前直接映射,不利用均值标准差进行标准化 2026-01-08 16:10:24 +08:00
f7601e9170 初步可跑通,但loss计算有问题,不收敛 2026-01-08 09:43:23 +08:00
efd76bccd2 update .gitignore 2026-01-07 15:54:52 +08:00
4888619f9d iniit .gitignore 2026-01-07 15:54:20 +08:00
7e9564ef20 test modify swiftformer to temporal input 2026-01-07 11:03:33 +08:00
Abdelrahman Shaker
4aa6cd6752 Create LICENSE 2025-07-18 16:04:30 +04:00
Abdelrahman Shaker
898d23ca89 Update README.md 2024-01-12 17:00:03 +04:00
Abdelrahman Shaker
3daedbd499 Merge pull request #15 from escorciav/main
Update README.md
2024-01-12 16:41:43 +04:00
Victor Escorcia
28ce806f55 Update README.md
Community drive contributions: SwiftFormer meets Android. Qualcomm S8G2
DSP/HTP hardware, via Qualcomm tooling (QNN). Details in #14. Work done
by @3scorciav . Refer to his fork for details.
2024-01-12 10:27:15 +00:00
Abdelrahman Shaker
9b7df0d145 Merge pull request #12 from ThomasCai/main
Fix the issue when the distillation type is set to none.
2023-11-30 15:41:26 +04:00
caitianren
0ddadad723 Fix this bug when setting distillation-type to none 2023-11-29 20:15:00 +08:00
Abdelrahman Shaker
cd1f854e59 Update README.md 2023-10-02 21:54:23 +02:00
Abdelrahman Shaker
5c9b4ceece Update README.md 2023-08-17 21:23:06 +04:00
Abdelrahman Shaker
7d5ca0c25b Update README.md 2023-08-10 18:54:53 +04:00
11 changed files with 2118 additions and 23 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.vscode/
__pycache__/
venv/
runs/

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -6,7 +6,7 @@
Mohamed Bin Zayed University of Artificial Intelligence<sup>1</sup>, University of California Merced<sup>2</sup>, Google Research<sup>3</sup>, Linkoping University<sup>4</sup> Mohamed Bin Zayed University of Artificial Intelligence<sup>1</sup>, University of California Merced<sup>2</sup>, Google Research<sup>3</sup>, Linkoping University<sup>4</sup>
<!-- [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](site_url) --> <!-- [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](site_url) -->
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2303.15446) [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://openaccess.thecvf.com/content/ICCV2023/papers/Shaker_SwiftFormer_Efficient_Additive_Attention_for_Transformer-based_Real-time_Mobile_Vision_Applications_ICCV_2023_paper.pdf)
<!-- [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](youtube_link) --> <!-- [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](youtube_link) -->
<!-- [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](presentation) --> <!-- [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](presentation) -->
@@ -64,6 +64,28 @@ Self-attention has become a defacto choice for capturing global context in vario
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/). The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
### SwiftFormer meets Android
Community-driven results with [Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2](https://www.qualcomm.com/snapdragon/device-finder/samsung-galaxy-s23-ultra):
1. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export.py) & profiler results of [`SwiftFormer_L1`](./models/swiftformer.py):
| QNN | 2.16 | 2.17 | 2.18 |
| -------------- | -----| ----- | ------ |
| Latency (msec) | 2.63 | 2.26 | 2.43 |
2. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export_block.py) & profiler results of SwiftFormerEncoder block:
| QNN | 2.16 | 2.17 | 2.18 |
| -------------- | -----| ----- | ------ |
| Latency (msec) | 2.17 | 1.69 | 1.7 |
Refer to the script above for details of the input & block parameters.
_Interested in reproducing the results above?_
Refer to [Issue #14](https://github.com/Amshaker/SwiftFormer/issues/14) for details about [exporting & profiling.](https://github.com/Amshaker/SwiftFormer/issues/14#issuecomment-1883351728)
## ImageNet ## ImageNet
### Prerequisites ### Prerequisites
@@ -78,7 +100,7 @@ pip install timm
pip install coremltools==5.2.0 pip install coremltools==5.2.0
``` ```
### Data preparation ### Data Preparation
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively: Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
``` ```
@@ -87,7 +109,7 @@ Download and extract ImageNet train and val images from http://image-net.org. Th
|-- val |-- val
``` ```
### Single machine multi-GPU training ### Single-machine multi-GPU training
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP). We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
@@ -107,7 +129,7 @@ On a Slurm-managed cluster, multi-node training can be launched as
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
``` ```
Note: specify slurm specific paramters in `slurm_train.sh` script. Note: specify slurm specific parameters in `slurm_train.sh` script.
### Testing ### Testing
@@ -121,20 +143,22 @@ sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth
## Citation ## Citation
if you use our work, please consider citing us: if you use our work, please consider citing us:
```BibTeX ```BibTeX
@article{Shaker2023SwiftFormer, @InProceedings{Shaker_2023_ICCV,
title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz}, author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
journal={arXiv:2303.15446}, title = {SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
year={2023} booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2023},
} }
``` ```
## Contact: ## Contact:
If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae. If you have any questions, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
## Acknowledgement ## Acknowledgement
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation. Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank the authors for their open-source implementation.
I'd like to express my sincere appreciation to [Victor Escorcia](https://github.com/escorciav) for measuring & reporting the latency of SwiftFormer on Android (Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2). Check [SwiftFormer Meets Android](https://github.com/escorciav/SwiftFormer) for more details!
## Our Related Works ## Our Related Works

58
dist_temporal_train.sh Executable file
View File

@@ -0,0 +1,58 @@
#!/usr/bin/env bash
# Distributed training script for SwiftFormerTemporal
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
DATA_PATH=$1
NUM_GPUS=$2
# Shift arguments to pass remaining options to python script
shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-128}
EPOCHS=${EPOCHS:-100}
# LR=${LR:-1e-3}
LR=${LR:-0.01}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"
echo "Data path: $DATA_PATH"
echo "Model: $MODEL"
echo "Batch size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "Output dir: $OUTPUT_DIR"
# Check if torch.distributed.launch or torchrun should be used
# For newer PyTorch versions (>=1.9), torchrun is recommended
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
echo "PyTorch version: $PYTHON_VERSION"
# Use torchrun for newer PyTorch versions
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
echo "Using torchrun (PyTorch >=1.10)"
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
else
echo "Using torch.distributed.launch"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
fi
# For single-node multi-GPU training with specific options:
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
echo "Training completed. Check logs in $OUTPUT_DIR"

484
evaluate_temporal.py Normal file
View File

@@ -0,0 +1,484 @@
"""
评估脚本 for SwiftFormerTemporal frame prediction
输出预测图注意反归一化以及对应指标mse&ssim&psnr
"""
import argparse
import os
import torch
import torch.nn as nn
import pickle
import numpy as np
import random
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from util.video_dataset import VideoFrameDataset
from models.swiftformer_temporal import (
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
)
# 导入SSIM和PSNR计算
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')
def denormalize(tensor):
"""
将[-1, 1]范围的张量反归一化到[0, 255]范围
Args:
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
Returns:
反归一化后的张量,值在[0, 255]
"""
# clip 到 [-1, 1] 范围
tensor = tensor.clamp(-1, 1)
# [-1, 1] -> [0, 1]
tensor = (tensor + 1) / 2
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
"""
计算MSE, SSIM, PSNR指标
Args:
pred: 预测图像,形状[H, W],值在[0, 255]
target: 目标图像,形状[H, W],值在[0, 255]
debug: 是否输出调试信息
Returns:
mse, ssim_value, psnr_value
"""
# 转换为numpy数组
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
# 确保是2D数组
if pred_np.ndim == 3:
pred_np = pred_np.squeeze(0)
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
# if debug:
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
mse = np.mean((pred_np - target_np) ** 2)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
input_frame_indices=None, target_frame_index=None):
"""
保存对比图:输入帧、目标帧、预测帧
Args:
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
target_frame: 目标帧,形状[H, W],值在[0, 255]
pred_frame: 预测帧,形状[H, W],值在[0, 255]
save_path: 保存路径
input_frame_indices: 输入帧的索引列表(可选)
target_frame_index: 目标帧索引(可选)
"""
num_input = len(input_frames)
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
# 绘制输入帧
for i in range(num_input):
ax = axes[i]
ax.imshow(input_frames[i], cmap='gray')
if input_frame_indices is not None:
ax.set_title(f'Input Frame {input_frame_indices[i]}')
else:
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
ax = axes[num_input]
ax.imshow(target_frame, cmap='gray')
if target_frame_index is not None:
ax.set_title(f'Target Frame {target_frame_index}')
else:
ax.set_title('Target')
ax.axis('off')
# 绘制预测帧
ax = axes[num_input + 1]
ax.imshow(pred_frame, cmap='gray')
ax.set_title('Predicted')
ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def evaluate_model(model, data_loader, device, args):
"""
评估模型并计算指标
Args:
model: 训练好的模型
data_loader: 数据加载器
device: 设备
args: 命令行参数
Returns:
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
model.eval()
# model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
total_ssim = 0.0
total_psnr = 0.0
total_samples = 0
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
sample_results = []
max_samples_to_save = args.num_samples_to_save
max_samples = args.max_samples
# 用于蓄水池抽样的计数器已处理的样本数不包括因max_samples限制而跳过的样本
sample_count = 0
with torch.no_grad():
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames = model(input_frames)
# 反归一化用于指标计算
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
# 计算每个样本的指标
for i in range(batch_size):
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
break
pred_i = pred_denorm[i] # [1, H, W]
target_i = target_denorm[i] # [1, H, W]
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
# if debug_mode:
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse
total_ssim += ssim_value
total_psnr += psnr_value
total_samples += 1
sample_count += 1
# 构建样本数据字典
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
# 分离输入帧
input_frames_list = []
for j in range(args.num_frames):
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
input_frames_list.append(input_frame_j.cpu().numpy())
sample_data = {
'input_frames': input_frames_list,
'target_frame': target_i.squeeze(0).cpu().numpy(),
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
'metrics': {
'mse': mse,
'ssim': ssim_value,
'psnr': psnr_value
},
'batch_idx': batch_idx,
'sample_idx': i
}
# 蓄水池抽样 (Reservoir Sampling)
if sample_count <= max_samples_to_save:
# 蓄水池未满,直接加入
sample_results.append(sample_data)
else:
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
r = random.randint(0, sample_count - 1)
if r < max_samples_to_save:
sample_results[r] = sample_data
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
print(f"达到最大样本数限制: {max_samples}")
break
# 进度打印
if (batch_idx + 1) % 10 == 0:
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
# 计算平均指标
if total_samples > 0:
avg_mse = float(total_mse / total_samples)
avg_ssim = float(total_ssim / total_samples)
avg_psnr = float(total_psnr / total_samples)
else:
avg_mse = avg_ssim = avg_psnr = 0.0
metrics_dict = {
'mse': avg_mse,
'ssim': avg_ssim,
'psnr': avg_psnr,
'num_samples': total_samples
}
return metrics_dict, sample_results
def main(args):
print("评估参数:", args)
device = torch.device(args.device)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
cudnn.benchmark = True
# 构建数据集
print("构建数据集...")
dataset_val = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=False,
max_interval=args.max_interval
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
shuffle=False,
drop_last=False
)
# 创建模型
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"未知模型: {args.model}")
model.to(device)
# 加载检查点
if args.resume:
print(f"加载检查点: {args.resume}")
try:
# 尝试使用weights_only=False加载PyTorch 2.6+需要)
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 移除'module.'前缀(如果存在)
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict)
else:
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
if any(key.startswith('module.') for key in state_dict.keys()):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print(f"检查点加载成功epoch: {checkpoint.get('epoch', 'unknown')}")
else:
print("警告: 未提供检查点路径,使用随机初始化的模型")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 评估模型
print("开始评估...")
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
# 打印指标
print("\n" + "="*50)
print("评估结果:")
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim']:.6f}")
print(f"PSNR: {metrics['psnr']:.6f} dB")
print(f"样本数量: {metrics['num_samples']}")
print("="*50)
# 保存指标到JSON文件
metrics_file = output_dir / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"指标已保存到: {metrics_file}")
# 保存示例可视化
if sample_results:
print(f"\n保存 {len(sample_results)} 个示例可视化...")
samples_dir = output_dir / 'sample_predictions'
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(sample_results):
save_path = samples_dir / f'sample_{i:03d}.png'
# 生成输入帧索引(假设连续)
input_frame_indices = list(range(1, args.num_frames + 1))
target_frame_index = args.num_frames + 1
save_comparison_figure(
sample['input_frames'],
sample['target_frame'],
sample['pred_frame'],
save_path,
input_frame_indices=input_frame_indices,
target_frame_index=target_frame_index
)
# 保存该样本的指标
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
with open(sample_metrics_file, 'w') as f:
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
print(f"示例可视化已保存到: {samples_dir}")
# 生成汇总报告
report_file = output_dir / 'evaluation_report.txt'
with open(report_file, 'w') as f:
f.write("SwiftFormerTemporal 帧预测评估报告\n")
f.write("="*50 + "\n")
f.write(f"模型: {args.model}\n")
f.write(f"检查点: {args.resume}\n")
f.write(f"数据集: {args.data_path}\n")
f.write(f"输入帧数: {args.num_frames}\n")
f.write(f"帧大小: {args.frame_size}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"样本总数: {metrics['num_samples']}\n\n")
f.write("评估指标:\n")
f.write(f" MSE: {metrics['mse']:.6f}\n")
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
print(f"评估报告已保存到: {report_file}")
print("\n评估完成!")
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估脚本', add_help=False)
# 数据集参数
parser.add_argument('--data-path', default='./videos', type=str,
help='视频数据集路径')
parser.add_argument('--num-frames', default=3, type=int,
help='输入帧数 (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='输入帧大小')
parser.add_argument('--max-interval', default=4, type=int,
help='连续帧之间的最大间隔')
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,
help='评估批次大小')
parser.add_argument('--num-samples-to-save', default=10, type=int,
help='保存可视化的样本数量')
parser.add_argument('--max-samples', default=None, type=int,
help='最大评估样本数None表示全部')
# 系统参数
parser.add_argument('--output-dir', default='./evaluation_results',
help='保存结果的路径')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
help='使用的设备')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='检查点路径')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='在DataLoader中固定CPU内存')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
args = parser.parse_args()
# 确保输出目录存在
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

550
main_temporal.py Normal file
View File

@@ -0,0 +1,550 @@
"""
Main training script for SwiftFormerTemporal frame prediction
"""
import argparse
import datetime
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset
# from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', add_help=False)
# Dataset parameters
parser.add_argument('--data-path', default='./videos', type=str,
help='Path to video dataset')
parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'],
type=str, help='Dataset type')
parser.add_argument('--num-frames', default=3, type=int,
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=10, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
# Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
# parser.add_argument('--l1-weight', type=float, default=1.0,
# help='Weight for L1 loss')
# parser.add_argument('--ssim-weight', type=float, default=0.1,
# help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
help='Disable SSIM loss')
# System parameters
parser.add_argument('--output-dir', default='./output',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--num-workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# Distributed training
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# Build datasets
dataset_train = build_dataset(is_train=True, args=args)
dataset_val = build_dataset(is_train=False, args=args)
# Create samplers
if args.distributed:
sampler_train = torch.utils.data.DistributedSampler(dataset_train)
sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
# Create model
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"Unknown model: {args.model}")
model.to(device)
# Model EMA
model_ema = None
if hasattr(args, 'model_ema') and args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999,
device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '',
resume='')
# Distributed training
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {n_parameters}')
# Create optimizer
optimizer = create_optimizer(args, model_without_ddp)
# Create loss scaler
loss_scaler = NativeScaler()
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function - simple MSE for Y channel prediction
class MSELossWrapper(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, temporal_indices=None):
loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema is not None:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
model_ema=model_ema, writer=writer,
global_step=global_step, args=args
)
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema) if model_ema else None,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats to text file
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.amp.autocast(device_type='cuda'):
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
temporal_indices
)
loss_value = loss.item()
if not torch.isfinite(torch.tensor(loss_value)):
print(f"Loss is {loss_value}, stopping training")
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# # 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# # 检查一个BatchNorm层的运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# break
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad()
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.amp.autocast(device_type='cuda'):
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
temporal_indices
)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# # 第一个批次打印详细诊断信息
# if batch_idx == 0:
# print(f"[评估诊断] 批次 0:")
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# # 检查BatchNorm运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# if module.running_var[0].item() < 1e-6:
# print(f" 警告: BatchNorm运行方差接近零!")
# break
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -1 +1,7 @@
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3 from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
from .swiftformer_temporal import (
SwiftFormerTemporal_XS,
SwiftFormerTemporal_S,
SwiftFormerTemporal_L1,
SwiftFormerTemporal_L3
)

View File

@@ -6,9 +6,9 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models import register_model
from timm.models.layers.helpers import to_2tuple from timm.layers import to_2tuple
import einops import einops
SwiftFormer_width = { SwiftFormer_width = {
@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
if not self.training: if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2 cls_out = (cls_out[0] + cls_out[1]) / 2
else: else:
cls_out = self.head(x.mean(-2)) cls_out = self.head(x.flatten(2).mean(-1))
# For image classification # For image classification
return cls_out return cls_out

View File

@@ -0,0 +1,232 @@
"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False # 禁用bias因为使用BN
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1):
super().__init__()
# Define decoder dimensions independently (no skip connections)
start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList()
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
))
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
def forward(self, x):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
"""
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
pred_frame: predicted frame [B, 1, H, W] (or None)
"""
# Encode
x = self.patch_embed(x)
x = self.forward_tokens(x)
x = self.norm(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
pred_frame = self.decoder(x)
return pred_frame
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)

233
util/video_dataset.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Video frame dataset for temporal self-supervised learning
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, List
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class VideoFrameDataset(Dataset):
"""
Dataset for loading consecutive frames from videos for frame prediction.
Assumes directory structure:
dataset_root/
video1/
frame_0001.jpg
frame_0002.jpg
...
video2/
...
"""
def __init__(self,
root_dir: str,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True,
max_interval: int = 1,
transform=None):
"""
Args:
root_dir: Root directory containing video folders
num_frames: Number of input frames (T)
frame_size: Size to resize frames to
is_train: Whether this is training set (affects augmentation)
max_interval: Maximum interval between consecutive frames
transform: Optional custom transform
"""
self.root_dir = Path(root_dir)
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, frame_files in enumerate(self.video_frame_files):
# Minimum frames needed considering max interval
min_frames_needed = num_frames * max_interval + 1
if len(frame_files) < min_frames_needed:
continue # Skip videos with insufficient frames
# Add all possible starting positions
# Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
raise ValueError("No valid frame sequences found in dataset")
# Default transforms
if transform is None:
self.transform = self._default_transform()
else:
self.transform = transform
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Convert pixel values [0, 255] to [-1, 1]
# This matches the model's tanh output range
self.normalize = None # We'll handle normalization manually
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self):
"""Default transform with augmentation for training"""
if self.is_train:
return transforms.Compose([
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
else:
return transforms.Compose([
transforms.Resize(int(self.frame_size * 1.14)),
transforms.CenterCrop(self.frame_size),
])
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
frame_files = self.video_frame_files[video_idx]
if frame_idx < 0 or frame_idx >= len(frame_files):
raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
def __len__(self) -> int:
return len(self.frame_indices)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
# Determine frame interval (for temporal augmentation)
interval = random.randint(1, self.max_interval) if self.is_train else 1
# Load input frames
input_frames = []
for i in range(self.num_frames):
frame_idx = start_idx + i * interval
frame = self._load_frame(video_idx, frame_idx)
# Apply transform (same for all frames in sequence)
if self.transform:
frame = self.transform(frame)
input_frames.append(frame)
# Load target frame (next frame after input sequence)
target_idx = start_idx + self.num_frames * interval
target_frame = self._load_frame(video_idx, target_idx)
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_gray, temporal_idx
# class SyntheticVideoDataset(Dataset):
# """
# Synthetic dataset for testing - generates random frames
# """
# def __init__(self,
# num_samples: int = 1000,
# num_frames: int = 3,
# frame_size: int = 224,
# is_train: bool = True):
# self.num_samples = num_samples
# self.num_frames = num_frames
# self.frame_size = frame_size
# self.is_train = is_train
# # Normalization for Y channel (single channel)
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
# y_std = (0.229 + 0.224 + 0.225) / 3.0
# self.normalize = transforms.Normalize(
# mean=[y_mean],
# std=[y_std]
# )
# def __len__(self):
# return self.num_samples
# def __getitem__(self, idx):
# # Generate random "frames" (noise with temporal correlation)
# input_frames = []
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
# for i in range(self.num_frames):
# # Add some temporal correlation
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# frame = torch.clamp(frame, -1, 1)
# input_frames.append(self.normalize(frame))
# prev_frame = frame
# # Target frame (next in sequence)
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# target_frame = torch.clamp(target_frame, -1, 1)
# target_tensor = self.normalize(target_frame)
# # Concatenate inputs
# input_concatenated = torch.cat(input_frames, dim=0)
# # Temporal index
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
# return input_concatenated, target_tensor, temporal_idx

303
video_preprocessor.py Normal file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
视频预处理脚本 - 将MP4视频转换为224x224帧图像
支持多线程并发处理、进度条显示和中断恢复功能
"""
import os
import sys
import json
import argparse
import subprocess
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from typing import List, Dict, Optional
class VideoPreprocessor:
"""视频预处理器,支持多线程和中断恢复"""
def __init__(self,
input_dir: str,
output_dir: str,
frame_size: int = 224,
fps: int = 30,
num_workers: int = 4,
quality: int = 2,
resume: bool = True):
"""
初始化预处理器
Args:
input_dir: 输入视频目录
output_dir: 输出帧目录
frame_size: 帧大小(正方形)
fps: 提取帧率
num_workers: 并发工作线程数
quality: JPEG质量 (1-31, 数值越小质量越高)
resume: 是否启用中断恢复
"""
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.frame_size = frame_size
self.fps = fps
self.num_workers = num_workers
self.quality = quality
self.resume = resume
# 状态文件路径
self.state_file = self.output_dir / ".preprocessing_state.json"
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化状态
self.state = self._load_state()
# 收集所有视频文件
self.video_files = self._collect_video_files()
def _load_state(self) -> Dict:
"""加载处理状态"""
if self.resume and self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
print(f"警告: 无法读取状态文件,将重新开始处理")
return {
"completed": [],
"failed": [],
"total_processed": 0,
"start_time": None,
"last_update": None
}
def _save_state(self):
"""保存处理状态"""
self.state["last_update"] = time.time()
try:
with open(self.state_file, 'w') as f:
json.dump(self.state, f, indent=2)
except IOError as e:
print(f"警告: 无法保存状态文件: {e}")
def _collect_video_files(self) -> List[Path]:
"""收集所有需要处理的视频文件"""
video_files = []
for file_path in self.input_dir.glob("*.mp4"):
if file_path.name not in self.state["completed"]:
video_files.append(file_path)
return sorted(video_files)
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
"""解析视频文件名使用完整文件名作为ID"""
name_without_ext = video_path.stem
# 直接使用完整文件名作为ID确保每个mp4文件有独立的输出目录
return {
"video_id": name_without_ext,
"start_frame": "unknown",
"end_frame": "unknown",
"full_name": name_without_ext
}
def _extract_frames(self, video_path: Path) -> bool:
"""提取单个视频的帧"""
try:
# 解析视频名称
video_info = self._parse_video_name(video_path)
output_subdir = self.output_dir / video_info["video_id"]
output_subdir.mkdir(exist_ok=True)
# 构建FFmpeg命令
output_pattern = output_subdir / "frame_%04d.jpg"
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
"-q:v", str(self.quality),
"-y", # 覆盖输出文件
str(output_pattern)
]
# 执行FFmpeg命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
return False
# 验证输出帧数量
output_frames = list(output_subdir.glob("frame_*.jpg"))
if len(output_frames) == 0:
print(f"警告: {video_path.name} 没有生成任何帧")
return False
return True
except subprocess.TimeoutExpired:
print(f"超时处理 {video_path.name}")
return False
except Exception as e:
print(f"处理 {video_path.name} 时发生错误: {e}")
return False
def _process_video(self, video_path: Path) -> tuple[bool, str]:
"""处理单个视频文件"""
video_name = video_path.name
try:
success = self._extract_frames(video_path)
if success:
self.state["completed"].append(video_name)
if video_name in self.state["failed"]:
self.state["failed"].remove(video_name)
return True, video_name
else:
self.state["failed"].append(video_name)
return False, video_name
except Exception as e:
print(f"处理 {video_name} 时发生异常: {e}")
self.state["failed"].append(video_name)
return False, video_name
def process_all_videos(self):
"""处理所有视频文件"""
if not self.video_files:
print("没有找到需要处理的视频文件")
return
print(f"找到 {len(self.video_files)} 个待处理视频文件")
print(f"输出目录: {self.output_dir}")
print(f"帧大小: {self.frame_size}x{self.frame_size}")
print(f"帧率: {self.fps} fps")
print(f"并发线程数: {self.num_workers}")
if self.state["completed"]:
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
# 记录开始时间
if self.state["start_time"] is None:
self.state["start_time"] = time.time()
# 创建进度条
with tqdm(total=len(self.video_files), desc="处理视频", unit="") as pbar:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# 提交所有任务
future_to_video = {
executor.submit(self._process_video, video_path): video_path
for video_path in self.video_files
}
# 处理完成的任务
for future in as_completed(future_to_video):
video_path = future_to_video[future]
try:
success, video_name = future.result()
if success:
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
else:
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
except Exception as e:
print(f"处理 {video_path.name} 时发生异常: {e}")
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
pbar.update(1)
self.state["total_processed"] += 1
# 定期保存状态
if self.state["total_processed"] % 5 == 0:
self._save_state()
# 最终保存状态
self._save_state()
# 打印处理结果
self._print_summary()
def _print_summary(self):
"""打印处理摘要"""
print("\n" + "="*50)
print("处理完成摘要:")
print(f"总处理视频数: {len(self.state['completed'])}")
print(f"失败视频数: {len(self.state['failed'])}")
if self.state["failed"]:
print("\n失败的视频:")
for video_name in self.state["failed"]:
print(f" - {video_name}")
if self.state["start_time"]:
elapsed_time = time.time() - self.state["start_time"]
print(f"\n总耗时: {elapsed_time:.2f}")
if self.state["total_processed"] > 0:
avg_time = elapsed_time / self.state["total_processed"]
print(f"平均每个视频: {avg_time:.2f}")
print("="*50)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="视频预处理脚本")
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
args = parser.parse_args()
# 检查输入目录
if not Path(args.input_dir).exists():
print(f"错误: 输入目录不存在: {args.input_dir}")
sys.exit(1)
# 检查FFmpeg是否可用
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("错误: FFmpeg未安装或不在PATH中")
sys.exit(1)
# 创建预处理器并开始处理
preprocessor = VideoPreprocessor(
input_dir=args.input_dir,
output_dir=args.output_dir,
frame_size=args.size,
fps=args.fps,
num_workers=args.workers,
quality=args.quality,
resume=not args.no_resume
)
try:
preprocessor.process_all_videos()
except KeyboardInterrupt:
print("\n\n用户中断处理,状态已保存")
preprocessor._save_state()
print("可以使用相同命令恢复处理")
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
preprocessor._save_state()
sys.exit(1)
if __name__ == "__main__":
main()