feat: 增加标注相关功能

This commit is contained in:
2025-10-26 14:25:30 +08:00
parent c8dfec6cf4
commit 856669de69
4 changed files with 795 additions and 42 deletions

View File

@@ -14,6 +14,9 @@ from datetime import datetime
import zipfile
import tempfile
# 导入视觉处理相关的模块
from llm_req import VisionAPIClient, DetectionResult
from cap_trigger import ImageClient
@@ -21,6 +24,8 @@ from cap_trigger import ImageClient
# --- 配置 ---
SAVE_PATH_LEFT = "./static/received/left"
SAVE_PATH_RIGHT = "./static/received/right"
SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径
SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径
FLASK_HOST = "0.0.0.0"
FLASK_PORT = 5000
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示
@@ -36,12 +41,14 @@ def init_db():
"""初始化 SQLite 数据库和表"""
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 创建表,存储图片信息
# 创建表,存储图片信息,添加标注图片字段
cursor.execute('''
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
left_filename TEXT NOT NULL,
right_filename TEXT NOT NULL,
left_marked_filename TEXT,
right_marked_filename TEXT,
timestamp REAL NOT NULL,
metadata TEXT,
comment TEXT,
@@ -72,6 +79,51 @@ image_client = ImageClient("tcp://127.0.0.1:54321", client_id="local")
# 初始化数据库
init_db()
# --- 辅助函数 ---
def draw_detections_on_image(image: np.ndarray, detections: list) -> np.ndarray:
"""在图像上绘制检测框"""
# 复制原图以避免修改原始图像
marked_image = image.copy()
# 定义颜色映射
color_map = {
1: (0, 255, 0), # 绿色 - 弹药箱
2: (255, 0, 0), # 蓝色 - 士兵
3: (0, 0, 255), # 红色 - 枪支
4: (255, 255, 0) # 青色 - 数字牌
}
# 获取图像尺寸
h, w = image.shape[:2]
# 绘制每个检测框
for detection in detections:
# 获取检测信息
obj_id = detection.get("id", 0)
label = detection.get("label", "")
bbox = detection.get("bbox", [])
if len(bbox) == 4:
# 将归一化的坐标转换为实际像素坐标
x_min = int(bbox[0] * w / 999)
y_min = int(bbox[1] * h / 999)
x_max = int(bbox[2] * w / 999)
y_max = int(bbox[3] * h / 999)
# 获取颜色
color = color_map.get(obj_id, (255, 255, 255)) # 默认白色
# 绘制边界框
cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2)
# 添加标签
label_text = f"{label} ({obj_id})"
cv2.putText(marked_image, label_text, (x_min, y_min - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
return marked_image
# --- Flask 路由 ---
@app.route('/')
@@ -89,8 +141,8 @@ def get_images_api():
"""API: 获取图片列表 (JSON 格式)"""
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 按时间倒序排列
cursor.execute("SELECT id, left_filename, right_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
# 按时间倒序排列,包含标注图片字段
cursor.execute("SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
rows = cursor.fetchall()
conn.close()
@@ -100,10 +152,12 @@ def get_images_api():
"id": row[0],
"left_filename": row[1],
"right_filename": row[2],
"timestamp": row[3],
"metadata": row[4],
"comment": row[5] or "", # 如果没有comment则显示空字符串
"created_at": row[6]
"left_marked_filename": row[3] or "", # 如果没有标注图片则显示空字符串
"right_marked_filename": row[4] or "", # 如果没有标注图片则显示空字符串
"timestamp": row[5],
"metadata": row[6],
"comment": row[7] or "", # 如果没有 comment 则显示空字符串
"created_at": row[8]
})
return jsonify(images)
@@ -116,22 +170,25 @@ def delete_image_api():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 查询文件名
cursor.execute("SELECT left_filename, right_filename FROM images WHERE id = ?", (image_id,))
# 查询文件名,包含标注图片
cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,))
row = cursor.fetchone()
if not row:
conn.close()
return jsonify({"error": "Image not found"}), 404
left_filename, right_filename = row
left_filename, right_filename, left_marked_filename, right_marked_filename = row
# 删除数据库记录
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
conn.commit()
conn.close()
# 删除对应的文件
# 删除对应的文件,包括标注图片
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None
try:
if os.path.exists(left_path):
os.remove(left_path)
@@ -139,6 +196,12 @@ def delete_image_api():
if os.path.exists(right_path):
os.remove(right_path)
logger.info(f"Deleted file: {right_path}")
if left_marked_path and os.path.exists(left_marked_path):
os.remove(left_marked_path)
logger.info(f"Deleted file: {left_marked_path}")
if right_marked_path and os.path.exists(right_marked_path):
os.remove(right_marked_path)
logger.info(f"Deleted file: {right_marked_path}")
except OSError as e:
logger.error(f"Error deleting files: {e}")
# 即使删除文件失败,数据库记录也已删除,返回成功
@@ -148,7 +211,7 @@ def delete_image_api():
@app.route('/api/images/export', methods=['POST'])
def export_images_api():
"""API: 打包导出选中的图片"""
"""API: 打包导出选中的图片,优先导出标注图片"""
selected_ids = request.json.get('ids', [])
if not selected_ids:
return jsonify({"error": "No image IDs selected"}), 400
@@ -156,7 +219,8 @@ def export_images_api():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
placeholders = ','.join('?' * len(selected_ids))
cursor.execute(f"SELECT left_filename, right_filename FROM images WHERE id IN ({placeholders})", selected_ids)
# 查询包含标注图片的文件名
cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids)
rows = cursor.fetchall()
conn.close()
@@ -169,13 +233,20 @@ def export_images_api():
try:
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for left_fn, right_fn in rows:
left_path = os.path.join(SAVE_PATH_LEFT, left_fn)
right_path = os.path.join(SAVE_PATH_RIGHT, right_fn)
if os.path.exists(left_path):
zipf.write(left_path, os.path.join('left', left_fn))
if os.path.exists(right_path):
zipf.write(right_path, os.path.join('right', right_fn))
for left_fn, right_fn, left_marked_fn, right_marked_fn in rows:
# 优先使用标注图片,如果没有则使用原图
left_export_fn = left_marked_fn if left_marked_fn else left_fn
right_export_fn = right_marked_fn if right_marked_fn else right_fn
# 确定文件路径
left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn)
right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn)
# 添加到 ZIP 文件
if os.path.exists(left_export_path):
zipf.write(left_export_path, os.path.join('left', left_export_fn))
if os.path.exists(right_export_path):
zipf.write(right_export_path, os.path.join('right', right_export_fn))
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
# 返回 ZIP 文件给客户端
@@ -190,13 +261,13 @@ def export_images_api():
@app.route('/upload', methods=['POST'])
def upload_images():
"""接收左右摄像头图片,保存并推送更新"""
"""接收左右摄像头图片,保存并推送更新,同时生成标注图片"""
try:
# 从 multipart/form-data 中获取文件
left_file = request.files.get('left_image')
right_file = request.files.get('right_image')
metadata_str = request.form.get('metadata') # 如果需要处理元数据
comment = request.form.get('comment', '') # 获取comment字段
comment = request.form.get('comment', '') # 获取 comment 字段
if not left_file or not right_file:
logger.warning("Received request without required image files.")
@@ -230,26 +301,67 @@ def upload_images():
# 生成文件名
left_filename = f"left_{timestamp_str_safe}.jpg"
right_filename = f"right_{timestamp_str_safe}.jpg"
left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
# 保存图到本地
# 保存图到本地
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
# 确保目录存在
os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录
os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录
cv2.imwrite(left_path, img_left)
cv2.imwrite(right_path, img_right)
logger.info(f"Saved images: {left_path}, {right_path}")
logger.info(f"Saved original images: {left_path}, {right_path}")
# 将图片信息写入数据库
# 使用 VisionAPIClient 处理图片并生成标注图片
left_marked_path = None
right_marked_path = None
try:
with VisionAPIClient() as client:
# 处理左图
left_task_id = client.submit_task(image_id=1, image=img_left)
# 处理右图
right_task_id = client.submit_task(image_id=2, image=img_right)
# 等待任务完成
client.task_queue.join()
# 获取处理结果
left_result = client.get_result(left_task_id)
right_result = client.get_result(right_task_id)
# 生成标注图片
if left_result and left_result.success:
marked_left_img = draw_detections_on_image(img_left, left_result.detections)
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
cv2.imwrite(left_marked_path, marked_left_img)
logger.info(f"Saved marked left image: {left_marked_path}")
if right_result and right_result.success:
marked_right_img = draw_detections_on_image(img_right, right_result.detections)
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
cv2.imwrite(right_marked_path, marked_right_img)
logger.info(f"Saved marked right image: {right_marked_path}")
except Exception as e:
logger.error(f"Error processing images with VisionAPIClient: {e}")
# 即使处理失败,也继续保存原图
# 将图片信息写入数据库,包含标注图片字段
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO images (left_filename, right_filename, timestamp, metadata, comment)
VALUES (?, ?, ?, ?, ?)
''', (left_filename, right_filename, float(timestamp_str), json.dumps(metadata), comment))
INSERT INTO images (left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (left_filename, right_filename,
left_marked_filename if left_marked_path else None,
right_marked_filename if right_marked_path else None,
float(timestamp_str), json.dumps(metadata), comment))
conn.commit()
image_id = cursor.lastrowid # 获取新插入记录的 ID
conn.close()
@@ -282,7 +394,7 @@ def upload_images():
@app.route('/api/images/comment', methods=['PUT'])
def update_image_comment():
"""API: 更新图片的comment"""
"""API: 更新图片的 comment"""
data = request.json
image_id = data.get('id')
comment = data.get('comment', '')
@@ -292,7 +404,7 @@ def update_image_comment():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# 更新comment字段
# 更新 comment 字段
cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id))
conn.commit()
conn.close()

