feat: 增加BasicAuth

This commit is contained in:
2025-10-28 11:32:02 +08:00
parent 925c89d2a3
commit da974f68d3

View File

@@ -2,6 +2,7 @@ import cv2
import numpy as np import numpy as np
from flask import Flask, render_template, request, jsonify, send_file from flask import Flask, render_template, request, jsonify, send_file
from flask_socketio import SocketIO, emit from flask_socketio import SocketIO, emit
from flask_httpauth import HTTPBasicAuth
import io import io
import base64 import base64
import time import time
@@ -19,30 +20,46 @@ import tempfile
from cap_trigger import ImageClient from cap_trigger import ImageClient
# --- 配置 --- # --- 配置 ---
SAVE_PATH_LEFT = "./static/received/left" SAVE_PATH_LEFT = "./static/received/left"
SAVE_PATH_RIGHT = "./static/received/right" SAVE_PATH_RIGHT = "./static/received/right"
SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径 SAVE_PATH_LEFT_MARKED = "./static/received/left_marked"
SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径 SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked"
FLASK_HOST = "0.0.0.0" FLASK_HOST = "0.0.0.0"
FLASK_PORT = 5000 FLASK_PORT = 5000
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示 MAX_LIVE_FRAMES = 2
DATABASE_PATH = "received_images.db" # SQLite 数据库文件路径 DATABASE_PATH = "received_images.db"
# --- 配置 --- # --- 配置 ---
USERNAME = os.getenv('BASIC_AUTH_USERNAME', 'admin')
PASSWORD = os.getenv('BASIC_AUTH_PASSWORD', '19260817')
# 设置日志 # 设置日志
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-change-this'
socketio = SocketIO(app, cors_allowed_origins="*")
auth = HTTPBasicAuth() # 创建 HTTPBasicAuth 实例
@auth.verify_password
def verify_password(username, password):
"""验证用户名和密码"""
# 比较提供的用户名和密码与配置的值
if username == USERNAME and password == PASSWORD:
logger.info(f"Successful BasicAuth login for user: {username}")
return username # 返回用户名表示认证成功
else:
logger.warning(f"Failed BasicAuth attempt for user: {username}")
return None # 返回 None 表示认证失败
# --- 数据库初始化 --- # --- 数据库初始化 ---
def init_db(): def init_db():
"""初始化 SQLite 数据库和表""" """初始化 SQLite 数据库和表"""
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
# 创建表,存储图片信息,添加标注图片字段
# --- Updated CREATE TABLE statement ---
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS images ( CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -58,126 +75,83 @@ def init_db():
manual_detections_right TEXT, manual_detections_right TEXT,
is_manual_labeled_left INTEGER DEFAULT 0, is_manual_labeled_left INTEGER DEFAULT 0,
is_manual_labeled_right INTEGER DEFAULT 0, is_manual_labeled_right INTEGER DEFAULT 0,
left_position INTEGER DEFAULT 0, -- Added column left_position INTEGER DEFAULT 0,
right_position INTEGER DEFAULT 0 -- Added column right_position INTEGER DEFAULT 0
) )
''') ''')
# --- End Update ---
conn.commit() conn.commit()
conn.close() conn.close()
logger.info(f"Database {DATABASE_PATH} initialized.") logger.info(f"Database {DATABASE_PATH} initialized.")
# --- 全局变量 --- # --- 全局变量 ---
# 用于存储最新的左右帧,用于实时显示
latest_left_frame = None latest_left_frame = None
latest_right_frame = None latest_right_frame = None
latest_timestamp = None latest_timestamp = None
frame_lock = threading.Lock() # 保护全局帧变量 frame_lock = threading.Lock()
# --- Flask & SocketIO 应用 ---
app = Flask(__name__)
# 为生产环境配置 SECRET_KEY
app.config['SECRET_KEY'] = 'your-secret-key-change-this'
# 配置异步模式,如果需要异步处理可以调整
socketio = SocketIO(app, cors_allowed_origins="*") # 允许所有来源,生产环境请具体配置
# 初始化图像客户端
image_client = ImageClient("tcp://175.24.228.220:7701", client_id="local")
# 初始化数据库 # 初始化数据库
init_db() init_db()
image_client = ImageClient("tcp://175.24.228.220:7701", client_id="local")
def draw_detections_on_image(image: np.ndarray, detections: list, left_position: int = None, right_position: int = None) -> np.ndarray: def draw_detections_on_image(image: np.ndarray, detections: list, left_position: int = None, right_position: int = None) -> np.ndarray:
"""在图像上绘制检测框和位置信息""" """在图像上绘制检测框和位置信息"""
# 复制原图以避免修改原始图像
marked_image = image.copy() marked_image = image.copy()
color_map = {1: (0, 255, 0), 2: (255, 0, 0), 3: (0, 0, 255), 4: (255, 255, 0)}
# 定义颜色映射
color_map = {
1: (0, 255, 0), # 绿色 - 弹药箱
2: (255, 0, 0), # 蓝色 - 士兵
3: (0, 0, 255), # 红色 - 枪支
4: (255, 255, 0) # 青色 - 数字牌
}
# 获取图像尺寸
h, w = image.shape[:2] h, w = image.shape[:2]
position_text = ""
if left_position is not None: if left_position is not None:
position_text = f"POS/CAM_L: {left_position if left_position != 0 else 'NAN'}" position_text = f"POS/CAM_L: {left_position if left_position != 0 else 'NAN'}"
elif right_position is not None: elif right_position is not None:
position_text = f"POS/CAM_R: {right_position if right_position != 0 else 'NAN'}" position_text = f"POS/CAM_R: {right_position if right_position != 0 else 'NAN'}"
else:
position_text = ""
if position_text: if position_text:
# 设置字体参数
font = cv2.FONT_HERSHEY_SIMPLEX font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.0 font_scale = 1.0
font_thickness = 2 font_thickness = 2
text_size = cv2.getTextSize(position_text, font, font_scale, font_thickness)[0] text_size = cv2.getTextSize(position_text, font, font_scale, font_thickness)[0]
text_x, text_y = 10, 30
cv2.rectangle(marked_image, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5), (255, 255, 255), -1)
cv2.putText(marked_image, position_text, (text_x, text_y), font, font_scale, (0, 0, 0), font_thickness)
# 计算文本背景位置
text_x = 10
text_y = 30
# 绘制白色背景矩形
cv2.rectangle(marked_image,
(text_x - 5, text_y - text_size[1] - 5),
(text_x + text_size[0] + 5, text_y + 5),
(255, 255, 255), -1)
# 绘制黑色文字
cv2.putText(marked_image, position_text, (text_x, text_y),
font, font_scale, (0, 0, 0), font_thickness)
# 绘制每个检测框
for detection in detections: for detection in detections:
# 获取检测信息
obj_id = detection.get("id", 0) obj_id = detection.get("id", 0)
label = detection.get("label", "") label = detection.get("label", "")
bbox = detection.get("bbox", []) bbox = detection.get("bbox", [])
if len(bbox) == 4: if len(bbox) == 4:
# 将归一化的坐标转换为实际像素坐标
x_min = int(bbox[0] * w / 999) x_min = int(bbox[0] * w / 999)
y_min = int(bbox[1] * h / 999) y_min = int(bbox[1] * h / 999)
x_max = int(bbox[2] * w / 999) x_max = int(bbox[2] * w / 999)
y_max = int(bbox[3] * h / 999) y_max = int(bbox[3] * h / 999)
color = color_map.get(obj_id, (255, 255, 255))
# 获取颜色
color = color_map.get(obj_id, (255, 255, 255)) # 默认白色
# 绘制边界框
cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2) cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2)
cv2.putText(marked_image, f"{label} ({obj_id})", (x_min, y_min + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 添加标签
label_text = f"{label} ({obj_id})"
cv2.putText(marked_image, label_text, (x_min, y_min + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return marked_image return marked_image
# --- Flask 路由 --- # --- Flask 路由 ---
# 对需要保护的路由添加 @auth.login_required 装饰器
@app.route('/') @app.route('/')
@auth.login_required # 需要认证才能访问主页
def index(): def index():
"""主页,加载实时图像页面""" """主页,加载实时图像页面"""
logger.info(f"User {auth.current_user()} accessed the main page.")
return render_template('index.html') return render_template('index.html')
@app.route('/list') # 新增路由用于显示图片列表 @app.route('/list') # 新增路由用于显示图片列表
@auth.login_required # 需要认证才能访问列表页
def list_images(): def list_images():
"""加载图片列表页面""" """加载图片列表页面"""
logger.info(f"User {auth.current_user()} accessed the image list page.")
return render_template('list_images.html') return render_template('list_images.html')
@app.route('/api/images', methods=['GET']) @app.route('/api/images', methods=['GET'])
@auth.login_required # 保护 API 端点
def get_images_api(): def get_images_api():
"""API: 获取图片列表 (JSON 格式)""" """API: 获取图片列表 (JSON 格式)"""
logger.info(f"User {auth.current_user()} requested image list API.")
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
# 按时间倒序排列,包含标注图片字段和人工标注字段
# --- Updated query to directly select the new columns ---
# No more try/except needed for these specific columns if init_db creates them
cursor.execute(""" cursor.execute("""
SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename,
timestamp, metadata, comment, created_at, manual_detections_left, manual_detections_right, timestamp, metadata, comment, created_at, manual_detections_left, manual_detections_right,
@@ -186,15 +160,11 @@ def get_images_api():
FROM images FROM images
ORDER BY timestamp DESC ORDER BY timestamp DESC
""") """)
# --- End Update ---
rows = cursor.fetchall() rows = cursor.fetchall()
conn.close() conn.close()
# print(rows)
images = [] images = []
for row in rows: for row in rows:
# Note: Adjusting indices based on the new SELECT order
images.append({ images.append({
"id": row[0], "id": row[0],
"left_filename": row[1], "left_filename": row[1],
@@ -205,28 +175,26 @@ def get_images_api():
"metadata": row[6], "metadata": row[6],
"comment": row[7] or "", "comment": row[7] or "",
"created_at": row[8], "created_at": row[8],
"manual_detections_left": row[9] or "[]", # Renamed for clarity "manual_detections_left": row[9] or "[]",
"manual_detections_right": row[10] or "[]", # Renamed for clarity "manual_detections_right": row[10] or "[]",
"is_manual_labeled_left": bool(row[11]) if row[11] is not None else False, # Renamed for clarity "is_manual_labeled_left": bool(row[11]) if row[11] is not None else False,
"is_manual_labeled_right": bool(row[12]) if row[12] is not None else False, # Renamed for clarity "is_manual_labeled_right": bool(row[12]) if row[12] is not None else False,
"left_position": row[13], # Updated index "left_position": row[13],
"right_position": row[14] # Updated index "right_position": row[14]
}) })
# print(images)
return jsonify(images) return jsonify(images)
@app.route('/api/images', methods=['DELETE']) @app.route('/api/images', methods=['DELETE'])
@auth.login_required
def delete_image_api(): def delete_image_api():
"""API: 删除单张图片记录及其文件""" """API: 删除单张图片记录及其文件"""
logger.info(f"User {auth.current_user()} requested to delete an image.")
image_id = request.json.get('id') image_id = request.json.get('id')
if not image_id: if not image_id:
return jsonify({"error": "Image ID is required"}), 400 return jsonify({"error": "Image ID is required"}), 400
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
# 查询文件名,包含标注图片
cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,)) cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,))
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
@@ -234,40 +202,31 @@ def delete_image_api():
return jsonify({"error": "Image not found"}), 404 return jsonify({"error": "Image not found"}), 404
left_filename, right_filename, left_marked_filename, right_marked_filename = row left_filename, right_filename, left_marked_filename, right_marked_filename = row
# 删除数据库记录
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,)) cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
conn.commit() conn.commit()
conn.close() conn.close()
# 删除对应的文件,包括标注图片
left_path = os.path.join(SAVE_PATH_LEFT, left_filename) left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename) right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None
try: try:
if os.path.exists(left_path): for path in [left_path, right_path, left_marked_path, right_marked_path]:
os.remove(left_path) if path and os.path.exists(path):
logger.info(f"Deleted file: {left_path}") os.remove(path)
if os.path.exists(right_path): logger.info(f"Deleted file: {path}")
os.remove(right_path)
logger.info(f"Deleted file: {right_path}")
if left_marked_path and os.path.exists(left_marked_path):
os.remove(left_marked_path)
logger.info(f"Deleted file: {left_marked_path}")
if right_marked_path and os.path.exists(right_marked_path):
os.remove(right_marked_path)
logger.info(f"Deleted file: {right_marked_path}")
except OSError as e: except OSError as e:
logger.error(f"Error deleting files: {e}") logger.error(f"Error deleting files: {e}")
# 即使删除文件失败,数据库记录也已删除,返回成功
pass pass
return jsonify({"message": f"Image {image_id} deleted successfully"}) return jsonify({"message": f"Image {image_id} deleted successfully"})
@app.route('/api/images/export', methods=['POST']) @app.route('/api/images/export', methods=['POST'])
@auth.login_required
def export_images_api(): def export_images_api():
"""API: 打包导出选中的图片,优先导出标注图片""" """API: 打包导出选中的图片,优先导出标注图片"""
logger.info(f"User {auth.current_user()} requested to export images.")
selected_ids = request.json.get('ids', []) selected_ids = request.json.get('ids', [])
if not selected_ids: if not selected_ids:
return jsonify({"error": "No image IDs selected"}), 400 return jsonify({"error": "No image IDs selected"}), 400
@@ -275,7 +234,6 @@ def export_images_api():
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
placeholders = ','.join('?' * len(selected_ids)) placeholders = ','.join('?' * len(selected_ids))
# 查询包含标注图片的文件名
cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids) cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids)
rows = cursor.fetchall() rows = cursor.fetchall()
conn.close() conn.close()
@@ -283,29 +241,22 @@ def export_images_api():
if not rows: if not rows:
return jsonify({"error": "No matching images found"}), 404 return jsonify({"error": "No matching images found"}), 404
# 创建临时 ZIP 文件
temp_zip_fd, temp_zip_path = tempfile.mkstemp(suffix='.zip') temp_zip_fd, temp_zip_path = tempfile.mkstemp(suffix='.zip')
os.close(temp_zip_fd) # 关闭文件描述符,让 zipfile 模块管理 os.close(temp_zip_fd)
try: try:
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for left_fn, right_fn, left_marked_fn, right_marked_fn in rows: for left_fn, right_fn, left_marked_fn, right_marked_fn in rows:
# 优先使用标注图片,如果没有则使用原图
left_export_fn = left_marked_fn if left_marked_fn else left_fn left_export_fn = left_marked_fn if left_marked_fn else left_fn
right_export_fn = right_marked_fn if right_marked_fn else right_fn right_export_fn = right_marked_fn if right_marked_fn else right_fn
# 确定文件路径
left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn) left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn)
right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn) right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn)
# 添加到 ZIP 文件
if os.path.exists(left_export_path): if os.path.exists(left_export_path):
zipf.write(left_export_path, os.path.join('left', left_export_fn)) zipf.write(left_export_path, os.path.join('left', left_export_fn))
if os.path.exists(right_export_path): if os.path.exists(right_export_path):
zipf.write(right_export_path, os.path.join('right', right_export_fn)) zipf.write(right_export_path, os.path.join('right', right_export_fn))
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}") logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
# 返回 ZIP 文件给客户端
return send_file(temp_zip_path, as_attachment=True, download_name='exported_images.zip') return send_file(temp_zip_path, as_attachment=True, download_name='exported_images.zip')
except Exception as e: except Exception as e:
@@ -316,26 +267,21 @@ def export_images_api():
@app.route('/upload', methods=['POST']) @app.route('/upload', methods=['POST'])
def upload_images(): def upload_images():
"""接收左右摄像头图片,保存并推送更新,同时生成标注图片""" """接收左右摄像头图片,保存并推送更新,同时生成标注图片"""
try: try:
# 从 multipart/form-data 中获取文件
left_file = request.files.get('left_image') left_file = request.files.get('left_image')
right_file = request.files.get('right_image') right_file = request.files.get('right_image')
metadata_str = request.form.get('metadata') # 如果需要处理元数据 metadata_str = request.form.get('metadata')
comment = request.form.get('comment', '') # 获取 comment 字段 comment = request.form.get('comment', '')
if not left_file or not right_file: if not left_file or not right_file:
logger.warning("Received request without required image files.") logger.warning("Received request without required image files.")
return jsonify({"error": "Missing left_image or right_image"}), 400 return jsonify({"error": "Missing left_image or right_image"}), 400
# 读取图片数据 nparr_left = np.frombuffer(left_file.read(), np.uint8)
left_img_bytes = left_file.read() nparr_right = np.frombuffer(right_file.read(), np.uint8)
right_img_bytes = right_file.read()
# 解码图片用于后续处理 (如显示、保存)
nparr_left = np.frombuffer(left_img_bytes, np.uint8)
nparr_right = np.frombuffer(right_img_bytes, np.uint8)
img_left = cv2.imdecode(nparr_left, cv2.IMREAD_COLOR) img_left = cv2.imdecode(nparr_left, cv2.IMREAD_COLOR)
img_right = cv2.imdecode(nparr_right, cv2.IMREAD_COLOR) img_right = cv2.imdecode(nparr_right, cv2.IMREAD_COLOR)
@@ -343,7 +289,6 @@ def upload_images():
logger.error("Failed to decode received images.") logger.error("Failed to decode received images.")
return jsonify({"error": "Could not decode images"}), 400 return jsonify({"error": "Could not decode images"}), 400
# 解析元数据 (如果提供)
metadata = {} metadata = {}
if metadata_str: if metadata_str:
try: try:
@@ -352,60 +297,26 @@ def upload_images():
logger.warning(f"Could not parse metadata: {e}") logger.warning(f"Could not parse metadata: {e}")
timestamp_str = str(metadata.get("timestamp", str(int(time.time())))) timestamp_str = str(metadata.get("timestamp", str(int(time.time()))))
timestamp_str_safe = timestamp_str.replace(".", "_") # 避免文件名中的点号问题 timestamp_str_safe = timestamp_str.replace(".", "_")
# 生成文件名
left_filename = f"left_{timestamp_str_safe}.jpg" left_filename = f"left_{timestamp_str_safe}.jpg"
right_filename = f"right_{timestamp_str_safe}.jpg" right_filename = f"right_{timestamp_str_safe}.jpg"
left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名 left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg"
right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名 right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg"
# 保存原图到本地
left_path = os.path.join(SAVE_PATH_LEFT, left_filename) left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename) right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
# 确保目录存在
os.makedirs(SAVE_PATH_LEFT, exist_ok=True) os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True) os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录 os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True)
os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录 os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True)
cv2.imwrite(left_path, img_left) cv2.imwrite(left_path, img_left)
cv2.imwrite(right_path, img_right) cv2.imwrite(right_path, img_right)
logger.info(f"Saved original images: {left_path}, {right_path}") logger.info(f"Saved original images: {left_path}, {right_path}")
# 使用 VisionAPIClient 处理图片并生成标注图片 # Debug: 直接原图覆盖
left_marked_path = None
right_marked_path = None
try:
# with VisionAPIClient() as client:
# # 处理左图
# left_task_id = client.submit_task(image_id=1, image=img_left)
# # 处理右图
# right_task_id = client.submit_task(image_id=2, image=img_right)
# # 等待任务完成
# client.task_queue.join()
# # 获取处理结果
# left_result = client.get_result(left_task_id)
# right_result = client.get_result(right_task_id)
# # 生成标注图片
# if left_result and left_result.success:
# marked_left_img = draw_detections_on_image(img_left, left_result.detections)
# left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
# cv2.imwrite(left_marked_path, marked_left_img)
# logger.info(f"Saved marked left image: {left_marked_path}")
# if right_result and right_result.success:
# marked_right_img = draw_detections_on_image(img_right, right_result.detections)
# right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
# cv2.imwrite(right_marked_path, marked_right_img)
# logger.info(f"Saved marked right image: {right_marked_path}")
# Debug 直接原图覆盖
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
cv2.imwrite(left_marked_path, img_left) cv2.imwrite(left_marked_path, img_left)
logger.info(f"Saved marked left image: {left_marked_path}") logger.info(f"Saved marked left image: {left_marked_path}")
@@ -414,11 +325,6 @@ def upload_images():
cv2.imwrite(right_marked_path, img_right) cv2.imwrite(right_marked_path, img_right)
logger.info(f"Saved marked right image: {right_marked_path}") logger.info(f"Saved marked right image: {right_marked_path}")
except Exception as e:
logger.error(f"Error processing images with VisionAPIClient: {e}")
# 即使处理失败,也继续保存原图
# 将图片信息写入数据库,包含标注图片字段
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute('''
@@ -429,17 +335,15 @@ def upload_images():
right_marked_filename if right_marked_path else None, right_marked_filename if right_marked_path else None,
float(timestamp_str), json.dumps(metadata), comment)) float(timestamp_str), json.dumps(metadata), comment))
conn.commit() conn.commit()
image_id = cursor.lastrowid # 获取新插入记录的 ID image_id = cursor.lastrowid
conn.close() conn.close()
logger.info(f"Recorded image pair (ID: {image_id}) in database.") logger.info(f"Recorded image pair (ID: {image_id}) in database.")
# 将 OpenCV 图像编码为 base64 字符串,用于 WebSocket 传输
_, left_encoded = cv2.imencode('.jpg', img_left) _, left_encoded = cv2.imencode('.jpg', img_left)
_, right_encoded = cv2.imencode('.jpg', img_right) _, right_encoded = cv2.imencode('.jpg', img_right)
left_b64 = base64.b64encode(left_encoded).decode('utf-8') left_b64 = base64.b64encode(left_encoded).decode('utf-8')
right_b64 = base64.b64encode(right_encoded).decode('utf-8') right_b64 = base64.b64encode(right_encoded).decode('utf-8')
# 更新用于实时显示的全局变量 (如果需要)
with frame_lock: with frame_lock:
global latest_left_frame, latest_right_frame, latest_timestamp global latest_left_frame, latest_right_frame, latest_timestamp
latest_left_frame = img_left latest_left_frame = img_left
@@ -459,32 +363,32 @@ def upload_images():
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.route('/api/images/comment', methods=['PUT']) @app.route('/api/images/comment', methods=['PUT'])
@auth.login_required # 保护 API 端点
def update_image_comment(): def update_image_comment():
"""API: 更新图片的 comment""" """API: 更新图片的 comment"""
logger.info(f"User {auth.current_user()} requested to update image comment.")
data = request.json data = request.json
image_id = data.get('id') image_id = data.get('id')
comment = data.get('comment', '') comment = data.get('comment', '')
if not image_id: if not image_id:
return jsonify({"error": "Image ID is required"}), 400 return jsonify({"error": "Image ID is required"}), 400
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
# 更新 comment 字段
cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id)) cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id))
conn.commit() conn.commit()
conn.close() conn.close()
return jsonify({"message": f"Comment for image {image_id} updated successfully"}) return jsonify({"message": f"Comment for image {image_id} updated successfully"})
@app.route('/api/images/position', methods=['PUT']) @app.route('/api/images/position', methods=['PUT'])
@auth.login_required # 保护 API 端点
def update_image_position(): def update_image_position():
"""API: 更新图片的位置编号""" """API: 更新图片的位置编号"""
logger.info(f"User {auth.current_user()} requested to update image position.")
data = request.json data = request.json
image_id = data.get('id') image_id = data.get('id')
left_position = data.get('left_position', 0) left_position = data.get('left_position', 0)
right_position = data.get('right_position', 0) right_position = data.get('right_position', 0)
if not image_id: if not image_id:
return jsonify({"error": "Image ID is required"}), 400 return jsonify({"error": "Image ID is required"}), 400
@@ -492,32 +396,13 @@ def update_image_position():
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("UPDATE images SET left_position = ?, right_position = ? WHERE id = ?", (left_position, right_position, image_id))
# --- Removed ALTER TABLE logic as columns are now in init_db ---
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN left_position INTEGER DEFAULT 0
# """)
# except sqlite3.OperationalError:
# pass
#
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN right_position INTEGER DEFAULT 0
# """)
# except sqlite3.OperationalError:
# pass
# --- End Removal ---
# 更新位置字段
cursor.execute("UPDATE images SET left_position = ?, right_position = ? WHERE id = ?",
(left_position, right_position, image_id))
conn.commit() conn.commit()
conn.close() conn.close()
return jsonify({"message": f"Position for image {image_id} updated successfully"}) return jsonify({"message": f"Position for image {image_id} updated successfully"})
@app.route('/status') @app.route('/status')
# @auth.login_required # 通常状态检查接口不需要认证,保持开放
def status(): def status():
with frame_lock: with frame_lock:
has_frames = latest_left_frame is not None and latest_right_frame is not None has_frames = latest_left_frame is not None and latest_right_frame is not None
@@ -525,86 +410,44 @@ def status():
return jsonify({"has_frames": has_frames, "latest_timestamp": timestamp}) return jsonify({"has_frames": has_frames, "latest_timestamp": timestamp})
@app.route('/api/images/manual-detections', methods=['PUT']) @app.route('/api/images/manual-detections', methods=['PUT'])
@auth.login_required # 保护 API 端点
def update_manual_detections(): def update_manual_detections():
"""API: 更新图片的人工标注检测框结果,支持左右图像分别标注""" """API: 更新图片的人工标注检测框结果,支持左右图像分别标注"""
# logger.info(f"User {auth.current_user()} requested to update manual detections.")
data = request.json data = request.json
image_id = data.get('id') image_id = data.get('id')
side = data.get('side', 'left') # 获取 side 参数,默认为左侧 side = data.get('side', 'left')
detections = data.get('detections') detections = data.get('detections')
if not image_id or detections is None: if not image_id or detections is None:
return jsonify({"error": "Image ID and detections are required"}), 400 return jsonify({"error": "Image ID and detections are required"}), 400
# 验证检测数据格式
if not isinstance(detections, list): if not isinstance(detections, list):
return jsonify({"error": "Detections must be a list"}), 400 return jsonify({"error": "Detections must be a list"}), 400
for detection in detections: for detection in detections:
if not isinstance(detection, dict): if not isinstance(detection, dict):
return jsonify({"error": "Each detection must be a dictionary"}), 400 return jsonify({"error": "Each detection must be a dictionary"}), 400
required_keys = ['id', 'label', 'bbox'] required_keys = ['id', 'label', 'bbox']
for key in required_keys: for key in required_keys:
if key not in detection: if key not in detection:
return jsonify({"error": f"Missing required key '{key}' in detection"}), 400 return jsonify({"error": f"Missing required key '{key}' in detection"}), 400
# 验证 ID
if not isinstance(detection['id'], int) or detection['id'] not in [1, 2, 3, 4]: if not isinstance(detection['id'], int) or detection['id'] not in [1, 2, 3, 4]:
return jsonify({"error": f"Invalid ID in detection: {detection['id']}"}), 400 return jsonify({"error": f"Invalid ID in detection: {detection['id']}"}), 400
# 验证标签
valid_labels = ['caisson', 'soldier', 'gun', 'number'] valid_labels = ['caisson', 'soldier', 'gun', 'number']
if detection['label'] not in valid_labels: if detection['label'] not in valid_labels:
return jsonify({"error": f"Invalid label in detection: {detection['label']}"}), 400 return jsonify({"error": f"Invalid label in detection: {detection['label']}"}), 400
# 验证边界框
bbox = detection['bbox'] bbox = detection['bbox']
if not isinstance(bbox, list) or len(bbox) != 4: if not isinstance(bbox, list) or len(bbox) != 4:
return jsonify({"error": f"Invalid bbox format in detection"}), 400 return jsonify({"error": f"Invalid bbox format in detection"}), 400
for coord in bbox: for coord in bbox:
if not isinstance(coord, int) or not (0 <= coord <= 999): if not isinstance(coord, int) or not (0 <= coord <= 999):
return jsonify({"error": f"Invalid bbox coordinate: {coord}"}), 400 return jsonify({"error": f"Invalid bbox coordinate: {coord}"}), 400
conn = sqlite3.connect(DATABASE_PATH) conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT id FROM images WHERE id = ?", (image_id,)) cursor.execute("SELECT id FROM images WHERE id = ?", (image_id,))
if not cursor.fetchone(): if not cursor.fetchone():
conn.close() conn.close()
return jsonify({"error": "Image not found"}), 404 return jsonify({"error": "Image not found"}), 404
# --- Removed ALTER TABLE logic as columns are now in init_db ---
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN manual_detections_left TEXT
# """)
# except sqlite3.OperationalError:
# pass
#
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN manual_detections_right TEXT
# """)
# except sqlite3.OperationalError:
# pass
#
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN is_manual_labeled_left INTEGER DEFAULT 0
# """)
# except sqlite3.OperationalError:
# pass
#
# try:
# cursor.execute("""
# ALTER TABLE images ADD COLUMN is_manual_labeled_right INTEGER DEFAULT 0
# """)
# except sqlite3.OperationalError:
# pass
# --- End Removal ---
# 根据 side 参数更新对应的人工标注结果
if side == 'left': if side == 'left':
cursor.execute(""" cursor.execute("""
UPDATE images UPDATE images
@@ -621,7 +464,6 @@ def update_manual_detections():
conn.commit() conn.commit()
conn.close() conn.close()
# 重新生成标注图片
try: try:
regenerate_marked_images(image_id, detections, side) regenerate_marked_images(image_id, detections, side)
return jsonify({"message": f"Manual detections for image {image_id} ({side}) updated successfully and marked images regenerated"}) return jsonify({"message": f"Manual detections for image {image_id} ({side}) updated successfully and marked images regenerated"})
@@ -639,17 +481,13 @@ def regenerate_marked_images(image_id, detections, side):
""", (image_id,)) """, (image_id,))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
if not row: if not row:
raise Exception("Image not found") raise Exception("Image not found")
left_filename, right_filename, left_marked_filename, right_marked_filename, left_position, right_position = row left_filename, right_filename, left_marked_filename, right_marked_filename, left_position, right_position = row
# 根据指定的 side 重新生成对应的标注图片
if side == 'left' and left_marked_filename: if side == 'left' and left_marked_filename:
left_path = os.path.join(SAVE_PATH_LEFT, left_filename) left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
if os.path.exists(left_path): if os.path.exists(left_path):
img_left = cv2.imread(left_path) img_left = cv2.imread(left_path)
if img_left is not None: if img_left is not None:
@@ -659,7 +497,6 @@ def regenerate_marked_images(image_id, detections, side):
elif side == 'right' and right_marked_filename: elif side == 'right' and right_marked_filename:
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename) right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
if os.path.exists(right_path): if os.path.exists(right_path):
img_right = cv2.imread(right_path) img_right = cv2.imread(right_path)
if img_right is not None: if img_right is not None:
@@ -667,8 +504,10 @@ def regenerate_marked_images(image_id, detections, side):
cv2.imwrite(right_marked_path, marked_right_img) cv2.imwrite(right_marked_path, marked_right_img)
@app.route('/manual-annotation') @app.route('/manual-annotation')
@auth.login_required # 需要认证才能访问标注页
def manual_annotation(): def manual_annotation():
"""标注页面""" """标注页面"""
logger.info(f"User {auth.current_user()} accessed the manual annotation page.")
return render_template('manual_annotation.html') return render_template('manual_annotation.html')
@app.route('/view') @app.route('/view')
@@ -698,7 +537,8 @@ def capture_button(data):
except Exception as e: except Exception as e:
logger.error(f"Error sending request: {e}") logger.error(f"Error sending request: {e}")
if __name__ == '__main__': if __name__ == '__main__':
logger.info(f"Starting Flask-SocketIO server on {FLASK_HOST}:{FLASK_PORT}") logger.info(f"Starting Flask-SocketIO server on {FLASK_HOST}:{FLASK_PORT}")
os.environ['BASIC_AUTH_USERNAME'] = USERNAME
os.environ['BASIC_AUTH_PASSWORD'] = PASSWORD
socketio.run(app, host=FLASK_HOST, port=FLASK_PORT, debug=False, allow_unsafe_werkzeug=True) socketio.run(app, host=FLASK_HOST, port=FLASK_PORT, debug=False, allow_unsafe_werkzeug=True)