feat: 增加标注相关功能
This commit is contained in:
166
cam_web.py
166
cam_web.py
@@ -14,6 +14,9 @@ from datetime import datetime
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
# 导入视觉处理相关的模块
|
||||
from llm_req import VisionAPIClient, DetectionResult
|
||||
|
||||
from cap_trigger import ImageClient
|
||||
|
||||
|
||||
@@ -21,6 +24,8 @@ from cap_trigger import ImageClient
|
||||
# --- 配置 ---
|
||||
SAVE_PATH_LEFT = "./static/received/left"
|
||||
SAVE_PATH_RIGHT = "./static/received/right"
|
||||
SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径
|
||||
SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径
|
||||
FLASK_HOST = "0.0.0.0"
|
||||
FLASK_PORT = 5000
|
||||
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示
|
||||
@@ -36,12 +41,14 @@ def init_db():
|
||||
"""初始化 SQLite 数据库和表"""
|
||||
conn = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = conn.cursor()
|
||||
# 创建表,存储图片信息
|
||||
# 创建表,存储图片信息,添加标注图片字段
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
left_filename TEXT NOT NULL,
|
||||
right_filename TEXT NOT NULL,
|
||||
left_marked_filename TEXT,
|
||||
right_marked_filename TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
metadata TEXT,
|
||||
comment TEXT,
|
||||
@@ -72,6 +79,51 @@ image_client = ImageClient("tcp://127.0.0.1:54321", client_id="local")
|
||||
# 初始化数据库
|
||||
init_db()
|
||||
|
||||
# --- 辅助函数 ---
|
||||
|
||||
def draw_detections_on_image(image: np.ndarray, detections: list) -> np.ndarray:
|
||||
"""在图像上绘制检测框"""
|
||||
# 复制原图以避免修改原始图像
|
||||
marked_image = image.copy()
|
||||
|
||||
# 定义颜色映射
|
||||
color_map = {
|
||||
1: (0, 255, 0), # 绿色 - 弹药箱
|
||||
2: (255, 0, 0), # 蓝色 - 士兵
|
||||
3: (0, 0, 255), # 红色 - 枪支
|
||||
4: (255, 255, 0) # 青色 - 数字牌
|
||||
}
|
||||
|
||||
# 获取图像尺寸
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# 绘制每个检测框
|
||||
for detection in detections:
|
||||
# 获取检测信息
|
||||
obj_id = detection.get("id", 0)
|
||||
label = detection.get("label", "")
|
||||
bbox = detection.get("bbox", [])
|
||||
|
||||
if len(bbox) == 4:
|
||||
# 将归一化的坐标转换为实际像素坐标
|
||||
x_min = int(bbox[0] * w / 999)
|
||||
y_min = int(bbox[1] * h / 999)
|
||||
x_max = int(bbox[2] * w / 999)
|
||||
y_max = int(bbox[3] * h / 999)
|
||||
|
||||
# 获取颜色
|
||||
color = color_map.get(obj_id, (255, 255, 255)) # 默认白色
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2)
|
||||
|
||||
# 添加标签
|
||||
label_text = f"{label} ({obj_id})"
|
||||
cv2.putText(marked_image, label_text, (x_min, y_min - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
|
||||
return marked_image
|
||||
|
||||
# --- Flask 路由 ---
|
||||
|
||||
@app.route('/')
|
||||
@@ -89,8 +141,8 @@ def get_images_api():
|
||||
"""API: 获取图片列表 (JSON 格式)"""
|
||||
conn = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = conn.cursor()
|
||||
# 按时间倒序排列
|
||||
cursor.execute("SELECT id, left_filename, right_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
|
||||
# 按时间倒序排列,包含标注图片字段
|
||||
cursor.execute("SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
@@ -100,10 +152,12 @@ def get_images_api():
|
||||
"id": row[0],
|
||||
"left_filename": row[1],
|
||||
"right_filename": row[2],
|
||||
"timestamp": row[3],
|
||||
"metadata": row[4],
|
||||
"comment": row[5] or "", # 如果没有comment则显示空字符串
|
||||
"created_at": row[6]
|
||||
"left_marked_filename": row[3] or "", # 如果没有标注图片则显示空字符串
|
||||
"right_marked_filename": row[4] or "", # 如果没有标注图片则显示空字符串
|
||||
"timestamp": row[5],
|
||||
"metadata": row[6],
|
||||
"comment": row[7] or "", # 如果没有 comment 则显示空字符串
|
||||
"created_at": row[8]
|
||||
})
|
||||
return jsonify(images)
|
||||
|
||||
@@ -116,22 +170,25 @@ def delete_image_api():
|
||||
|
||||
conn = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = conn.cursor()
|
||||
# 查询文件名
|
||||
cursor.execute("SELECT left_filename, right_filename FROM images WHERE id = ?", (image_id,))
|
||||
# 查询文件名,包含标注图片
|
||||
cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
conn.close()
|
||||
return jsonify({"error": "Image not found"}), 404
|
||||
|
||||
left_filename, right_filename = row
|
||||
left_filename, right_filename, left_marked_filename, right_marked_filename = row
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 删除对应的文件
|
||||
# 删除对应的文件,包括标注图片
|
||||
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
||||
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None
|
||||
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None
|
||||
|
||||
try:
|
||||
if os.path.exists(left_path):
|
||||
os.remove(left_path)
|
||||
@@ -139,6 +196,12 @@ def delete_image_api():
|
||||
if os.path.exists(right_path):
|
||||
os.remove(right_path)
|
||||
logger.info(f"Deleted file: {right_path}")
|
||||
if left_marked_path and os.path.exists(left_marked_path):
|
||||
os.remove(left_marked_path)
|
||||
logger.info(f"Deleted file: {left_marked_path}")
|
||||
if right_marked_path and os.path.exists(right_marked_path):
|
||||
os.remove(right_marked_path)
|
||||
logger.info(f"Deleted file: {right_marked_path}")
|
||||
except OSError as e:
|
||||
logger.error(f"Error deleting files: {e}")
|
||||
# 即使删除文件失败,数据库记录也已删除,返回成功
|
||||
@@ -148,7 +211,7 @@ def delete_image_api():
|
||||
|
||||
@app.route('/api/images/export', methods=['POST'])
|
||||
def export_images_api():
|
||||
"""API: 打包导出选中的图片"""
|
||||
"""API: 打包导出选中的图片,优先导出标注图片"""
|
||||
selected_ids = request.json.get('ids', [])
|
||||
if not selected_ids:
|
||||
return jsonify({"error": "No image IDs selected"}), 400
|
||||
@@ -156,7 +219,8 @@ def export_images_api():
|
||||
conn = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = conn.cursor()
|
||||
placeholders = ','.join('?' * len(selected_ids))
|
||||
cursor.execute(f"SELECT left_filename, right_filename FROM images WHERE id IN ({placeholders})", selected_ids)
|
||||
# 查询包含标注图片的文件名
|
||||
cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
@@ -169,13 +233,20 @@ def export_images_api():
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for left_fn, right_fn in rows:
|
||||
left_path = os.path.join(SAVE_PATH_LEFT, left_fn)
|
||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_fn)
|
||||
if os.path.exists(left_path):
|
||||
zipf.write(left_path, os.path.join('left', left_fn))
|
||||
if os.path.exists(right_path):
|
||||
zipf.write(right_path, os.path.join('right', right_fn))
|
||||
for left_fn, right_fn, left_marked_fn, right_marked_fn in rows:
|
||||
# 优先使用标注图片,如果没有则使用原图
|
||||
left_export_fn = left_marked_fn if left_marked_fn else left_fn
|
||||
right_export_fn = right_marked_fn if right_marked_fn else right_fn
|
||||
|
||||
# 确定文件路径
|
||||
left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn)
|
||||
right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn)
|
||||
|
||||
# 添加到 ZIP 文件
|
||||
if os.path.exists(left_export_path):
|
||||
zipf.write(left_export_path, os.path.join('left', left_export_fn))
|
||||
if os.path.exists(right_export_path):
|
||||
zipf.write(right_export_path, os.path.join('right', right_export_fn))
|
||||
|
||||
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
|
||||
# 返回 ZIP 文件给客户端
|
||||
@@ -190,7 +261,7 @@ def export_images_api():
|
||||
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload_images():
|
||||
"""接收左右摄像头图片,保存并推送更新"""
|
||||
"""接收左右摄像头图片,保存并推送更新,同时生成标注图片"""
|
||||
try:
|
||||
# 从 multipart/form-data 中获取文件
|
||||
left_file = request.files.get('left_image')
|
||||
@@ -230,26 +301,67 @@ def upload_images():
|
||||
# 生成文件名
|
||||
left_filename = f"left_{timestamp_str_safe}.jpg"
|
||||
right_filename = f"right_{timestamp_str_safe}.jpg"
|
||||
left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
|
||||
right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
|
||||
|
||||
# 保存图片到本地
|
||||
# 保存原图到本地
|
||||
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
|
||||
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
|
||||
os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录
|
||||
os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录
|
||||
|
||||
cv2.imwrite(left_path, img_left)
|
||||
cv2.imwrite(right_path, img_right)
|
||||
logger.info(f"Saved images: {left_path}, {right_path}")
|
||||
logger.info(f"Saved original images: {left_path}, {right_path}")
|
||||
|
||||
# 将图片信息写入数据库
|
||||
# 使用 VisionAPIClient 处理图片并生成标注图片
|
||||
left_marked_path = None
|
||||
right_marked_path = None
|
||||
|
||||
try:
|
||||
with VisionAPIClient() as client:
|
||||
# 处理左图
|
||||
left_task_id = client.submit_task(image_id=1, image=img_left)
|
||||
# 处理右图
|
||||
right_task_id = client.submit_task(image_id=2, image=img_right)
|
||||
|
||||
# 等待任务完成
|
||||
client.task_queue.join()
|
||||
|
||||
# 获取处理结果
|
||||
left_result = client.get_result(left_task_id)
|
||||
right_result = client.get_result(right_task_id)
|
||||
|
||||
# 生成标注图片
|
||||
if left_result and left_result.success:
|
||||
marked_left_img = draw_detections_on_image(img_left, left_result.detections)
|
||||
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
|
||||
cv2.imwrite(left_marked_path, marked_left_img)
|
||||
logger.info(f"Saved marked left image: {left_marked_path}")
|
||||
|
||||
if right_result and right_result.success:
|
||||
marked_right_img = draw_detections_on_image(img_right, right_result.detections)
|
||||
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
|
||||
cv2.imwrite(right_marked_path, marked_right_img)
|
||||
logger.info(f"Saved marked right image: {right_marked_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing images with VisionAPIClient: {e}")
|
||||
# 即使处理失败,也继续保存原图
|
||||
|
||||
# 将图片信息写入数据库,包含标注图片字段
|
||||
conn = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO images (left_filename, right_filename, timestamp, metadata, comment)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (left_filename, right_filename, float(timestamp_str), json.dumps(metadata), comment))
|
||||
INSERT INTO images (left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (left_filename, right_filename,
|
||||
left_marked_filename if left_marked_path else None,
|
||||
right_marked_filename if right_marked_path else None,
|
||||
float(timestamp_str), json.dumps(metadata), comment))
|
||||
conn.commit()
|
||||
image_id = cursor.lastrowid # 获取新插入记录的 ID
|
||||
conn.close()
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
import time
|
||||
|
||||
class ImageClient:
|
||||
def __init__(self, server_address="tcp://10.42.70.1:54321", client_id=None):
|
||||
def __init__(self, server_address="tcp://:54321", client_id=None):
|
||||
self.server_address = server_address
|
||||
self.client_id = client_id or f"client_{uuid.uuid4().hex[:8]}"
|
||||
self.socket = None
|
||||
|
||||
610
llm_req.py
Normal file
610
llm_req.py
Normal file
@@ -0,0 +1,610 @@
|
||||
import base64
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from openai import OpenAI
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import queue
|
||||
import time
|
||||
|
||||
# Configuration
|
||||
API_KEY = "sk-e3a0287ece6a41bb9b79b2c285f10197"
|
||||
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
MODEL_NAME = "qwen-vl-plus"
|
||||
|
||||
# Category mapping
|
||||
CATEGORY_MAPPING = {
|
||||
1: "caisson",
|
||||
2: "soldier",
|
||||
3: "gun",
|
||||
4: "number"
|
||||
}
|
||||
|
||||
CATEGORY_COLORS = {
|
||||
1: (0, 255, 0), # Green for caisson
|
||||
2: (0, 255, 255), # Yellow for soldier
|
||||
3: (0, 0, 255), # Red for gun
|
||||
4: (255, 0, 0) # Blue for number
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Detection result data class"""
|
||||
image_id: str
|
||||
original_image: np.ndarray
|
||||
detections: List[Dict[str, Any]]
|
||||
marked_image: Optional[np.ndarray] = None
|
||||
success: bool = False
|
||||
error_message: Optional[str] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task status enumeration"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task data class for queue"""
|
||||
task_id: str
|
||||
image_id: str
|
||||
image: np.ndarray # OpenCV Mat format
|
||||
prompt: str
|
||||
callback: Optional[Callable[[DetectionResult], None]] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
class VisionAPIClient:
|
||||
"""Vision API Client for asynchronous processing with OpenCV Mat input/output"""
|
||||
|
||||
def __init__(self, api_key: str = API_KEY, base_url: str = BASE_URL,
|
||||
model_name: str = MODEL_NAME, max_workers: int = 4):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.max_workers = max_workers
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
# Task management
|
||||
self.task_queue = queue.Queue()
|
||||
self.processing_tasks = {}
|
||||
self.results_cache = {}
|
||||
self._running = False
|
||||
self._worker_thread = None
|
||||
|
||||
self.default_prompt = """Please perform object detection on the image, identifying and localizing the following four types of targets:
|
||||
- Category 1: Green ammunition box (caisson)
|
||||
- Category 2: Dummy soldier wearing digital camouflage uniform (soldier)
|
||||
- Category 3: Gun (gun)
|
||||
- Category 4: Round blue number plate (number)
|
||||
|
||||
Please follow these requirements for the output:
|
||||
1. The output must be in valid JSON format.
|
||||
2. The JSON structure should contain a list named "detections".
|
||||
3. Each element in the list represents a detected target, containing the following fields:
|
||||
- "id" (integer): Target category ID (1, 2, 3, 4).
|
||||
- "label" (string): Target category name ("caisson", "soldier", "gun", "number").
|
||||
- "bbox" (list of int): Bounding box coordinates in format [x_min, y_min, x_max, y_max], where (x_min, y_min) is the top-left coordinate and (x_max, y_max) is the bottom-right coordinate. Coordinate values are integers normalized to the 0-999 range (0,0 represents top-left, 999,999 represents bottom-right).
|
||||
4. If no targets are detected in the image, "detections" should be an empty list [].
|
||||
5. Please output only JSON, no other explanatory text.
|
||||
|
||||
JSON output example (when targets are detected):
|
||||
{
|
||||
"detections": [
|
||||
{
|
||||
"id": 1,
|
||||
"label": "caisson",
|
||||
"bbox": [x1, y1, x2, y2] // x1, y1, x2, y2 are integers in the 0-999 range
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"label": "soldier",
|
||||
"bbox": [x3, y3, x4, y4] // x3, y3, x4, y4 are integers in the 0-999 range
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
JSON output example (when no targets are detected):
|
||||
{
|
||||
"detections": []
|
||||
}"""
|
||||
|
||||
def encode_cv_image(self, image: np.ndarray) -> str:
|
||||
"""
|
||||
Encodes an OpenCV image (Mat) to base64 string.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image
|
||||
"""
|
||||
# Encode the image to JPEG format
|
||||
_, buffer = cv2.imencode('.jpg', image)
|
||||
return base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
def validate_and_extract_json(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validates and extracts JSON from API response text.
|
||||
|
||||
Args:
|
||||
response_text: Raw response text from API
|
||||
|
||||
Returns:
|
||||
Parsed JSON dictionary if valid, None otherwise
|
||||
"""
|
||||
# Try to find JSON within the response text (in case of additional text)
|
||||
start_idx = response_text.find('{')
|
||||
end_idx = response_text.rfind('}')
|
||||
|
||||
if start_idx == -1 or end_idx == -1 or start_idx > end_idx:
|
||||
print("No valid JSON structure found in response.")
|
||||
return None
|
||||
|
||||
json_str = response_text[start_idx:end_idx+1]
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(json_str)
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing failed: {e}")
|
||||
print(f"Problematic JSON string: {json_str[:200]}...") # Show first 200 chars
|
||||
return None
|
||||
|
||||
def validate_detections_format(self, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validates the structure and content of the detections JSON.
|
||||
|
||||
Args:
|
||||
data: Parsed JSON data
|
||||
|
||||
Returns:
|
||||
True if format is valid, False otherwise
|
||||
"""
|
||||
if not isinstance(data, dict) or "detections" not in data:
|
||||
print("Missing 'detections' key in response.")
|
||||
return False
|
||||
|
||||
detections = data["detections"]
|
||||
if not isinstance(detections, list):
|
||||
print("'detections' is not a list.")
|
||||
return False
|
||||
|
||||
for i, detection in enumerate(detections):
|
||||
if not isinstance(detection, dict):
|
||||
print(f"Detection item {i} is not a dictionary.")
|
||||
return False
|
||||
|
||||
required_keys = ["id", "label", "bbox"]
|
||||
for key in required_keys:
|
||||
if key not in detection:
|
||||
print(f"Missing required key '{key}' in detection {i}.")
|
||||
return False
|
||||
|
||||
# Validate ID
|
||||
if not isinstance(detection["id"], int) or detection["id"] not in [1, 2, 3, 4]:
|
||||
print(f"Invalid ID in detection {i}: {detection['id']}")
|
||||
return False
|
||||
|
||||
# Validate label
|
||||
if not isinstance(detection["label"], str) or detection["label"] not in CATEGORY_MAPPING.values():
|
||||
print(f"Invalid label in detection {i}: {detection['label']}")
|
||||
return False
|
||||
|
||||
# Validate bbox
|
||||
bbox = detection["bbox"]
|
||||
if not isinstance(bbox, list) or len(bbox) != 4:
|
||||
print(f"Invalid bbox format in detection {i}: {bbox}")
|
||||
return False
|
||||
|
||||
for coord in bbox:
|
||||
if not isinstance(coord, (int, float)) or not (0 <= coord <= 999):
|
||||
print(f"Invalid bbox coordinate in detection {i}: {coord}")
|
||||
return False
|
||||
|
||||
# Validate confidence if present
|
||||
if "confidence" in detection:
|
||||
conf = detection["confidence"]
|
||||
if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0):
|
||||
print(f"Invalid confidence in detection {i}: {conf}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def call_vision_api_sync(self, image: np.ndarray, prompt: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Synchronous call to the vision API with OpenCV Mat input.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt to send to the API
|
||||
|
||||
Returns:
|
||||
Parsed JSON response if successful, None otherwise
|
||||
"""
|
||||
# Resize image to 1000x1000 if needed
|
||||
h, w = image.shape[:2]
|
||||
if h != 1000 or w != 1000:
|
||||
image = cv2.resize(image, (1000, 1000))
|
||||
|
||||
# Encode the image directly to base64
|
||||
image_base64 = self.encode_cv_image(image)
|
||||
image_url = f"data:image/jpeg;base64,{image_base64}"
|
||||
|
||||
try:
|
||||
# Create the completion request
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
]
|
||||
}
|
||||
],
|
||||
stream=False # Set to False for single response instead of streaming
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
else:
|
||||
print("No choices returned from API.")
|
||||
return None
|
||||
|
||||
if not content:
|
||||
print("No content returned from API.")
|
||||
return None
|
||||
|
||||
# Validate and parse JSON from response
|
||||
parsed_data = self.validate_and_extract_json(content)
|
||||
|
||||
if parsed_data is None:
|
||||
return None
|
||||
|
||||
# Validate the structure of the detections
|
||||
if not self.validate_detections_format(parsed_data):
|
||||
print("Invalid detections format in response.")
|
||||
return None
|
||||
|
||||
return parsed_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"API request failed: {e}")
|
||||
return None
|
||||
|
||||
def draw_detections_on_image(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes and labels on the OpenCV Mat image (without confidence).
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Image with drawn detections as numpy array
|
||||
"""
|
||||
# Work on a copy to avoid modifying the original
|
||||
result_image = image.copy()
|
||||
|
||||
# Get image dimensions
|
||||
img_h, img_w = result_image.shape[:2]
|
||||
|
||||
for detection in detections:
|
||||
# Get bounding box coordinates (normalized to 0-999 range)
|
||||
bbox = detection["bbox"]
|
||||
x1_norm, y1_norm, x2_norm, y2_norm = bbox
|
||||
|
||||
# Convert normalized coordinates to pixel coordinates
|
||||
x1 = int((x1_norm / 999) * img_w)
|
||||
y1 = int((y1_norm / 999) * img_h)
|
||||
x2 = int((x2_norm / 999) * img_w)
|
||||
y2 = int((y2_norm / 999) * img_h)
|
||||
|
||||
# Ensure coordinates are within image bounds
|
||||
x1 = max(0, min(x1, img_w - 1))
|
||||
y1 = max(0, min(y1, img_h - 1))
|
||||
x2 = max(0, min(x2, img_w - 1))
|
||||
y2 = max(0, min(y2, img_h - 1))
|
||||
|
||||
# Get color and label
|
||||
category_id = detection["id"]
|
||||
label = detection["label"]
|
||||
color = CATEGORY_COLORS.get(category_id, (255, 255, 255)) # Default white if not found
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# Prepare label text (no confidence)
|
||||
label_text = label
|
||||
|
||||
# Calculate text size and position
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 2
|
||||
(text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(result_image, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(result_image, label_text, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness)
|
||||
|
||||
return result_image
|
||||
|
||||
def process_single_image(self, image_id: str, image: np.ndarray, prompt: str = None) -> DetectionResult:
|
||||
"""
|
||||
Process a single OpenCV Mat image synchronously.
|
||||
|
||||
Args:
|
||||
image_id: Unique identifier for the image
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt for the vision API (optional)
|
||||
|
||||
Returns:
|
||||
DetectionResult containing the results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input image
|
||||
if image is None or image.size == 0:
|
||||
error_msg = f"Invalid image for image_id: {image_id}"
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=[],
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Use provided prompt or default
|
||||
use_prompt = prompt if prompt is not None else self.default_prompt
|
||||
|
||||
# Call the vision API
|
||||
print(f"Calling vision API for image {image_id}...")
|
||||
result = self.call_vision_api_sync(image, use_prompt)
|
||||
|
||||
if result is None:
|
||||
error_msg = "Failed to get valid response from API."
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=[],
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Extract detections
|
||||
detections = result.get("detections", [])
|
||||
print(f"Found {len(detections)} detections for image {image_id}.")
|
||||
|
||||
# Draw detections on image
|
||||
try:
|
||||
marked_image = self.draw_detections_on_image(image, detections)
|
||||
except Exception as e:
|
||||
error_msg = f"Error drawing detections on image: {e}"
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=detections,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Return successful result
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=detections,
|
||||
marked_image=marked_image,
|
||||
success=True,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
def start_worker(self):
|
||||
"""Start the background worker thread for processing tasks"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
print("Vision API worker started.")
|
||||
|
||||
def stop_worker(self):
|
||||
"""Stop the background worker thread"""
|
||||
self._running = False
|
||||
if self._worker_thread:
|
||||
self._worker_thread.join(timeout=5.0) # Wait up to 5 seconds
|
||||
print("Vision API worker stopped.")
|
||||
|
||||
def _worker_loop(self):
|
||||
"""Background worker loop for processing tasks from queue"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
task = self.task_queue.get(timeout=1.0)
|
||||
|
||||
# Mark as processing
|
||||
self.processing_tasks[task.task_id] = TaskStatus.PROCESSING
|
||||
|
||||
# Process the task
|
||||
result = self.process_single_image(task.image_id, task.image, task.prompt)
|
||||
|
||||
# Store result
|
||||
self.results_cache[task.task_id] = result
|
||||
|
||||
# Update task status
|
||||
self.processing_tasks[task.task_id] = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
|
||||
|
||||
# Call callback if provided
|
||||
if task.callback:
|
||||
try:
|
||||
task.callback(result)
|
||||
except Exception as e:
|
||||
print(f"Callback execution failed for task {task.task_id}: {e}")
|
||||
|
||||
# Mark task as done
|
||||
self.task_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
continue # Timeout, continue loop
|
||||
except Exception as e:
|
||||
print(f"Worker error: {e}")
|
||||
continue
|
||||
|
||||
def submit_task(self, image_id: int, image: np.ndarray, prompt: str = None,
|
||||
callback: Callable[[DetectionResult], None] = None) -> str:
|
||||
"""
|
||||
Submit a task to the processing queue with OpenCV Mat input.
|
||||
|
||||
Args:
|
||||
image_id: Unique identifier for the image
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt for the vision API (optional)
|
||||
callback: Callback function to be called when processing is complete (optional)
|
||||
|
||||
Returns:
|
||||
Task ID for tracking the task
|
||||
"""
|
||||
task_id = f"task_{int(time.time() * 1000000)}_{image_id}" # Generate unique task ID
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
image_id=image_id,
|
||||
image=image,
|
||||
prompt=prompt if prompt is not None else self.default_prompt,
|
||||
callback=callback,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
self.task_queue.put(task)
|
||||
self.processing_tasks[task_id] = TaskStatus.PENDING
|
||||
|
||||
return task_id
|
||||
|
||||
def get_result(self, task_id: str) -> Optional[DetectionResult]:
|
||||
"""
|
||||
Get the result for a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to retrieve result for
|
||||
|
||||
Returns:
|
||||
DetectionResult if available, None otherwise
|
||||
"""
|
||||
return self.results_cache.get(task_id)
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
||||
"""
|
||||
Get the status of a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check status for
|
||||
|
||||
Returns:
|
||||
TaskStatus if task exists, None otherwise
|
||||
"""
|
||||
return self.processing_tasks.get(task_id)
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get the current size of the task queue"""
|
||||
return self.task_queue.qsize()
|
||||
|
||||
def get_processing_count(self) -> int:
|
||||
"""Get the number of currently processing tasks"""
|
||||
return sum(1 for status in self.processing_tasks.values()
|
||||
if status == TaskStatus.PROCESSING)
|
||||
|
||||
def get_completed_count(self) -> int:
|
||||
"""Get the number of completed tasks"""
|
||||
return sum(1 for status in self.processing_tasks.values()
|
||||
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED])
|
||||
|
||||
def clear_results(self):
|
||||
"""Clear the results cache"""
|
||||
self.results_cache.clear()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
self.start_worker()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.stop_worker()
|
||||
|
||||
|
||||
# Example usage
|
||||
def example_callback(result: DetectionResult):
|
||||
"""Example callback function"""
|
||||
if result.success:
|
||||
print(f"Callback: Processing completed for image {result.image_id}, found {len(result.detections)} detections")
|
||||
# The result.marked_image is the OpenCV Mat with detections drawn
|
||||
marked_image = result.marked_image
|
||||
# You can now use the marked_image for further processing
|
||||
else:
|
||||
print(f"Callback: Processing failed for image {result.image_id}: {result.error_message}")
|
||||
|
||||
|
||||
def main():
|
||||
original_image = cv2.imread("/home/evan/Desktop/received/left/left_1761388243_7673044.jpg") # Replace with your image source
|
||||
|
||||
if original_image is None:
|
||||
print("Could not load image")
|
||||
return
|
||||
|
||||
# Example usage with context manager
|
||||
with VisionAPIClient() as client:
|
||||
# Submit a task with OpenCV Mat
|
||||
task_id = client.submit_task(
|
||||
image_id=1,
|
||||
image=original_image,
|
||||
callback=example_callback
|
||||
)
|
||||
|
||||
print(f"Submitted task {task_id}")
|
||||
|
||||
# Wait for the task to complete
|
||||
print("Waiting for task to complete...")
|
||||
client.task_queue.join() # Wait for all tasks in queue to be processed
|
||||
|
||||
# Get the result
|
||||
result = client.get_result(task_id)
|
||||
if result:
|
||||
if result.success:
|
||||
print(f"Task completed successfully! Found {len(result.detections)} detections.")
|
||||
# result.marked_image is the OpenCV Mat with detections drawn
|
||||
marked_image = result.marked_image
|
||||
|
||||
# Display the result (optional)
|
||||
cv2.imshow("Original Image", original_image)
|
||||
cv2.imshow("Marked Image", marked_image)
|
||||
print("Press any key to close windows...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Or save the result
|
||||
# cv2.imwrite("marked_image.jpg", marked_image)
|
||||
else:
|
||||
print(f"Task failed: {result.error_message}")
|
||||
else:
|
||||
print("No result found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,11 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<title>Saved Images List</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.min.js "></script>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Saved Images</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.js"></script>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
@@ -15,11 +17,12 @@
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 8px 16px;
|
||||
padding: 10px 15px;
|
||||
margin-right: 10px;
|
||||
background-color: #007bff;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
@@ -97,6 +100,8 @@
|
||||
<th>ID</th>
|
||||
<th>Left Image</th>
|
||||
<th>Right Image</th>
|
||||
<th>Left Marked Image</th> <!-- 新增标注图片列 -->
|
||||
<th>Right Marked Image</th> <!-- 新增标注图片列 -->
|
||||
<th>Timestamp</th>
|
||||
<th>Comment</th>
|
||||
<th>Actions</th>
|
||||
@@ -138,6 +143,7 @@
|
||||
row.insertCell(0).innerHTML = `<input type="checkbox" class="selectCheckbox" data-id="${image.id}">`;
|
||||
row.insertCell(1).textContent = image.id;
|
||||
|
||||
// 原图
|
||||
const leftCell = row.insertCell(2);
|
||||
const leftImg = document.createElement('img');
|
||||
// 修改这里:使用 Flask 静态文件路径
|
||||
@@ -156,10 +162,35 @@
|
||||
rightImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||
rightCell.appendChild(rightImg);
|
||||
|
||||
row.insertCell(4).textContent = new Date(image.timestamp * 1000).toISOString();
|
||||
// 标注图片
|
||||
const leftMarkedCell = row.insertCell(4);
|
||||
if (image.left_marked_filename) {
|
||||
const leftMarkedImg = document.createElement('img');
|
||||
leftMarkedImg.src = `/static/received/left_marked/${image.left_marked_filename}`;
|
||||
leftMarkedImg.alt = "Left Marked Image";
|
||||
leftMarkedImg.className = 'image-preview';
|
||||
leftMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||
leftMarkedCell.appendChild(leftMarkedImg);
|
||||
} else {
|
||||
leftMarkedCell.textContent = "N/A";
|
||||
}
|
||||
|
||||
const rightMarkedCell = row.insertCell(5);
|
||||
if (image.right_marked_filename) {
|
||||
const rightMarkedImg = document.createElement('img');
|
||||
rightMarkedImg.src = `/static/received/right_marked/${image.right_marked_filename}`;
|
||||
rightMarkedImg.alt = "Right Marked Image";
|
||||
rightMarkedImg.className = 'image-preview';
|
||||
rightMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||
rightMarkedCell.appendChild(rightMarkedImg);
|
||||
} else {
|
||||
rightMarkedCell.textContent = "N/A";
|
||||
}
|
||||
|
||||
row.insertCell(6).textContent = new Date(image.timestamp * 1000).toISOString();
|
||||
|
||||
// 添加可编辑的 comment 单元格
|
||||
const commentCell = row.insertCell(5);
|
||||
const commentCell = row.insertCell(7);
|
||||
const commentInput = document.createElement('input');
|
||||
commentInput.type = 'text';
|
||||
commentInput.value = image.comment || '';
|
||||
@@ -171,7 +202,7 @@
|
||||
});
|
||||
commentCell.appendChild(commentInput);
|
||||
|
||||
row.insertCell(6).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
|
||||
row.insertCell(8).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user