View File

@@ -4,7 +4,7 @@ import uuid
import time
class ImageClient:
def __init__(self, server_address="tcp://10.42.70.1:54321", client_id=None):
def __init__(self, server_address="tcp://:54321", client_id=None):
self.server_address = server_address
self.client_id = client_id or f"client_{uuid.uuid4().hex[:8]}"
self.socket = None
@@ -43,7 +43,7 @@ class ImageClient:
"""发送同步请求到服务器"""
socket = None
try:
# 每次都创建一个新的socket实例以避免复用已关闭的socket
# 每次都创建一个新的 socket 实例以避免复用已关闭的 socket
socket = pynng.Req0()
socket.dial(self.server_address, block=True) # 使用阻塞模式确保连接建立
client_id_bytes = self.client_id.encode('utf-8')
@@ -54,7 +54,7 @@ class ImageClient:
print(f"Client {self.client_id} error: {e}")
return None
finally:
# 确保socket在使用后被正确关闭
# 确保 socket 在使用后被正确关闭
if socket:
try:
socket.close()

610
llm_req.py Normal file
View File

@@ -0,0 +1,610 @@
import base64
import json
import cv2
import numpy as np
from typing import List, Dict, Any, Optional, Callable
from openai import OpenAI
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading
from dataclasses import dataclass
from enum import Enum
import queue
import time
# Configuration
API_KEY = "sk-e3a0287ece6a41bb9b79b2c285f10197"
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
MODEL_NAME = "qwen-vl-plus"
# Category mapping
CATEGORY_MAPPING = {
1: "caisson",
2: "soldier",
3: "gun",
4: "number"
}
CATEGORY_COLORS = {
1: (0, 255, 0), # Green for caisson
2: (0, 255, 255), # Yellow for soldier
3: (0, 0, 255), # Red for gun
4: (255, 0, 0) # Blue for number
}
@dataclass
class DetectionResult:
"""Detection result data class"""
image_id: str
original_image: np.ndarray
detections: List[Dict[str, Any]]
marked_image: Optional[np.ndarray] = None
success: bool = False
error_message: Optional[str] = None
timestamp: float = 0.0
class TaskStatus(Enum):
"""Task status enumeration"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class Task:
"""Task data class for queue"""
task_id: str
image_id: str
image: np.ndarray # OpenCV Mat format
prompt: str
callback: Optional[Callable[[DetectionResult], None]] = None
timestamp: float = 0.0
class VisionAPIClient:
"""Vision API Client for asynchronous processing with OpenCV Mat input/output"""
def __init__(self, api_key: str = API_KEY, base_url: str = BASE_URL,
model_name: str = MODEL_NAME, max_workers: int = 4):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
self.max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.client = OpenAI(api_key=api_key, base_url=base_url)
# Task management
self.task_queue = queue.Queue()
self.processing_tasks = {}
self.results_cache = {}
self._running = False
self._worker_thread = None
self.default_prompt = """Please perform object detection on the image, identifying and localizing the following four types of targets:
- Category 1: Green ammunition box (caisson)
- Category 2: Dummy soldier wearing digital camouflage uniform (soldier)
- Category 3: Gun (gun)
- Category 4: Round blue number plate (number)
Please follow these requirements for the output:
1. The output must be in valid JSON format.
2. The JSON structure should contain a list named "detections".
3. Each element in the list represents a detected target, containing the following fields:
- "id" (integer): Target category ID (1, 2, 3, 4).
- "label" (string): Target category name ("caisson", "soldier", "gun", "number").
- "bbox" (list of int): Bounding box coordinates in format [x_min, y_min, x_max, y_max], where (x_min, y_min) is the top-left coordinate and (x_max, y_max) is the bottom-right coordinate. Coordinate values are integers normalized to the 0-999 range (0,0 represents top-left, 999,999 represents bottom-right).
4. If no targets are detected in the image, "detections" should be an empty list [].
5. Please output only JSON, no other explanatory text.
JSON output example (when targets are detected):
{
"detections": [
{
"id": 1,
"label": "caisson",
"bbox": [x1, y1, x2, y2] // x1, y1, x2, y2 are integers in the 0-999 range
},
{
"id": 2,
"label": "soldier",
"bbox": [x3, y3, x4, y4] // x3, y3, x4, y4 are integers in the 0-999 range
}
]
}
JSON output example (when no targets are detected):
{
"detections": []
}"""
def encode_cv_image(self, image: np.ndarray) -> str:
"""
Encodes an OpenCV image (Mat) to base64 string.
Args:
image: OpenCV image (Mat) format
Returns:
Base64 encoded string of the image
"""
# Encode the image to JPEG format
_, buffer = cv2.imencode('.jpg', image)
return base64.b64encode(buffer).decode('utf-8')
def validate_and_extract_json(self, response_text: str) -> Optional[Dict[str, Any]]:
"""
Validates and extracts JSON from API response text.
Args:
response_text: Raw response text from API
Returns:
Parsed JSON dictionary if valid, None otherwise
"""
# Try to find JSON within the response text (in case of additional text)
start_idx = response_text.find('{')
end_idx = response_text.rfind('}')
if start_idx == -1 or end_idx == -1 or start_idx > end_idx:
print("No valid JSON structure found in response.")
return None
json_str = response_text[start_idx:end_idx+1]
try:
parsed_json = json.loads(json_str)
return parsed_json
except json.JSONDecodeError as e:
print(f"JSON parsing failed: {e}")
print(f"Problematic JSON string: {json_str[:200]}...") # Show first 200 chars
return None
def validate_detections_format(self, data: Dict[str, Any]) -> bool:
"""
Validates the structure and content of the detections JSON.
Args:
data: Parsed JSON data
Returns:
True if format is valid, False otherwise
"""
if not isinstance(data, dict) or "detections" not in data:
print("Missing 'detections' key in response.")
return False
detections = data["detections"]
if not isinstance(detections, list):
print("'detections' is not a list.")
return False
for i, detection in enumerate(detections):
if not isinstance(detection, dict):
print(f"Detection item {i} is not a dictionary.")
return False
required_keys = ["id", "label", "bbox"]
for key in required_keys:
if key not in detection:
print(f"Missing required key '{key}' in detection {i}.")
return False
# Validate ID
if not isinstance(detection["id"], int) or detection["id"] not in [1, 2, 3, 4]:
print(f"Invalid ID in detection {i}: {detection['id']}")
return False
# Validate label
if not isinstance(detection["label"], str) or detection["label"] not in CATEGORY_MAPPING.values():
print(f"Invalid label in detection {i}: {detection['label']}")
return False
# Validate bbox
bbox = detection["bbox"]
if not isinstance(bbox, list) or len(bbox) != 4:
print(f"Invalid bbox format in detection {i}: {bbox}")
return False
for coord in bbox:
if not isinstance(coord, (int, float)) or not (0 <= coord <= 999):
print(f"Invalid bbox coordinate in detection {i}: {coord}")
return False
# Validate confidence if present
if "confidence" in detection:
conf = detection["confidence"]
if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0):
print(f"Invalid confidence in detection {i}: {conf}")
return False
return True
def call_vision_api_sync(self, image: np.ndarray, prompt: str) -> Optional[Dict[str, Any]]:
"""
Synchronous call to the vision API with OpenCV Mat input.
Args:
image: OpenCV image (Mat) format
prompt: The prompt to send to the API
Returns:
Parsed JSON response if successful, None otherwise
"""
# Resize image to 1000x1000 if needed
h, w = image.shape[:2]
if h != 1000 or w != 1000:
image = cv2.resize(image, (1000, 1000))
# Encode the image directly to base64
image_base64 = self.encode_cv_image(image)
image_url = f"data:image/jpeg;base64,{image_base64}"
try:
# Create the completion request
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
'role': 'user',
'content': [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}}
]
}
],
stream=False # Set to False for single response instead of streaming
)
# Extract content from response
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
else:
print("No choices returned from API.")
return None
if not content:
print("No content returned from API.")
return None
# Validate and parse JSON from response
parsed_data = self.validate_and_extract_json(content)
if parsed_data is None:
return None
# Validate the structure of the detections
if not self.validate_detections_format(parsed_data):
print("Invalid detections format in response.")
return None
return parsed_data
except Exception as e:
print(f"API request failed: {e}")
return None
def draw_detections_on_image(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
"""
Draws bounding boxes and labels on the OpenCV Mat image (without confidence).
Args:
image: OpenCV image (Mat) format
detections: List of detection dictionaries
Returns:
Image with drawn detections as numpy array
"""
# Work on a copy to avoid modifying the original
result_image = image.copy()
# Get image dimensions
img_h, img_w = result_image.shape[:2]
for detection in detections:
# Get bounding box coordinates (normalized to 0-999 range)
bbox = detection["bbox"]
x1_norm, y1_norm, x2_norm, y2_norm = bbox
# Convert normalized coordinates to pixel coordinates
x1 = int((x1_norm / 999) * img_w)
y1 = int((y1_norm / 999) * img_h)
x2 = int((x2_norm / 999) * img_w)
y2 = int((y2_norm / 999) * img_h)
# Ensure coordinates are within image bounds
x1 = max(0, min(x1, img_w - 1))
y1 = max(0, min(y1, img_h - 1))
x2 = max(0, min(x2, img_w - 1))
y2 = max(0, min(y2, img_h - 1))
# Get color and label
category_id = detection["id"]
label = detection["label"]
color = CATEGORY_COLORS.get(category_id, (255, 255, 255)) # Default white if not found
# Draw bounding box
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
# Prepare label text (no confidence)
label_text = label
# Calculate text size and position
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
(text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
# Draw label background
cv2.rectangle(result_image, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
# Draw label text
cv2.putText(result_image, label_text, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness)
return result_image
def process_single_image(self, image_id: str, image: np.ndarray, prompt: str = None) -> DetectionResult:
"""
Process a single OpenCV Mat image synchronously.
Args:
image_id: Unique identifier for the image
image: OpenCV image (Mat) format
prompt: The prompt for the vision API (optional)
Returns:
DetectionResult containing the results
"""
start_time = time.time()
# Validate input image
if image is None or image.size == 0:
error_msg = f"Invalid image for image_id: {image_id}"
print(error_msg)
return DetectionResult(
image_id=image_id,
original_image=image,
detections=[],
success=False,
error_message=error_msg,
timestamp=start_time
)
# Use provided prompt or default
use_prompt = prompt if prompt is not None else self.default_prompt
# Call the vision API
print(f"Calling vision API for image {image_id}...")
result = self.call_vision_api_sync(image, use_prompt)
if result is None:
error_msg = "Failed to get valid response from API."
print(error_msg)
return DetectionResult(
image_id=image_id,
original_image=image,
detections=[],
success=False,
error_message=error_msg,
timestamp=start_time
)
# Extract detections
detections = result.get("detections", [])
print(f"Found {len(detections)} detections for image {image_id}.")
# Draw detections on image
try:
marked_image = self.draw_detections_on_image(image, detections)
except Exception as e:
error_msg = f"Error drawing detections on image: {e}"
print(error_msg)
return DetectionResult(
image_id=image_id,
original_image=image,
detections=detections,
success=False,
error_message=error_msg,
timestamp=start_time
)
# Return successful result
return DetectionResult(
image_id=image_id,
original_image=image,
detections=detections,
marked_image=marked_image,
success=True,
timestamp=start_time
)
def start_worker(self):
"""Start the background worker thread for processing tasks"""
if self._running:
return
self._running = True
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self._worker_thread.start()
print("Vision API worker started.")
def stop_worker(self):
"""Stop the background worker thread"""
self._running = False
if self._worker_thread:
self._worker_thread.join(timeout=5.0) # Wait up to 5 seconds
print("Vision API worker stopped.")
def _worker_loop(self):
"""Background worker loop for processing tasks from queue"""
while self._running:
try:
# Get task from queue with timeout
task = self.task_queue.get(timeout=1.0)
# Mark as processing
self.processing_tasks[task.task_id] = TaskStatus.PROCESSING
# Process the task
result = self.process_single_image(task.image_id, task.image, task.prompt)
# Store result
self.results_cache[task.task_id] = result
# Update task status
self.processing_tasks[task.task_id] = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
# Call callback if provided
if task.callback:
try:
task.callback(result)
except Exception as e:
print(f"Callback execution failed for task {task.task_id}: {e}")
# Mark task as done
self.task_queue.task_done()
except queue.Empty:
continue # Timeout, continue loop
except Exception as e:
print(f"Worker error: {e}")
continue
def submit_task(self, image_id: int, image: np.ndarray, prompt: str = None,
callback: Callable[[DetectionResult], None] = None) -> str:
"""
Submit a task to the processing queue with OpenCV Mat input.
Args:
image_id: Unique identifier for the image
image: OpenCV image (Mat) format
prompt: The prompt for the vision API (optional)
callback: Callback function to be called when processing is complete (optional)
Returns:
Task ID for tracking the task
"""
task_id = f"task_{int(time.time() * 1000000)}_{image_id}" # Generate unique task ID
task = Task(
task_id=task_id,
image_id=image_id,
image=image,
prompt=prompt if prompt is not None else self.default_prompt,
callback=callback,
timestamp=time.time()
)
self.task_queue.put(task)
self.processing_tasks[task_id] = TaskStatus.PENDING
return task_id
def get_result(self, task_id: str) -> Optional[DetectionResult]:
"""
Get the result for a specific task.
Args:
task_id: The task ID to retrieve result for
Returns:
DetectionResult if available, None otherwise
"""
return self.results_cache.get(task_id)
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
"""
Get the status of a specific task.
Args:
task_id: The task ID to check status for
Returns:
TaskStatus if task exists, None otherwise
"""
return self.processing_tasks.get(task_id)
def get_queue_size(self) -> int:
"""Get the current size of the task queue"""
return self.task_queue.qsize()
def get_processing_count(self) -> int:
"""Get the number of currently processing tasks"""
return sum(1 for status in self.processing_tasks.values()
if status == TaskStatus.PROCESSING)
def get_completed_count(self) -> int:
"""Get the number of completed tasks"""
return sum(1 for status in self.processing_tasks.values()
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED])
def clear_results(self):
"""Clear the results cache"""
self.results_cache.clear()
def __enter__(self):
"""Context manager entry"""
self.start_worker()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.stop_worker()
# Example usage
def example_callback(result: DetectionResult):
"""Example callback function"""
if result.success:
print(f"Callback: Processing completed for image {result.image_id}, found {len(result.detections)} detections")
# The result.marked_image is the OpenCV Mat with detections drawn
marked_image = result.marked_image
# You can now use the marked_image for further processing
else:
print(f"Callback: Processing failed for image {result.image_id}: {result.error_message}")
def main():
original_image = cv2.imread("/home/evan/Desktop/received/left/left_1761388243_7673044.jpg") # Replace with your image source
if original_image is None:
print("Could not load image")
return
# Example usage with context manager
with VisionAPIClient() as client:
# Submit a task with OpenCV Mat
task_id = client.submit_task(
image_id=1,
image=original_image,
callback=example_callback
)
print(f"Submitted task {task_id}")
# Wait for the task to complete
print("Waiting for task to complete...")
client.task_queue.join() # Wait for all tasks in queue to be processed
# Get the result
result = client.get_result(task_id)
if result:
if result.success:
print(f"Task completed successfully! Found {len(result.detections)} detections.")
# result.marked_image is the OpenCV Mat with detections drawn
marked_image = result.marked_image
# Display the result (optional)
cv2.imshow("Original Image", original_image)
cv2.imshow("Marked Image", marked_image)
print("Press any key to close windows...")
cv2.waitKey(0)
cv2.destroyAllWindows()
# Or save the result
# cv2.imwrite("marked_image.jpg", marked_image)
else:
print(f"Task failed: {result.error_message}")
else:
print("No result found")
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,11 @@
<!DOCTYPE html>
<html>
<html lang="en">
<head>
<title>Saved Images List</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.min.js "></script>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Saved Images</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.js"></script>
<style>
body {
font-family: Arial, sans-serif;
@@ -15,11 +17,12 @@
}
button {
padding: 8px 16px;
padding: 10px 15px;
margin-right: 10px;
background-color: #007bff;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}
@@ -97,6 +100,8 @@
<th>ID</th>
<th>Left Image</th>
<th>Right Image</th>
<th>Left Marked Image</th> <!-- 新增标注图片列 -->
<th>Right Marked Image</th> <!-- 新增标注图片列 -->
<th>Timestamp</th>
<th>Comment</th>
<th>Actions</th>
@@ -138,6 +143,7 @@
row.insertCell(0).innerHTML = `<input type="checkbox" class="selectCheckbox" data-id="${image.id}">`;
row.insertCell(1).textContent = image.id;
// 原图
const leftCell = row.insertCell(2);
const leftImg = document.createElement('img');
// 修改这里:使用 Flask 静态文件路径
@@ -156,10 +162,35 @@
rightImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
rightCell.appendChild(rightImg);
row.insertCell(4).textContent = new Date(image.timestamp * 1000).toISOString();
// 标注图片
const leftMarkedCell = row.insertCell(4);
if (image.left_marked_filename) {
const leftMarkedImg = document.createElement('img');
leftMarkedImg.src = `/static/received/left_marked/${image.left_marked_filename}`;
leftMarkedImg.alt = "Left Marked Image";
leftMarkedImg.className = 'image-preview';
leftMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
leftMarkedCell.appendChild(leftMarkedImg);
} else {
leftMarkedCell.textContent = "N/A";
}
// 添加可编辑的comment单元格
const commentCell = row.insertCell(5);
const rightMarkedCell = row.insertCell(5);
if (image.right_marked_filename) {
const rightMarkedImg = document.createElement('img');
rightMarkedImg.src = `/static/received/right_marked/${image.right_marked_filename}`;
rightMarkedImg.alt = "Right Marked Image";
rightMarkedImg.className = 'image-preview';
rightMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
rightMarkedCell.appendChild(rightMarkedImg);
} else {
rightMarkedCell.textContent = "N/A";
}
row.insertCell(6).textContent = new Date(image.timestamp * 1000).toISOString();
// 添加可编辑的 comment 单元格
const commentCell = row.insertCell(7);
const commentInput = document.createElement('input');
commentInput.type = 'text';
commentInput.value = image.comment || '';
@@ -171,11 +202,11 @@
});
commentCell.appendChild(commentInput);
row.insertCell(6).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
row.insertCell(8).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
});
}
// 添加更新comment的函数
// 添加更新 comment 的函数
async function updateComment(id, comment) {
try {
const response = await fetch('/api/images/comment', {