Compare commits
19 Commits
28fd075488
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 543beefa2a | |||
| a92a0b29e9 | |||
| df703638da | |||
| c5502cc87c | |||
| 12de74f130 | |||
| 500c2eb18f | |||
| f7601e9170 | |||
| efd76bccd2 | |||
| 4888619f9d | |||
| 7e9564ef20 | |||
|
|
4aa6cd6752 | ||
|
|
898d23ca89 | ||
|
|
3daedbd499 | ||
|
|
28ce806f55 | ||
|
|
9b7df0d145 | ||
|
|
0ddadad723 | ||
|
|
cd1f854e59 | ||
|
|
5c9b4ceece | ||
|
|
7d5ca0c25b |
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
.vscode/
|
||||||
|
__pycache__/
|
||||||
|
venv/
|
||||||
|
runs/
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
44
README.md
44
README.md
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
Mohamed Bin Zayed University of Artificial Intelligence<sup>1</sup>, University of California Merced<sup>2</sup>, Google Research<sup>3</sup>, Linkoping University<sup>4</sup>
|
Mohamed Bin Zayed University of Artificial Intelligence<sup>1</sup>, University of California Merced<sup>2</sup>, Google Research<sup>3</sup>, Linkoping University<sup>4</sup>
|
||||||
<!-- [](site_url) -->
|
<!-- [](site_url) -->
|
||||||
[](https://arxiv.org/abs/2303.15446)
|
[](https://openaccess.thecvf.com/content/ICCV2023/papers/Shaker_SwiftFormer_Efficient_Additive_Attention_for_Transformer-based_Real-time_Mobile_Vision_Applications_ICCV_2023_paper.pdf)
|
||||||
<!-- [](youtube_link) -->
|
<!-- [](youtube_link) -->
|
||||||
<!-- [](presentation) -->
|
<!-- [](presentation) -->
|
||||||
|
|
||||||
@@ -64,6 +64,28 @@ Self-attention has become a defacto choice for capturing global context in vario
|
|||||||
|
|
||||||
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
|
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
|
||||||
|
|
||||||
|
### SwiftFormer meets Android
|
||||||
|
|
||||||
|
Community-driven results with [Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2](https://www.qualcomm.com/snapdragon/device-finder/samsung-galaxy-s23-ultra):
|
||||||
|
|
||||||
|
1. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export.py) & profiler results of [`SwiftFormer_L1`](./models/swiftformer.py):
|
||||||
|
|
||||||
|
| QNN | 2.16 | 2.17 | 2.18 |
|
||||||
|
| -------------- | -----| ----- | ------ |
|
||||||
|
| Latency (msec) | 2.63 | 2.26 | 2.43 |
|
||||||
|
|
||||||
|
2. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export_block.py) & profiler results of SwiftFormerEncoder block:
|
||||||
|
|
||||||
|
| QNN | 2.16 | 2.17 | 2.18 |
|
||||||
|
| -------------- | -----| ----- | ------ |
|
||||||
|
| Latency (msec) | 2.17 | 1.69 | 1.7 |
|
||||||
|
|
||||||
|
Refer to the script above for details of the input & block parameters.
|
||||||
|
|
||||||
|
❓ _Interested in reproducing the results above?_
|
||||||
|
|
||||||
|
Refer to [Issue #14](https://github.com/Amshaker/SwiftFormer/issues/14) for details about [exporting & profiling.](https://github.com/Amshaker/SwiftFormer/issues/14#issuecomment-1883351728)
|
||||||
|
|
||||||
## ImageNet
|
## ImageNet
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
@@ -78,7 +100,7 @@ pip install timm
|
|||||||
pip install coremltools==5.2.0
|
pip install coremltools==5.2.0
|
||||||
```
|
```
|
||||||
|
|
||||||
### Data preparation
|
### Data Preparation
|
||||||
|
|
||||||
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
|
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
|
||||||
```
|
```
|
||||||
@@ -87,7 +109,7 @@ Download and extract ImageNet train and val images from http://image-net.org. Th
|
|||||||
|-- val
|
|-- val
|
||||||
```
|
```
|
||||||
|
|
||||||
### Single machine multi-GPU training
|
### Single-machine multi-GPU training
|
||||||
|
|
||||||
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
|
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
|
||||||
|
|
||||||
@@ -107,7 +129,7 @@ On a Slurm-managed cluster, multi-node training can be launched as
|
|||||||
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
|
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: specify slurm specific paramters in `slurm_train.sh` script.
|
Note: specify slurm specific parameters in `slurm_train.sh` script.
|
||||||
|
|
||||||
### Testing
|
### Testing
|
||||||
|
|
||||||
@@ -121,20 +143,22 @@ sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth
|
|||||||
## Citation
|
## Citation
|
||||||
if you use our work, please consider citing us:
|
if you use our work, please consider citing us:
|
||||||
```BibTeX
|
```BibTeX
|
||||||
@article{Shaker2023SwiftFormer,
|
@InProceedings{Shaker_2023_ICCV,
|
||||||
title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
|
|
||||||
author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
|
author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
|
||||||
journal={arXiv:2303.15446},
|
title = {SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
|
||||||
year={2023}
|
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||||
|
year = {2023},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Contact:
|
## Contact:
|
||||||
If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
|
If you have any questions, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation.
|
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank the authors for their open-source implementation.
|
||||||
|
|
||||||
|
I'd like to express my sincere appreciation to [Victor Escorcia](https://github.com/escorciav) for measuring & reporting the latency of SwiftFormer on Android (Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2). Check [SwiftFormer Meets Android](https://github.com/escorciav/SwiftFormer) for more details!
|
||||||
|
|
||||||
## Our Related Works
|
## Our Related Works
|
||||||
|
|
||||||
|
|||||||
58
dist_temporal_train.sh
Executable file
58
dist_temporal_train.sh
Executable file
@@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Distributed training script for SwiftFormerTemporal
|
||||||
|
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
|
||||||
|
|
||||||
|
DATA_PATH=$1
|
||||||
|
NUM_GPUS=$2
|
||||||
|
|
||||||
|
# Shift arguments to pass remaining options to python script
|
||||||
|
shift 2
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
||||||
|
BATCH_SIZE=${BATCH_SIZE:-128}
|
||||||
|
EPOCHS=${EPOCHS:-100}
|
||||||
|
# LR=${LR:-1e-3}
|
||||||
|
LR=${LR:-0.01}
|
||||||
|
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
||||||
|
|
||||||
|
echo "Starting distributed training with $NUM_GPUS GPUs"
|
||||||
|
echo "Data path: $DATA_PATH"
|
||||||
|
echo "Model: $MODEL"
|
||||||
|
echo "Batch size: $BATCH_SIZE"
|
||||||
|
echo "Epochs: $EPOCHS"
|
||||||
|
echo "Output dir: $OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Check if torch.distributed.launch or torchrun should be used
|
||||||
|
# For newer PyTorch versions (>=1.9), torchrun is recommended
|
||||||
|
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
|
||||||
|
echo "PyTorch version: $PYTHON_VERSION"
|
||||||
|
|
||||||
|
# Use torchrun for newer PyTorch versions
|
||||||
|
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
|
||||||
|
echo "Using torchrun (PyTorch >=1.10)"
|
||||||
|
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
|
||||||
|
--data-path "$DATA_PATH" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--batch-size $BATCH_SIZE \
|
||||||
|
--epochs $EPOCHS \
|
||||||
|
--lr $LR \
|
||||||
|
--output-dir "$OUTPUT_DIR" \
|
||||||
|
"$@"
|
||||||
|
else
|
||||||
|
echo "Using torch.distributed.launch"
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
|
||||||
|
--data-path "$DATA_PATH" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--batch-size $BATCH_SIZE \
|
||||||
|
--epochs $EPOCHS \
|
||||||
|
--lr $LR \
|
||||||
|
--output-dir "$OUTPUT_DIR" \
|
||||||
|
"$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For single-node multi-GPU training with specific options:
|
||||||
|
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
|
||||||
|
|
||||||
|
echo "Training completed. Check logs in $OUTPUT_DIR"
|
||||||
484
evaluate_temporal.py
Normal file
484
evaluate_temporal.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
"""
|
||||||
|
评估脚本 for SwiftFormerTemporal frame prediction
|
||||||
|
输出预测图(注意反归一化)以及对应指标(mse&ssim&psnr)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
from util.video_dataset import VideoFrameDataset
|
||||||
|
from models.swiftformer_temporal import (
|
||||||
|
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
|
||||||
|
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入SSIM和PSNR计算
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
from skimage.metrics import peak_signal_noise_ratio as psnr
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(tensor):
|
||||||
|
"""
|
||||||
|
将[-1, 1]范围的张量反归一化到[0, 255]范围
|
||||||
|
Args:
|
||||||
|
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
|
||||||
|
Returns:
|
||||||
|
反归一化后的张量,值在[0, 255]
|
||||||
|
"""
|
||||||
|
# clip 到 [-1, 1] 范围
|
||||||
|
tensor = tensor.clamp(-1, 1)
|
||||||
|
|
||||||
|
# [-1, 1] -> [0, 1]
|
||||||
|
tensor = (tensor + 1) / 2
|
||||||
|
# [0, 1] -> [0, 255]
|
||||||
|
tensor = tensor * 255
|
||||||
|
return tensor.clamp(0, 255)
|
||||||
|
|
||||||
|
def minmax_denormalize(tensor):
|
||||||
|
tensor_min = tensor.min()
|
||||||
|
tensor_max = tensor.max()
|
||||||
|
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
|
||||||
|
# tensor = tensor*2-1
|
||||||
|
tensor = tensor*255
|
||||||
|
return tensor.clamp(0, 255)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(pred, target, debug=False):
|
||||||
|
"""
|
||||||
|
计算MSE, SSIM, PSNR指标
|
||||||
|
Args:
|
||||||
|
pred: 预测图像,形状[H, W],值在[0, 255]
|
||||||
|
target: 目标图像,形状[H, W],值在[0, 255]
|
||||||
|
debug: 是否输出调试信息
|
||||||
|
Returns:
|
||||||
|
mse, ssim_value, psnr_value
|
||||||
|
"""
|
||||||
|
# 转换为numpy数组
|
||||||
|
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
|
||||||
|
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
|
||||||
|
|
||||||
|
# 确保是2D数组
|
||||||
|
if pred_np.ndim == 3:
|
||||||
|
pred_np = pred_np.squeeze(0)
|
||||||
|
if target_np.ndim == 3:
|
||||||
|
target_np = target_np.squeeze(0)
|
||||||
|
|
||||||
|
# if debug:
|
||||||
|
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||||
|
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||||
|
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||||
|
|
||||||
|
mse = np.mean((pred_np - target_np) ** 2)
|
||||||
|
|
||||||
|
data_range = 255.0
|
||||||
|
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||||
|
|
||||||
|
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||||
|
|
||||||
|
return mse, ssim_value, psnr_value
|
||||||
|
|
||||||
|
|
||||||
|
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||||||
|
input_frame_indices=None, target_frame_index=None):
|
||||||
|
"""
|
||||||
|
保存对比图:输入帧、目标帧、预测帧
|
||||||
|
Args:
|
||||||
|
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
|
||||||
|
target_frame: 目标帧,形状[H, W],值在[0, 255]
|
||||||
|
pred_frame: 预测帧,形状[H, W],值在[0, 255]
|
||||||
|
save_path: 保存路径
|
||||||
|
input_frame_indices: 输入帧的索引列表(可选)
|
||||||
|
target_frame_index: 目标帧索引(可选)
|
||||||
|
"""
|
||||||
|
num_input = len(input_frames)
|
||||||
|
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
|
||||||
|
|
||||||
|
# 绘制输入帧
|
||||||
|
for i in range(num_input):
|
||||||
|
ax = axes[i]
|
||||||
|
ax.imshow(input_frames[i], cmap='gray')
|
||||||
|
if input_frame_indices is not None:
|
||||||
|
ax.set_title(f'Input Frame {input_frame_indices[i]}')
|
||||||
|
else:
|
||||||
|
ax.set_title(f'Input {i+1}')
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
# 绘制目标帧
|
||||||
|
ax = axes[num_input]
|
||||||
|
ax.imshow(target_frame, cmap='gray')
|
||||||
|
if target_frame_index is not None:
|
||||||
|
ax.set_title(f'Target Frame {target_frame_index}')
|
||||||
|
else:
|
||||||
|
ax.set_title('Target')
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
# 绘制预测帧
|
||||||
|
ax = axes[num_input + 1]
|
||||||
|
ax.imshow(pred_frame, cmap='gray')
|
||||||
|
ax.set_title('Predicted')
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
#debug print
|
||||||
|
print(target_frame)
|
||||||
|
print(pred_frame)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(model, data_loader, device, args):
|
||||||
|
"""
|
||||||
|
评估模型并计算指标
|
||||||
|
Args:
|
||||||
|
model: 训练好的模型
|
||||||
|
data_loader: 数据加载器
|
||||||
|
device: 设备
|
||||||
|
args: 命令行参数
|
||||||
|
Returns:
|
||||||
|
metrics_dict: 包含所有指标的字典
|
||||||
|
sample_results: 示例结果用于可视化
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
# model.train() # 临时使用训练模式
|
||||||
|
|
||||||
|
# 初始化指标累加器
|
||||||
|
total_mse = 0.0
|
||||||
|
total_ssim = 0.0
|
||||||
|
total_psnr = 0.0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
|
||||||
|
sample_results = []
|
||||||
|
max_samples_to_save = args.num_samples_to_save
|
||||||
|
max_samples = args.max_samples
|
||||||
|
# 用于蓄水池抽样的计数器(已处理的样本数,不包括因max_samples限制而跳过的样本)
|
||||||
|
sample_count = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
|
||||||
|
input_frames = input_frames.to(device, non_blocking=True)
|
||||||
|
target_frames = target_frames.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
pred_frames = model(input_frames)
|
||||||
|
|
||||||
|
# 反归一化用于指标计算
|
||||||
|
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
|
||||||
|
pred_denorm = denormalize(pred_frames)
|
||||||
|
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
||||||
|
|
||||||
|
batch_size = input_frames.size(0)
|
||||||
|
|
||||||
|
# 计算每个样本的指标
|
||||||
|
for i in range(batch_size):
|
||||||
|
# 检查是否达到最大样本数限制
|
||||||
|
if max_samples is not None and total_samples >= max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
pred_i = pred_denorm[i] # [1, H, W]
|
||||||
|
target_i = target_denorm[i] # [1, H, W]
|
||||||
|
|
||||||
|
# 对第一个样本启用调试
|
||||||
|
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||||
|
# if debug_mode:
|
||||||
|
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||||
|
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||||
|
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||||
|
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||||
|
|
||||||
|
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
|
||||||
|
|
||||||
|
total_mse += mse
|
||||||
|
total_ssim += ssim_value
|
||||||
|
total_psnr += psnr_value
|
||||||
|
total_samples += 1
|
||||||
|
sample_count += 1
|
||||||
|
|
||||||
|
# 构建样本数据字典
|
||||||
|
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
|
||||||
|
# 分离输入帧
|
||||||
|
input_frames_list = []
|
||||||
|
for j in range(args.num_frames):
|
||||||
|
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
|
||||||
|
input_frames_list.append(input_frame_j.cpu().numpy())
|
||||||
|
|
||||||
|
sample_data = {
|
||||||
|
'input_frames': input_frames_list,
|
||||||
|
'target_frame': target_i.squeeze(0).cpu().numpy(),
|
||||||
|
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
|
||||||
|
'metrics': {
|
||||||
|
'mse': mse,
|
||||||
|
'ssim': ssim_value,
|
||||||
|
'psnr': psnr_value
|
||||||
|
},
|
||||||
|
'batch_idx': batch_idx,
|
||||||
|
'sample_idx': i
|
||||||
|
}
|
||||||
|
|
||||||
|
# 蓄水池抽样 (Reservoir Sampling)
|
||||||
|
if sample_count <= max_samples_to_save:
|
||||||
|
# 蓄水池未满,直接加入
|
||||||
|
sample_results.append(sample_data)
|
||||||
|
else:
|
||||||
|
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
|
||||||
|
r = random.randint(0, sample_count - 1)
|
||||||
|
if r < max_samples_to_save:
|
||||||
|
sample_results[r] = sample_data
|
||||||
|
|
||||||
|
# 检查是否达到最大样本数限制
|
||||||
|
if max_samples is not None and total_samples >= max_samples:
|
||||||
|
print(f"达到最大样本数限制: {max_samples}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 进度打印
|
||||||
|
if (batch_idx + 1) % 10 == 0:
|
||||||
|
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
|
||||||
|
|
||||||
|
# 计算平均指标
|
||||||
|
if total_samples > 0:
|
||||||
|
avg_mse = float(total_mse / total_samples)
|
||||||
|
avg_ssim = float(total_ssim / total_samples)
|
||||||
|
avg_psnr = float(total_psnr / total_samples)
|
||||||
|
else:
|
||||||
|
avg_mse = avg_ssim = avg_psnr = 0.0
|
||||||
|
|
||||||
|
metrics_dict = {
|
||||||
|
'mse': avg_mse,
|
||||||
|
'ssim': avg_ssim,
|
||||||
|
'psnr': avg_psnr,
|
||||||
|
'num_samples': total_samples
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics_dict, sample_results
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("评估参数:", args)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
# 设置随机种子
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
random.seed(args.seed)
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
# 构建数据集
|
||||||
|
print("构建数据集...")
|
||||||
|
dataset_val = VideoFrameDataset(
|
||||||
|
root_dir=args.data_path,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
frame_size=args.frame_size,
|
||||||
|
is_train=False,
|
||||||
|
max_interval=args.max_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
|
dataset_val,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
print(f"创建模型: {args.model}")
|
||||||
|
model_kwargs = {
|
||||||
|
'num_frames': args.num_frames,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.model == 'SwiftFormerTemporal_XS':
|
||||||
|
model = SwiftFormerTemporal_XS(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_S':
|
||||||
|
model = SwiftFormerTemporal_S(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_L1':
|
||||||
|
model = SwiftFormerTemporal_L1(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_L3':
|
||||||
|
model = SwiftFormerTemporal_L3(**model_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"未知模型: {args.model}")
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 加载检查点
|
||||||
|
if args.resume:
|
||||||
|
print(f"加载检查点: {args.resume}")
|
||||||
|
try:
|
||||||
|
# 尝试使用weights_only=False加载(PyTorch 2.6+需要)
|
||||||
|
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||||
|
except (pickle.UnpicklingError, TypeError) as e:
|
||||||
|
print(f"使用weights_only=False加载失败: {e}")
|
||||||
|
print("尝试使用torch.serialization.add_safe_globals...")
|
||||||
|
|
||||||
|
# 处理状态字典(可能包含'module.'前缀)
|
||||||
|
if 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
elif 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
# 移除'module.'前缀(如果存在)
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
model.module.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
|
||||||
|
if any(key.startswith('module.') for key in state_dict.keys()):
|
||||||
|
from collections import OrderedDict
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
new_state_dict[k[7:]] = v
|
||||||
|
else:
|
||||||
|
new_state_dict[k] = v
|
||||||
|
state_dict = new_state_dict
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
print(f"检查点加载成功,epoch: {checkpoint.get('epoch', 'unknown')}")
|
||||||
|
else:
|
||||||
|
print("警告: 未提供检查点路径,使用随机初始化的模型")
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 评估模型
|
||||||
|
print("开始评估...")
|
||||||
|
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
|
||||||
|
|
||||||
|
# 打印指标
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("评估结果:")
|
||||||
|
print(f"MSE: {metrics['mse']:.6f}")
|
||||||
|
print(f"SSIM: {metrics['ssim']:.6f}")
|
||||||
|
print(f"PSNR: {metrics['psnr']:.6f} dB")
|
||||||
|
print(f"样本数量: {metrics['num_samples']}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
# 保存指标到JSON文件
|
||||||
|
metrics_file = output_dir / 'evaluation_metrics.json'
|
||||||
|
with open(metrics_file, 'w') as f:
|
||||||
|
json.dump(metrics, f, indent=4)
|
||||||
|
print(f"指标已保存到: {metrics_file}")
|
||||||
|
|
||||||
|
# 保存示例可视化
|
||||||
|
if sample_results:
|
||||||
|
print(f"\n保存 {len(sample_results)} 个示例可视化...")
|
||||||
|
samples_dir = output_dir / 'sample_predictions'
|
||||||
|
samples_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for i, sample in enumerate(sample_results):
|
||||||
|
save_path = samples_dir / f'sample_{i:03d}.png'
|
||||||
|
|
||||||
|
# 生成输入帧索引(假设连续)
|
||||||
|
input_frame_indices = list(range(1, args.num_frames + 1))
|
||||||
|
target_frame_index = args.num_frames + 1
|
||||||
|
|
||||||
|
save_comparison_figure(
|
||||||
|
sample['input_frames'],
|
||||||
|
sample['target_frame'],
|
||||||
|
sample['pred_frame'],
|
||||||
|
save_path,
|
||||||
|
input_frame_indices=input_frame_indices,
|
||||||
|
target_frame_index=target_frame_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存该样本的指标
|
||||||
|
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
|
||||||
|
with open(sample_metrics_file, 'w') as f:
|
||||||
|
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
|
||||||
|
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
|
||||||
|
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
|
||||||
|
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
|
||||||
|
|
||||||
|
print(f"示例可视化已保存到: {samples_dir}")
|
||||||
|
|
||||||
|
# 生成汇总报告
|
||||||
|
report_file = output_dir / 'evaluation_report.txt'
|
||||||
|
with open(report_file, 'w') as f:
|
||||||
|
f.write("SwiftFormerTemporal 帧预测评估报告\n")
|
||||||
|
f.write("="*50 + "\n")
|
||||||
|
f.write(f"模型: {args.model}\n")
|
||||||
|
f.write(f"检查点: {args.resume}\n")
|
||||||
|
f.write(f"数据集: {args.data_path}\n")
|
||||||
|
f.write(f"输入帧数: {args.num_frames}\n")
|
||||||
|
f.write(f"帧大小: {args.frame_size}\n")
|
||||||
|
f.write(f"批次大小: {args.batch_size}\n")
|
||||||
|
f.write(f"样本总数: {metrics['num_samples']}\n\n")
|
||||||
|
|
||||||
|
f.write("评估指标:\n")
|
||||||
|
f.write(f" MSE: {metrics['mse']:.6f}\n")
|
||||||
|
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
|
||||||
|
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
|
||||||
|
|
||||||
|
print(f"评估报告已保存到: {report_file}")
|
||||||
|
print("\n评估完成!")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'SwiftFormerTemporal 评估脚本', add_help=False)
|
||||||
|
|
||||||
|
# 数据集参数
|
||||||
|
parser.add_argument('--data-path', default='./videos', type=str,
|
||||||
|
help='视频数据集路径')
|
||||||
|
parser.add_argument('--num-frames', default=3, type=int,
|
||||||
|
help='输入帧数 (T)')
|
||||||
|
parser.add_argument('--frame-size', default=224, type=int,
|
||||||
|
help='输入帧大小')
|
||||||
|
parser.add_argument('--max-interval', default=4, type=int,
|
||||||
|
help='连续帧之间的最大间隔')
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||||
|
help='要评估的模型名称')
|
||||||
|
|
||||||
|
# 评估参数
|
||||||
|
parser.add_argument('--batch-size', default=16, type=int,
|
||||||
|
help='评估批次大小')
|
||||||
|
parser.add_argument('--num-samples-to-save', default=10, type=int,
|
||||||
|
help='保存可视化的样本数量')
|
||||||
|
parser.add_argument('--max-samples', default=None, type=int,
|
||||||
|
help='最大评估样本数(None表示全部)')
|
||||||
|
|
||||||
|
# 系统参数
|
||||||
|
parser.add_argument('--output-dir', default='./evaluation_results',
|
||||||
|
help='保存结果的路径')
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
help='使用的设备')
|
||||||
|
parser.add_argument('--seed', default=0, type=int)
|
||||||
|
parser.add_argument('--resume', default='', help='检查点路径')
|
||||||
|
parser.add_argument('--num-workers', default=4, type=int)
|
||||||
|
parser.add_argument('--pin-mem', action='store_true',
|
||||||
|
help='在DataLoader中固定CPU内存')
|
||||||
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||||
|
parser.set_defaults(pin_mem=True)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
if args.output_dir:
|
||||||
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
main(args)
|
||||||
550
main_temporal.py
Normal file
550
main_temporal.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
"""
|
||||||
|
Main training script for SwiftFormerTemporal frame prediction
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from timm.scheduler import create_scheduler
|
||||||
|
from timm.optim import create_optimizer
|
||||||
|
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||||
|
|
||||||
|
from util import *
|
||||||
|
from models import *
|
||||||
|
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||||
|
from util.video_dataset import VideoFrameDataset
|
||||||
|
# from util.frame_losses import MultiTaskLoss
|
||||||
|
|
||||||
|
# Try to import TensorBoard
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
TENSORBOARD_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TENSORBOARD_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'SwiftFormerTemporal training script', add_help=False)
|
||||||
|
|
||||||
|
# Dataset parameters
|
||||||
|
parser.add_argument('--data-path', default='./videos', type=str,
|
||||||
|
help='Path to video dataset')
|
||||||
|
parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'],
|
||||||
|
type=str, help='Dataset type')
|
||||||
|
parser.add_argument('--num-frames', default=3, type=int,
|
||||||
|
help='Number of input frames (T)')
|
||||||
|
parser.add_argument('--frame-size', default=224, type=int,
|
||||||
|
help='Input frame size')
|
||||||
|
parser.add_argument('--max-interval', default=10, type=int,
|
||||||
|
help='Maximum interval between consecutive frames')
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||||
|
help='Name of model to train')
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument('--batch-size', default=32, type=int)
|
||||||
|
parser.add_argument('--epochs', default=100, type=int)
|
||||||
|
|
||||||
|
# Optimizer parameters (required by timm's create_optimizer)
|
||||||
|
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
||||||
|
help='Optimizer (default: "adamw"')
|
||||||
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||||
|
help='Optimizer Epsilon (default: 1e-8)')
|
||||||
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||||
|
help='Optimizer Betas (default: None, use opt default)')
|
||||||
|
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
|
||||||
|
help='Clip gradient norm (default: None, no clipping)')
|
||||||
|
parser.add_argument('--clip-mode', type=str, default='agc',
|
||||||
|
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||||
|
help='SGD momentum (default: 0.9)')
|
||||||
|
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||||
|
help='weight decay (default: 0.05)')
|
||||||
|
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
|
||||||
|
help='learning rate (default: 1e-3)')
|
||||||
|
|
||||||
|
# Learning rate schedule parameters (required by timm's create_scheduler)
|
||||||
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||||
|
help='LR scheduler (default: "cosine"')
|
||||||
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||||
|
help='learning rate noise on/off epoch percentages')
|
||||||
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||||
|
help='learning rate noise limit percent (default: 0.67)')
|
||||||
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||||
|
help='learning rate noise std-dev (default: 1.0)')
|
||||||
|
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
|
||||||
|
help='warmup learning rate (default: 1e-6)')
|
||||||
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||||
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||||
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||||
|
help='epoch interval to decay LR')
|
||||||
|
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||||
|
help='epochs to warmup LR, if scheduler supports')
|
||||||
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||||
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||||
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||||
|
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||||
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||||
|
help='LR decay rate (default: 0.1)')
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
parser.add_argument('--frame-weight', type=float, default=1.0,
|
||||||
|
help='Weight for frame prediction loss')
|
||||||
|
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
||||||
|
help='Weight for contrastive loss')
|
||||||
|
# parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||||
|
# help='Weight for L1 loss')
|
||||||
|
# parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||||
|
# help='Weight for SSIM loss')
|
||||||
|
parser.add_argument('--no-contrastive', action='store_true',
|
||||||
|
help='Disable contrastive loss')
|
||||||
|
parser.add_argument('--no-ssim', action='store_true',
|
||||||
|
help='Disable SSIM loss')
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
parser.add_argument('--output-dir', default='./output',
|
||||||
|
help='path where to save, empty for no saving')
|
||||||
|
parser.add_argument('--device', default='cuda',
|
||||||
|
help='device to use for training / testing')
|
||||||
|
parser.add_argument('--seed', default=0, type=int)
|
||||||
|
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||||||
|
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||||
|
help='start epoch')
|
||||||
|
parser.add_argument('--eval', action='store_true',
|
||||||
|
help='Perform evaluation only')
|
||||||
|
parser.add_argument('--num-workers', default=16, type=int)
|
||||||
|
parser.add_argument('--pin-mem', action='store_true',
|
||||||
|
help='Pin CPU memory in DataLoader')
|
||||||
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||||
|
parser.set_defaults(pin_mem=True)
|
||||||
|
|
||||||
|
# Distributed training
|
||||||
|
parser.add_argument('--world-size', default=1, type=int,
|
||||||
|
help='number of distributed processes')
|
||||||
|
parser.add_argument('--dist-url', default='env://',
|
||||||
|
help='url used to set up distributed training')
|
||||||
|
|
||||||
|
# TensorBoard logging
|
||||||
|
parser.add_argument('--tensorboard-logdir', default='./runs',
|
||||||
|
type=str, help='TensorBoard log directory')
|
||||||
|
parser.add_argument('--log-images', action='store_true',
|
||||||
|
help='Log sample images to TensorBoard')
|
||||||
|
parser.add_argument('--image-log-freq', default=100, type=int,
|
||||||
|
help='Frequency of logging images (in iterations)')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(is_train, args):
|
||||||
|
"""Build video frame dataset"""
|
||||||
|
dataset = VideoFrameDataset(
|
||||||
|
root_dir=args.data_path,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
frame_size=args.frame_size,
|
||||||
|
is_train=is_train,
|
||||||
|
max_interval=args.max_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
utils.init_distributed_mode(args)
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
# Fix the seed for reproducibility
|
||||||
|
seed = args.seed + utils.get_rank()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
# Build datasets
|
||||||
|
dataset_train = build_dataset(is_train=True, args=args)
|
||||||
|
dataset_val = build_dataset(is_train=False, args=args)
|
||||||
|
|
||||||
|
# Create samplers
|
||||||
|
if args.distributed:
|
||||||
|
sampler_train = torch.utils.data.DistributedSampler(dataset_train)
|
||||||
|
sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)
|
||||||
|
else:
|
||||||
|
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
||||||
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||||
|
|
||||||
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
|
dataset_train, sampler=sampler_train,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
|
dataset_val, sampler=sampler_val,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
print(f"Creating model: {args.model}")
|
||||||
|
model_kwargs = {
|
||||||
|
'num_frames': args.num_frames,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.model == 'SwiftFormerTemporal_XS':
|
||||||
|
model = SwiftFormerTemporal_XS(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_S':
|
||||||
|
model = SwiftFormerTemporal_S(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_L1':
|
||||||
|
model = SwiftFormerTemporal_L1(**model_kwargs)
|
||||||
|
elif args.model == 'SwiftFormerTemporal_L3':
|
||||||
|
model = SwiftFormerTemporal_L3(**model_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: {args.model}")
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# Model EMA
|
||||||
|
model_ema = None
|
||||||
|
if hasattr(args, 'model_ema') and args.model_ema:
|
||||||
|
model_ema = ModelEma(
|
||||||
|
model,
|
||||||
|
decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999,
|
||||||
|
device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '',
|
||||||
|
resume='')
|
||||||
|
|
||||||
|
# Distributed training
|
||||||
|
model_without_ddp = model
|
||||||
|
if args.distributed:
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||||
|
model_without_ddp = model.module
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f'Number of parameters: {n_parameters}')
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
optimizer = create_optimizer(args, model_without_ddp)
|
||||||
|
|
||||||
|
# Create loss scaler
|
||||||
|
loss_scaler = NativeScaler()
|
||||||
|
|
||||||
|
# Create scheduler
|
||||||
|
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||||
|
|
||||||
|
# Create loss function - simple MSE for Y channel prediction
|
||||||
|
class MSELossWrapper(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def forward(self, pred_frame, target_frame, temporal_indices=None):
|
||||||
|
loss = self.mse(pred_frame, target_frame)
|
||||||
|
loss_dict = {'mse': loss}
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
|
criterion = MSELossWrapper()
|
||||||
|
|
||||||
|
# Resume from checkpoint
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
if args.resume:
|
||||||
|
if args.resume.startswith('https'):
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
args.resume, map_location='cpu', check_hash=True)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||||
|
|
||||||
|
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||||
|
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||||
|
args.start_epoch = checkpoint['epoch'] + 1
|
||||||
|
if model_ema is not None:
|
||||||
|
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
||||||
|
if 'scaler' in checkpoint:
|
||||||
|
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
writer = None
|
||||||
|
if TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||||
|
from datetime import datetime
|
||||||
|
# Create log directory with timestamp
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
print(f"TensorBoard logs will be saved to: {log_dir}")
|
||||||
|
print(f"To view logs, run: tensorboard --logdir={log_dir}")
|
||||||
|
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||||
|
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
|
||||||
|
print("Training will continue without TensorBoard logging.")
|
||||||
|
|
||||||
|
if args.eval:
|
||||||
|
test_stats = evaluate(data_loader_val, model, criterion, device)
|
||||||
|
print(f"Test stats: {test_stats}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Start training for {args.epochs} epochs")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Global step counter for TensorBoard
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
|
if args.distributed:
|
||||||
|
data_loader_train.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
train_stats, global_step = train_one_epoch(
|
||||||
|
model, criterion, data_loader_train,
|
||||||
|
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
|
||||||
|
model_ema=model_ema, writer=writer,
|
||||||
|
global_step=global_step, args=args
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler.step(epoch)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
|
||||||
|
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||||
|
utils.save_on_master({
|
||||||
|
'model': model_without_ddp.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'lr_scheduler': lr_scheduler.state_dict(),
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_ema': get_state_dict(model_ema) if model_ema else None,
|
||||||
|
'scaler': loss_scaler.state_dict(),
|
||||||
|
'args': args,
|
||||||
|
}, checkpoint_path)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if epoch % 5 == 0 or epoch == args.epochs - 1:
|
||||||
|
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
|
||||||
|
print(f"Epoch {epoch}: Test stats: {test_stats}")
|
||||||
|
|
||||||
|
# Log stats to text file
|
||||||
|
log_stats = {
|
||||||
|
**{f'train_{k}': v for k, v in train_stats.items()},
|
||||||
|
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||||
|
'epoch': epoch,
|
||||||
|
'n_parameters': n_parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.output_dir and utils.is_main_process():
|
||||||
|
with (output_dir / "log.txt").open("a") as f:
|
||||||
|
f.write(json.dumps(log_stats) + "\n")
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f'Training time {total_time_str}')
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
if writer is not None:
|
||||||
|
writer.close()
|
||||||
|
print(f"TensorBoard logs saved to: {writer.log_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||||
|
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
|
||||||
|
global_step=0, args=None, **kwargs):
|
||||||
|
model.train()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
header = f'Epoch: [{epoch}]'
|
||||||
|
print_freq = 10
|
||||||
|
|
||||||
|
# 添加诊断指标
|
||||||
|
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
|
||||||
|
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
|
||||||
|
metric_logger.log_every(data_loader, print_freq, header)):
|
||||||
|
|
||||||
|
input_frames = input_frames.to(device, non_blocking=True)
|
||||||
|
target_frames = target_frames.to(device, non_blocking=True)
|
||||||
|
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
with torch.amp.autocast(device_type='cuda'):
|
||||||
|
pred_frames = model(input_frames)
|
||||||
|
loss, loss_dict = criterion(
|
||||||
|
pred_frames, target_frames,
|
||||||
|
temporal_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_value = loss.item()
|
||||||
|
if not torch.isfinite(torch.tensor(loss_value)):
|
||||||
|
print(f"Loss is {loss_value}, stopping training")
|
||||||
|
raise ValueError(f"Loss is {loss_value}")
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||||
|
parameters=model.parameters())
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if model_ema is not None:
|
||||||
|
model_ema.update(model)
|
||||||
|
|
||||||
|
# 计算诊断指标
|
||||||
|
pred_mean = pred_frames.mean().item()
|
||||||
|
pred_std = pred_frames.std().item()
|
||||||
|
|
||||||
|
# 计算梯度范数
|
||||||
|
total_grad_norm = 0.0
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
total_grad_norm += param.grad.norm().item()
|
||||||
|
|
||||||
|
# 记录诊断指标
|
||||||
|
metric_logger.update(pred_mean=pred_mean)
|
||||||
|
metric_logger.update(pred_std=pred_std)
|
||||||
|
metric_logger.update(grad_norm=total_grad_norm)
|
||||||
|
|
||||||
|
# # 每50个批次打印一次BatchNorm统计
|
||||||
|
if batch_idx % 50 == 0:
|
||||||
|
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||||
|
# # 检查一个BatchNorm层的运行统计
|
||||||
|
# for name, module in model.named_modules():
|
||||||
|
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||||
|
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||||
|
# break
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
if writer is not None:
|
||||||
|
# Log scalar metrics every iteration
|
||||||
|
writer.add_scalar('train/loss', loss_value, global_step)
|
||||||
|
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
|
||||||
|
|
||||||
|
# Log individual loss components
|
||||||
|
for k, v in loss_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
writer.add_scalar(f'train/{k}', v.item(), global_step)
|
||||||
|
else:
|
||||||
|
writer.add_scalar(f'train/{k}', v, global_step)
|
||||||
|
|
||||||
|
# Log diagnostic metrics
|
||||||
|
writer.add_scalar('train/pred_mean', pred_mean, global_step)
|
||||||
|
writer.add_scalar('train/pred_std', pred_std, global_step)
|
||||||
|
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
|
||||||
|
|
||||||
|
# Log images periodically
|
||||||
|
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Take first sample from batch for visualization
|
||||||
|
pred_vis = model(input_frames[:1])
|
||||||
|
# Convert to appropriate format for TensorBoard
|
||||||
|
# Assuming frames are in [B, C, H, W] format
|
||||||
|
writer.add_images('train/input', input_frames[:1], global_step)
|
||||||
|
writer.add_images('train/target', target_frames[:1], global_step)
|
||||||
|
writer.add_images('train/predicted', pred_vis[:1], global_step)
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
metric_logger.update(loss=loss_value)
|
||||||
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||||
|
for k, v in loss_dict.items():
|
||||||
|
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
metric_logger.synchronize_between_processes()
|
||||||
|
print("Averaged stats:", metric_logger)
|
||||||
|
|
||||||
|
# Log epoch-level metrics
|
||||||
|
if writer is not None:
|
||||||
|
for k, meter in metric_logger.meters.items():
|
||||||
|
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
|
||||||
|
|
||||||
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||||
|
model.eval()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
header = 'Test:'
|
||||||
|
|
||||||
|
# 添加诊断指标
|
||||||
|
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
|
|
||||||
|
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
||||||
|
input_frames = input_frames.to(device, non_blocking=True)
|
||||||
|
target_frames = target_frames.to(device, non_blocking=True)
|
||||||
|
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# Compute output
|
||||||
|
with torch.amp.autocast(device_type='cuda'):
|
||||||
|
pred_frames = model(input_frames)
|
||||||
|
loss, loss_dict = criterion(
|
||||||
|
pred_frames, target_frames,
|
||||||
|
temporal_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算诊断指标
|
||||||
|
pred_mean = pred_frames.mean().item()
|
||||||
|
pred_std = pred_frames.std().item()
|
||||||
|
target_mean = target_frames.mean().item()
|
||||||
|
target_std = target_frames.std().item()
|
||||||
|
|
||||||
|
# 更新诊断指标
|
||||||
|
metric_logger.update(pred_mean=pred_mean)
|
||||||
|
metric_logger.update(pred_std=pred_std)
|
||||||
|
metric_logger.update(target_mean=target_mean)
|
||||||
|
metric_logger.update(target_std=target_std)
|
||||||
|
|
||||||
|
# # 第一个批次打印详细诊断信息
|
||||||
|
# if batch_idx == 0:
|
||||||
|
# print(f"[评估诊断] 批次 0:")
|
||||||
|
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||||
|
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||||
|
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||||
|
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||||
|
|
||||||
|
# # 检查BatchNorm运行统计
|
||||||
|
# for name, module in model.named_modules():
|
||||||
|
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||||
|
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||||
|
# if module.running_var[0].item() < 1e-6:
|
||||||
|
# print(f" 警告: BatchNorm运行方差接近零!")
|
||||||
|
# break
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
for k, v in loss_dict.items():
|
||||||
|
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
|
||||||
|
|
||||||
|
metric_logger.synchronize_between_processes()
|
||||||
|
print('* Test stats:', metric_logger)
|
||||||
|
|
||||||
|
# Log validation metrics to TensorBoard
|
||||||
|
if writer is not None:
|
||||||
|
for k, meter in metric_logger.meters.items():
|
||||||
|
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
|
||||||
|
|
||||||
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'SwiftFormerTemporal training script', parents=[get_args_parser()])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.output_dir:
|
||||||
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -1 +1,7 @@
|
|||||||
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
|
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
|
||||||
|
from .swiftformer_temporal import (
|
||||||
|
SwiftFormerTemporal_XS,
|
||||||
|
SwiftFormerTemporal_S,
|
||||||
|
SwiftFormerTemporal_L1,
|
||||||
|
SwiftFormerTemporal_L3
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.models.layers import DropPath, trunc_normal_
|
from timm.layers import DropPath, trunc_normal_
|
||||||
from timm.models.registry import register_model
|
from timm.models import register_model
|
||||||
from timm.models.layers.helpers import to_2tuple
|
from timm.layers import to_2tuple
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
SwiftFormer_width = {
|
SwiftFormer_width = {
|
||||||
@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
|
|||||||
if not self.training:
|
if not self.training:
|
||||||
cls_out = (cls_out[0] + cls_out[1]) / 2
|
cls_out = (cls_out[0] + cls_out[1]) / 2
|
||||||
else:
|
else:
|
||||||
cls_out = self.head(x.mean(-2))
|
cls_out = self.head(x.flatten(2).mean(-1))
|
||||||
# For image classification
|
# For image classification
|
||||||
return cls_out
|
return cls_out
|
||||||
|
|
||||||
|
|||||||
232
models/swiftformer_temporal.py
Normal file
232
models/swiftformer_temporal.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""
|
||||||
|
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .swiftformer import (
|
||||||
|
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
|
||||||
|
stem, Embedding, Stage
|
||||||
|
)
|
||||||
|
from timm.layers import DropPath, trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
"""Upsampling block for frame prediction decoder without residual connections"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||||
|
super().__init__()
|
||||||
|
# 主路径:反卷积 + 两个卷积层
|
||||||
|
self.conv_transpose = nn.ConvTranspose2d(
|
||||||
|
in_channels, out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
bias=False # 禁用bias,因为使用BN
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
|
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||||
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||||
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
# 使用ReLU激活函数
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
# 初始化权重
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
# 初始化反卷积层
|
||||||
|
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
|
||||||
|
# 初始化卷积层
|
||||||
|
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
|
||||||
|
# 初始化BN层(使用默认初始化)
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 主路径
|
||||||
|
x = self.conv_transpose(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FramePredictionDecoder(nn.Module):
|
||||||
|
"""Improved decoder for frame prediction"""
|
||||||
|
def __init__(self, embed_dims, output_channels=1):
|
||||||
|
super().__init__()
|
||||||
|
# Define decoder dimensions independently (no skip connections)
|
||||||
|
start_dim = embed_dims[-1]
|
||||||
|
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
|
||||||
|
# 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1])
|
||||||
|
self.blocks.append(DecoderBlock(
|
||||||
|
decoder_dims[0], decoder_dims[1],
|
||||||
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
|
))
|
||||||
|
# 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2])
|
||||||
|
self.blocks.append(DecoderBlock(
|
||||||
|
decoder_dims[1], decoder_dims[2],
|
||||||
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
|
))
|
||||||
|
# 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3])
|
||||||
|
self.blocks.append(DecoderBlock(
|
||||||
|
decoder_dims[2], decoder_dims[3],
|
||||||
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
|
))
|
||||||
|
# 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
|
||||||
|
self.blocks.append(DecoderBlock(
|
||||||
|
decoder_dims[3], 64,
|
||||||
|
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||||
|
))
|
||||||
|
|
||||||
|
self.final_block = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
||||||
|
"""
|
||||||
|
# 不使用skip connections
|
||||||
|
for i in range(4):
|
||||||
|
x = self.blocks[i](x)
|
||||||
|
|
||||||
|
# 最终输出层:只进行特征精炼,不上采样
|
||||||
|
x = self.final_block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwiftFormerTemporal(nn.Module):
|
||||||
|
"""
|
||||||
|
SwiftFormer with temporal input for frame prediction.
|
||||||
|
Input: [B, num_frames, H, W] (Y channel only)
|
||||||
|
Output: predicted frame [B, 1, H, W] and optional representation
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
model_name='XS',
|
||||||
|
num_frames=3,
|
||||||
|
use_decoder=True,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Get model configuration
|
||||||
|
layers = SwiftFormer_depth[model_name]
|
||||||
|
embed_dims = SwiftFormer_width[model_name]
|
||||||
|
|
||||||
|
# Store configuration
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.use_decoder = use_decoder
|
||||||
|
|
||||||
|
# Modify stem to accept multiple frames (only Y channel)
|
||||||
|
in_channels = num_frames
|
||||||
|
self.patch_embed = stem(in_channels, embed_dims[0])
|
||||||
|
|
||||||
|
# Build encoder network (same as SwiftFormer)
|
||||||
|
network = []
|
||||||
|
for i in range(len(layers)):
|
||||||
|
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
drop_rate=0., drop_path_rate=0.,
|
||||||
|
use_layer_scale=True,
|
||||||
|
layer_scale_init_value=1e-5,
|
||||||
|
vit_num=1)
|
||||||
|
network.append(stage)
|
||||||
|
if i >= len(layers) - 1:
|
||||||
|
break
|
||||||
|
if embed_dims[i] != embed_dims[i + 1]:
|
||||||
|
network.append(
|
||||||
|
Embedding(
|
||||||
|
patch_size=3, stride=2, padding=1,
|
||||||
|
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.network = nn.ModuleList(network)
|
||||||
|
self.norm = nn.BatchNorm2d(embed_dims[-1])
|
||||||
|
|
||||||
|
# Frame prediction decoder
|
||||||
|
if use_decoder:
|
||||||
|
self.decoder = FramePredictionDecoder(
|
||||||
|
embed_dims,
|
||||||
|
output_channels=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
|
# 使用Kaiming初始化,适合ReLU
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.ConvTranspose2d):
|
||||||
|
# 反卷积层使用特定的初始化
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
def forward_tokens(self, x):
|
||||||
|
for block in self.network:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: input frames of shape [B, num_frames, H, W]
|
||||||
|
Returns:
|
||||||
|
pred_frame: predicted frame [B, 1, H, W] (or None)
|
||||||
|
"""
|
||||||
|
# Encode
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.forward_tokens(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# Decode to frame
|
||||||
|
pred_frame = None
|
||||||
|
if self.use_decoder:
|
||||||
|
pred_frame = self.decoder(x)
|
||||||
|
|
||||||
|
return pred_frame
|
||||||
|
|
||||||
|
|
||||||
|
# Factory functions for different model sizes
|
||||||
|
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
|
||||||
|
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
|
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
|
||||||
|
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
|
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
|
||||||
|
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
|
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
|
||||||
|
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
|
||||||
233
util/video_dataset.py
Normal file
233
util/video_dataset.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""
|
||||||
|
Video frame dataset for temporal self-supervised learning
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFrameDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for loading consecutive frames from videos for frame prediction.
|
||||||
|
|
||||||
|
Assumes directory structure:
|
||||||
|
dataset_root/
|
||||||
|
video1/
|
||||||
|
frame_0001.jpg
|
||||||
|
frame_0002.jpg
|
||||||
|
...
|
||||||
|
video2/
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
root_dir: str,
|
||||||
|
num_frames: int = 3,
|
||||||
|
frame_size: int = 224,
|
||||||
|
is_train: bool = True,
|
||||||
|
max_interval: int = 1,
|
||||||
|
transform=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
root_dir: Root directory containing video folders
|
||||||
|
num_frames: Number of input frames (T)
|
||||||
|
frame_size: Size to resize frames to
|
||||||
|
is_train: Whether this is training set (affects augmentation)
|
||||||
|
max_interval: Maximum interval between consecutive frames
|
||||||
|
transform: Optional custom transform
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.frame_size = frame_size
|
||||||
|
self.is_train = is_train
|
||||||
|
self.max_interval = max_interval
|
||||||
|
|
||||||
|
# if num_frames < 1:
|
||||||
|
# raise ValueError("num_frames must be >= 1")
|
||||||
|
# if frame_size < 1:
|
||||||
|
# raise ValueError("frame_size must be >= 1")
|
||||||
|
# if max_interval < 1:
|
||||||
|
# raise ValueError("max_interval must be >= 1")
|
||||||
|
|
||||||
|
# Collect all video folders and their frame files
|
||||||
|
self.video_folders = []
|
||||||
|
self.video_frame_files = [] # list of list of Path objects
|
||||||
|
for item in self.root_dir.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
self.video_folders.append(item)
|
||||||
|
# Get all frame files
|
||||||
|
frame_files = sorted([f for f in item.iterdir()
|
||||||
|
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||||
|
self.video_frame_files.append(frame_files)
|
||||||
|
|
||||||
|
if len(self.video_folders) == 0:
|
||||||
|
raise ValueError(f"No video folders found in {root_dir}")
|
||||||
|
|
||||||
|
# Build frame index: list of (video_idx, start_frame_idx)
|
||||||
|
self.frame_indices = []
|
||||||
|
for video_idx, frame_files in enumerate(self.video_frame_files):
|
||||||
|
# Minimum frames needed considering max interval
|
||||||
|
min_frames_needed = num_frames * max_interval + 1
|
||||||
|
if len(frame_files) < min_frames_needed:
|
||||||
|
continue # Skip videos with insufficient frames
|
||||||
|
|
||||||
|
# Add all possible starting positions
|
||||||
|
# Ensure that for any interval up to max_interval, all frames are within bounds
|
||||||
|
max_start = len(frame_files) - num_frames * max_interval
|
||||||
|
for start_idx in range(max_start):
|
||||||
|
self.frame_indices.append((video_idx, start_idx))
|
||||||
|
|
||||||
|
if len(self.frame_indices) == 0:
|
||||||
|
raise ValueError("No valid frame sequences found in dataset")
|
||||||
|
|
||||||
|
# Default transforms
|
||||||
|
if transform is None:
|
||||||
|
self.transform = self._default_transform()
|
||||||
|
else:
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
|
||||||
|
# Convert pixel values [0, 255] to [-1, 1]
|
||||||
|
# This matches the model's tanh output range
|
||||||
|
self.normalize = None # We'll handle normalization manually
|
||||||
|
|
||||||
|
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
|
||||||
|
|
||||||
|
def _default_transform(self):
|
||||||
|
"""Default transform with augmentation for training"""
|
||||||
|
if self.is_train:
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize(int(self.frame_size * 1.14)),
|
||||||
|
transforms.CenterCrop(self.frame_size),
|
||||||
|
])
|
||||||
|
|
||||||
|
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
||||||
|
"""Load a single frame as PIL Image"""
|
||||||
|
frame_files = self.video_frame_files[video_idx]
|
||||||
|
if frame_idx < 0 or frame_idx >= len(frame_files):
|
||||||
|
raise IndexError(
|
||||||
|
f"Frame index {frame_idx} out of range for video {video_idx} "
|
||||||
|
f"(0-{len(frame_files)-1})"
|
||||||
|
)
|
||||||
|
frame_path = frame_files[frame_idx]
|
||||||
|
return Image.open(frame_path).convert('RGB')
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.frame_indices)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
|
||||||
|
target_frame: [1, H, W] target frame to predict (Y channel only)
|
||||||
|
temporal_idx: temporal index of target frame (for contrastive loss)
|
||||||
|
"""
|
||||||
|
video_idx, start_idx = self.frame_indices[idx]
|
||||||
|
|
||||||
|
# Determine frame interval (for temporal augmentation)
|
||||||
|
interval = random.randint(1, self.max_interval) if self.is_train else 1
|
||||||
|
|
||||||
|
# Load input frames
|
||||||
|
input_frames = []
|
||||||
|
for i in range(self.num_frames):
|
||||||
|
frame_idx = start_idx + i * interval
|
||||||
|
frame = self._load_frame(video_idx, frame_idx)
|
||||||
|
|
||||||
|
# Apply transform (same for all frames in sequence)
|
||||||
|
if self.transform:
|
||||||
|
frame = self.transform(frame)
|
||||||
|
|
||||||
|
input_frames.append(frame)
|
||||||
|
|
||||||
|
# Load target frame (next frame after input sequence)
|
||||||
|
target_idx = start_idx + self.num_frames * interval
|
||||||
|
target_frame = self._load_frame(video_idx, target_idx)
|
||||||
|
if self.transform:
|
||||||
|
target_frame = self.transform(target_frame)
|
||||||
|
|
||||||
|
# Convert to tensors and convert to grayscale (Y channel)
|
||||||
|
input_tensors = []
|
||||||
|
for frame in input_frames:
|
||||||
|
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
|
||||||
|
# Convert RGB to grayscale using weighted sum
|
||||||
|
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||||
|
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
|
||||||
|
# Normalize from [0, 1] to [-1, 1]
|
||||||
|
gray = gray * 2 - 1 # [0,1] -> [-1,1]
|
||||||
|
input_tensors.append(gray)
|
||||||
|
|
||||||
|
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
|
||||||
|
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||||
|
# Normalize from [0, 1] to [-1, 1]
|
||||||
|
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
|
||||||
|
|
||||||
|
# Concatenate input frames along channel dimension
|
||||||
|
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||||
|
|
||||||
|
# Temporal index (for contrastive loss)
|
||||||
|
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||||
|
|
||||||
|
return input_concatenated, target_gray, temporal_idx
|
||||||
|
|
||||||
|
|
||||||
|
# class SyntheticVideoDataset(Dataset):
|
||||||
|
# """
|
||||||
|
# Synthetic dataset for testing - generates random frames
|
||||||
|
# """
|
||||||
|
# def __init__(self,
|
||||||
|
# num_samples: int = 1000,
|
||||||
|
# num_frames: int = 3,
|
||||||
|
# frame_size: int = 224,
|
||||||
|
# is_train: bool = True):
|
||||||
|
# self.num_samples = num_samples
|
||||||
|
# self.num_frames = num_frames
|
||||||
|
# self.frame_size = frame_size
|
||||||
|
# self.is_train = is_train
|
||||||
|
|
||||||
|
# # Normalization for Y channel (single channel)
|
||||||
|
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||||
|
# y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||||
|
# self.normalize = transforms.Normalize(
|
||||||
|
# mean=[y_mean],
|
||||||
|
# std=[y_std]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return self.num_samples
|
||||||
|
|
||||||
|
# def __getitem__(self, idx):
|
||||||
|
# # Generate random "frames" (noise with temporal correlation)
|
||||||
|
# input_frames = []
|
||||||
|
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||||
|
|
||||||
|
# for i in range(self.num_frames):
|
||||||
|
# # Add some temporal correlation
|
||||||
|
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||||
|
# frame = torch.clamp(frame, -1, 1)
|
||||||
|
# input_frames.append(self.normalize(frame))
|
||||||
|
# prev_frame = frame
|
||||||
|
|
||||||
|
# # Target frame (next in sequence)
|
||||||
|
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||||
|
# target_frame = torch.clamp(target_frame, -1, 1)
|
||||||
|
# target_tensor = self.normalize(target_frame)
|
||||||
|
|
||||||
|
# # Concatenate inputs
|
||||||
|
# input_concatenated = torch.cat(input_frames, dim=0)
|
||||||
|
|
||||||
|
# # Temporal index
|
||||||
|
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||||
|
|
||||||
|
# return input_concatenated, target_tensor, temporal_idx
|
||||||
303
video_preprocessor.py
Normal file
303
video_preprocessor.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
视频预处理脚本 - 将MP4视频转换为224x224帧图像
|
||||||
|
支持多线程并发处理、进度条显示和中断恢复功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPreprocessor:
|
||||||
|
"""视频预处理器,支持多线程和中断恢复"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
frame_size: int = 224,
|
||||||
|
fps: int = 30,
|
||||||
|
num_workers: int = 4,
|
||||||
|
quality: int = 2,
|
||||||
|
resume: bool = True):
|
||||||
|
"""
|
||||||
|
初始化预处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: 输入视频目录
|
||||||
|
output_dir: 输出帧目录
|
||||||
|
frame_size: 帧大小(正方形)
|
||||||
|
fps: 提取帧率
|
||||||
|
num_workers: 并发工作线程数
|
||||||
|
quality: JPEG质量 (1-31, 数值越小质量越高)
|
||||||
|
resume: 是否启用中断恢复
|
||||||
|
"""
|
||||||
|
self.input_dir = Path(input_dir)
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.frame_size = frame_size
|
||||||
|
self.fps = fps
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.quality = quality
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
# 状态文件路径
|
||||||
|
self.state_file = self.output_dir / ".preprocessing_state.json"
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化状态
|
||||||
|
self.state = self._load_state()
|
||||||
|
|
||||||
|
# 收集所有视频文件
|
||||||
|
self.video_files = self._collect_video_files()
|
||||||
|
|
||||||
|
def _load_state(self) -> Dict:
|
||||||
|
"""加载处理状态"""
|
||||||
|
if self.resume and self.state_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.state_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
print(f"警告: 无法读取状态文件,将重新开始处理")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"completed": [],
|
||||||
|
"failed": [],
|
||||||
|
"total_processed": 0,
|
||||||
|
"start_time": None,
|
||||||
|
"last_update": None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_state(self):
|
||||||
|
"""保存处理状态"""
|
||||||
|
self.state["last_update"] = time.time()
|
||||||
|
try:
|
||||||
|
with open(self.state_file, 'w') as f:
|
||||||
|
json.dump(self.state, f, indent=2)
|
||||||
|
except IOError as e:
|
||||||
|
print(f"警告: 无法保存状态文件: {e}")
|
||||||
|
|
||||||
|
def _collect_video_files(self) -> List[Path]:
|
||||||
|
"""收集所有需要处理的视频文件"""
|
||||||
|
video_files = []
|
||||||
|
for file_path in self.input_dir.glob("*.mp4"):
|
||||||
|
if file_path.name not in self.state["completed"]:
|
||||||
|
video_files.append(file_path)
|
||||||
|
|
||||||
|
return sorted(video_files)
|
||||||
|
|
||||||
|
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
|
||||||
|
"""解析视频文件名,使用完整文件名作为ID"""
|
||||||
|
name_without_ext = video_path.stem
|
||||||
|
|
||||||
|
# 直接使用完整文件名作为ID,确保每个mp4文件有独立的输出目录
|
||||||
|
return {
|
||||||
|
"video_id": name_without_ext,
|
||||||
|
"start_frame": "unknown",
|
||||||
|
"end_frame": "unknown",
|
||||||
|
"full_name": name_without_ext
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_frames(self, video_path: Path) -> bool:
|
||||||
|
"""提取单个视频的帧"""
|
||||||
|
try:
|
||||||
|
# 解析视频名称
|
||||||
|
video_info = self._parse_video_name(video_path)
|
||||||
|
output_subdir = self.output_dir / video_info["video_id"]
|
||||||
|
output_subdir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 构建FFmpeg命令
|
||||||
|
output_pattern = output_subdir / "frame_%04d.jpg"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
|
||||||
|
"-q:v", str(self.quality),
|
||||||
|
"-y", # 覆盖输出文件
|
||||||
|
str(output_pattern)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 执行FFmpeg命令
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300 # 5分钟超时
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证输出帧数量
|
||||||
|
output_frames = list(output_subdir.glob("frame_*.jpg"))
|
||||||
|
if len(output_frames) == 0:
|
||||||
|
print(f"警告: {video_path.name} 没有生成任何帧")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"超时处理 {video_path.name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_path.name} 时发生错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_video(self, video_path: Path) -> tuple[bool, str]:
|
||||||
|
"""处理单个视频文件"""
|
||||||
|
video_name = video_path.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = self._extract_frames(video_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.state["completed"].append(video_name)
|
||||||
|
if video_name in self.state["failed"]:
|
||||||
|
self.state["failed"].remove(video_name)
|
||||||
|
return True, video_name
|
||||||
|
else:
|
||||||
|
self.state["failed"].append(video_name)
|
||||||
|
return False, video_name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_name} 时发生异常: {e}")
|
||||||
|
self.state["failed"].append(video_name)
|
||||||
|
return False, video_name
|
||||||
|
|
||||||
|
def process_all_videos(self):
|
||||||
|
"""处理所有视频文件"""
|
||||||
|
if not self.video_files:
|
||||||
|
print("没有找到需要处理的视频文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(self.video_files)} 个待处理视频文件")
|
||||||
|
print(f"输出目录: {self.output_dir}")
|
||||||
|
print(f"帧大小: {self.frame_size}x{self.frame_size}")
|
||||||
|
print(f"帧率: {self.fps} fps")
|
||||||
|
print(f"并发线程数: {self.num_workers}")
|
||||||
|
|
||||||
|
if self.state["completed"]:
|
||||||
|
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
if self.state["start_time"] is None:
|
||||||
|
self.state["start_time"] = time.time()
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
with tqdm(total=len(self.video_files), desc="处理视频", unit="个") as pbar:
|
||||||
|
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_video = {
|
||||||
|
executor.submit(self._process_video, video_path): video_path
|
||||||
|
for video_path in self.video_files
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理完成的任务
|
||||||
|
for future in as_completed(future_to_video):
|
||||||
|
video_path = future_to_video[future]
|
||||||
|
try:
|
||||||
|
success, video_name = future.result()
|
||||||
|
if success:
|
||||||
|
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
|
||||||
|
else:
|
||||||
|
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_path.name} 时发生异常: {e}")
|
||||||
|
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
self.state["total_processed"] += 1
|
||||||
|
|
||||||
|
# 定期保存状态
|
||||||
|
if self.state["total_processed"] % 5 == 0:
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
|
# 最终保存状态
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
|
# 打印处理结果
|
||||||
|
self._print_summary()
|
||||||
|
|
||||||
|
def _print_summary(self):
|
||||||
|
"""打印处理摘要"""
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("处理完成摘要:")
|
||||||
|
print(f"总处理视频数: {len(self.state['completed'])}")
|
||||||
|
print(f"失败视频数: {len(self.state['failed'])}")
|
||||||
|
|
||||||
|
if self.state["failed"]:
|
||||||
|
print("\n失败的视频:")
|
||||||
|
for video_name in self.state["failed"]:
|
||||||
|
print(f" - {video_name}")
|
||||||
|
|
||||||
|
if self.state["start_time"]:
|
||||||
|
elapsed_time = time.time() - self.state["start_time"]
|
||||||
|
print(f"\n总耗时: {elapsed_time:.2f} 秒")
|
||||||
|
if self.state["total_processed"] > 0:
|
||||||
|
avg_time = elapsed_time / self.state["total_processed"]
|
||||||
|
print(f"平均每个视频: {avg_time:.2f} 秒")
|
||||||
|
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
parser = argparse.ArgumentParser(description="视频预处理脚本")
|
||||||
|
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
|
||||||
|
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
|
||||||
|
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
|
||||||
|
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
|
||||||
|
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
|
||||||
|
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 检查输入目录
|
||||||
|
if not Path(args.input_dir).exists():
|
||||||
|
print(f"错误: 输入目录不存在: {args.input_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 检查FFmpeg是否可用
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
print("错误: FFmpeg未安装或不在PATH中")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 创建预处理器并开始处理
|
||||||
|
preprocessor = VideoPreprocessor(
|
||||||
|
input_dir=args.input_dir,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
frame_size=args.size,
|
||||||
|
fps=args.fps,
|
||||||
|
num_workers=args.workers,
|
||||||
|
quality=args.quality,
|
||||||
|
resume=not args.no_resume
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
preprocessor.process_all_videos()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n用户中断处理,状态已保存")
|
||||||
|
preprocessor._save_state()
|
||||||
|
print("可以使用相同命令恢复处理")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n处理过程中发生错误: {e}")
|
||||||
|
preprocessor._save_state()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user