feat: 增加标注相关功能

This commit is contained in:
2025-10-26 14:25:30 +08:00
parent c8dfec6cf4
commit 856669de69
4 changed files with 795 additions and 42 deletions

View File

@@ -14,6 +14,9 @@ from datetime import datetime
import zipfile
import tempfile
# 导入视觉处理相关的模块
from llm_req import VisionAPIClient, DetectionResult
from cap_trigger import ImageClient
@@ -21,6 +24,8 @@ from cap_trigger import ImageClient
# --- 配置 ---
SAVE_PATH_LEFT = "./static/received/left"
SAVE_PATH_RIGHT = "./static/received/right"
SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径
SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径
FLASK_HOST = "0.0.0.0"
FLASK_PORT = 5000
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示
@@ -36,12 +41,14 @@ def init_db():
"""初始化 SQLite 数据库和表"""
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 创建表,存储图片信息
# 创建表,存储图片信息,添加标注图片字段
cursor.execute('''
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
left_filename TEXT NOT NULL,
right_filename TEXT NOT NULL,
left_marked_filename TEXT,
right_marked_filename TEXT,
timestamp REAL NOT NULL,
metadata TEXT,
comment TEXT,
@@ -72,6 +79,51 @@ image_client = ImageClient("tcp://127.0.0.1:54321", client_id="local")
# 初始化数据库
init_db()
# --- 辅助函数 ---
def draw_detections_on_image(image: np.ndarray, detections: list) -> np.ndarray:
"""在图像上绘制检测框"""
# 复制原图以避免修改原始图像
marked_image = image.copy()
# 定义颜色映射
color_map = {
1: (0, 255, 0), # 绿色 - 弹药箱
2: (255, 0, 0), # 蓝色 - 士兵
3: (0, 0, 255), # 红色 - 枪支
4: (255, 255, 0) # 青色 - 数字牌
}
# 获取图像尺寸
h, w = image.shape[:2]
# 绘制每个检测框
for detection in detections:
# 获取检测信息
obj_id = detection.get("id", 0)
label = detection.get("label", "")
bbox = detection.get("bbox", [])
if len(bbox) == 4:
# 将归一化的坐标转换为实际像素坐标
x_min = int(bbox[0] * w / 999)
y_min = int(bbox[1] * h / 999)
x_max = int(bbox[2] * w / 999)
y_max = int(bbox[3] * h / 999)
# 获取颜色
color = color_map.get(obj_id, (255, 255, 255)) # 默认白色
# 绘制边界框
cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2)
# 添加标签
label_text = f"{label} ({obj_id})"
cv2.putText(marked_image, label_text, (x_min, y_min - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
return marked_image
# --- Flask 路由 ---
@app.route('/')
@@ -89,8 +141,8 @@ def get_images_api():
"""API: 获取图片列表 (JSON 格式)"""
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 按时间倒序排列
cursor.execute("SELECT id, left_filename, right_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
# 按时间倒序排列,包含标注图片字段
cursor.execute("SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
rows = cursor.fetchall()
conn.close()
@@ -100,10 +152,12 @@ def get_images_api():
"id": row[0],
"left_filename": row[1],
"right_filename": row[2],
"timestamp": row[3],
"metadata": row[4],
"comment": row[5] or "", # 如果没有comment则显示空字符串
"created_at": row[6]
"left_marked_filename": row[3] or "", # 如果没有标注图片则显示空字符串
"right_marked_filename": row[4] or "", # 如果没有标注图片则显示空字符串
"timestamp": row[5],
"metadata": row[6],
"comment": row[7] or "", # 如果没有 comment 则显示空字符串
"created_at": row[8]
})
return jsonify(images)
@@ -116,22 +170,25 @@ def delete_image_api():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 查询文件名
cursor.execute("SELECT left_filename, right_filename FROM images WHERE id = ?", (image_id,))
# 查询文件名,包含标注图片
cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,))
row = cursor.fetchone()
if not row:
conn.close()
return jsonify({"error": "Image not found"}), 404
left_filename, right_filename = row
left_filename, right_filename, left_marked_filename, right_marked_filename = row
# 删除数据库记录
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
conn.commit()
conn.close()
# 删除对应的文件
# 删除对应的文件,包括标注图片
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None
try:
if os.path.exists(left_path):
os.remove(left_path)
@@ -139,6 +196,12 @@ def delete_image_api():
if os.path.exists(right_path):
os.remove(right_path)
logger.info(f"Deleted file: {right_path}")
if left_marked_path and os.path.exists(left_marked_path):
os.remove(left_marked_path)
logger.info(f"Deleted file: {left_marked_path}")
if right_marked_path and os.path.exists(right_marked_path):
os.remove(right_marked_path)
logger.info(f"Deleted file: {right_marked_path}")
except OSError as e:
logger.error(f"Error deleting files: {e}")
# 即使删除文件失败,数据库记录也已删除,返回成功
@@ -148,7 +211,7 @@ def delete_image_api():
@app.route('/api/images/export', methods=['POST'])
def export_images_api():
"""API: 打包导出选中的图片"""
"""API: 打包导出选中的图片,优先导出标注图片"""
selected_ids = request.json.get('ids', [])
if not selected_ids:
return jsonify({"error": "No image IDs selected"}), 400
@@ -156,7 +219,8 @@ def export_images_api():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
placeholders = ','.join('?' * len(selected_ids))
cursor.execute(f"SELECT left_filename, right_filename FROM images WHERE id IN ({placeholders})", selected_ids)
# 查询包含标注图片的文件名
cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids)
rows = cursor.fetchall()
conn.close()
@@ -169,13 +233,20 @@ def export_images_api():
try:
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for left_fn, right_fn in rows:
left_path = os.path.join(SAVE_PATH_LEFT, left_fn)
right_path = os.path.join(SAVE_PATH_RIGHT, right_fn)
if os.path.exists(left_path):
zipf.write(left_path, os.path.join('left', left_fn))
if os.path.exists(right_path):
zipf.write(right_path, os.path.join('right', right_fn))
for left_fn, right_fn, left_marked_fn, right_marked_fn in rows:
# 优先使用标注图片,如果没有则使用原图
left_export_fn = left_marked_fn if left_marked_fn else left_fn
right_export_fn = right_marked_fn if right_marked_fn else right_fn
# 确定文件路径
left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn)
right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn)
# 添加到 ZIP 文件
if os.path.exists(left_export_path):
zipf.write(left_export_path, os.path.join('left', left_export_fn))
if os.path.exists(right_export_path):
zipf.write(right_export_path, os.path.join('right', right_export_fn))
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
# 返回 ZIP 文件给客户端
@@ -190,13 +261,13 @@ def export_images_api():
@app.route('/upload', methods=['POST'])
def upload_images():
"""接收左右摄像头图片,保存并推送更新"""
"""接收左右摄像头图片,保存并推送更新,同时生成标注图片"""
try:
# 从 multipart/form-data 中获取文件
left_file = request.files.get('left_image')
right_file = request.files.get('right_image')
metadata_str = request.form.get('metadata') # 如果需要处理元数据
comment = request.form.get('comment', '') # 获取comment字段
comment = request.form.get('comment', '') # 获取 comment 字段
if not left_file or not right_file:
logger.warning("Received request without required image files.")
@@ -230,26 +301,67 @@ def upload_images():
# 生成文件名
left_filename = f"left_{timestamp_str_safe}.jpg"
right_filename = f"right_{timestamp_str_safe}.jpg"
left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
# 保存图到本地
# 保存图到本地
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
# 确保目录存在
os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录
os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录
cv2.imwrite(left_path, img_left)
cv2.imwrite(right_path, img_right)
logger.info(f"Saved images: {left_path}, {right_path}")
logger.info(f"Saved original images: {left_path}, {right_path}")
# 将图片信息写入数据库
# 使用 VisionAPIClient 处理图片并生成标注图片
left_marked_path = None
right_marked_path = None
try:
with VisionAPIClient() as client:
# 处理左图
left_task_id = client.submit_task(image_id=1, image=img_left)
# 处理右图
right_task_id = client.submit_task(image_id=2, image=img_right)
# 等待任务完成
client.task_queue.join()
# 获取处理结果
left_result = client.get_result(left_task_id)
right_result = client.get_result(right_task_id)
# 生成标注图片
if left_result and left_result.success:
marked_left_img = draw_detections_on_image(img_left, left_result.detections)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
cv2.imwrite(left_marked_path, marked_left_img)
logger.info(f"Saved marked left image: {left_marked_path}")
if right_result and right_result.success:
marked_right_img = draw_detections_on_image(img_right, right_result.detections)
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
cv2.imwrite(right_marked_path, marked_right_img)
logger.info(f"Saved marked right image: {right_marked_path}")
except Exception as e:
logger.error(f"Error processing images with VisionAPIClient: {e}")
# 即使处理失败,也继续保存原图
# 将图片信息写入数据库,包含标注图片字段
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO images (left_filename, right_filename, timestamp, metadata, comment)
VALUES (?, ?, ?, ?, ?)
''', (left_filename, right_filename, float(timestamp_str), json.dumps(metadata), comment))
INSERT INTO images (left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (left_filename, right_filename,
left_marked_filename if left_marked_path else None,
right_marked_filename if right_marked_path else None,
float(timestamp_str), json.dumps(metadata), comment))
conn.commit()
image_id = cursor.lastrowid # 获取新插入记录的 ID
conn.close()
@@ -282,7 +394,7 @@ def upload_images():
@app.route('/api/images/comment', methods=['PUT'])
def update_image_comment():
"""API: 更新图片的comment"""
"""API: 更新图片的 comment"""
data = request.json
image_id = data.get('id')
comment = data.get('comment', '')
@@ -292,7 +404,7 @@ def update_image_comment():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 更新comment字段
# 更新 comment 字段
cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id))
conn.commit()
conn.close()