diff --git a/cam_web.py b/cam_web.py index 7198710..ac5e773 100644 --- a/cam_web.py +++ b/cam_web.py @@ -14,6 +14,9 @@ from datetime import datetime import zipfile import tempfile +# 导入视觉处理相关的模块 +from llm_req import VisionAPIClient, DetectionResult + from cap_trigger import ImageClient @@ -21,6 +24,8 @@ from cap_trigger import ImageClient # --- 配置 --- SAVE_PATH_LEFT = "./static/received/left" SAVE_PATH_RIGHT = "./static/received/right" +SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径 +SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径 FLASK_HOST = "0.0.0.0" FLASK_PORT = 5000 MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示 @@ -36,12 +41,14 @@ def init_db(): """初始化 SQLite 数据库和表""" conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - # 创建表,存储图片信息 + # 创建表,存储图片信息,添加标注图片字段 cursor.execute(''' CREATE TABLE IF NOT EXISTS images ( id INTEGER PRIMARY KEY AUTOINCREMENT, left_filename TEXT NOT NULL, right_filename TEXT NOT NULL, + left_marked_filename TEXT, + right_marked_filename TEXT, timestamp REAL NOT NULL, metadata TEXT, comment TEXT, @@ -72,6 +79,51 @@ image_client = ImageClient("tcp://127.0.0.1:54321", client_id="local") # 初始化数据库 init_db() +# --- 辅助函数 --- + +def draw_detections_on_image(image: np.ndarray, detections: list) -> np.ndarray: + """在图像上绘制检测框""" + # 复制原图以避免修改原始图像 + marked_image = image.copy() + + # 定义颜色映射 + color_map = { + 1: (0, 255, 0), # 绿色 - 弹药箱 + 2: (255, 0, 0), # 蓝色 - 士兵 + 3: (0, 0, 255), # 红色 - 枪支 + 4: (255, 255, 0) # 青色 - 数字牌 + } + + # 获取图像尺寸 + h, w = image.shape[:2] + + # 绘制每个检测框 + for detection in detections: + # 获取检测信息 + obj_id = detection.get("id", 0) + label = detection.get("label", "") + bbox = detection.get("bbox", []) + + if len(bbox) == 4: + # 将归一化的坐标转换为实际像素坐标 + x_min = int(bbox[0] * w / 999) + y_min = int(bbox[1] * h / 999) + x_max = int(bbox[2] * w / 999) + y_max = int(bbox[3] * h / 999) + + # 获取颜色 + color = color_map.get(obj_id, (255, 255, 255)) # 默认白色 + + # 绘制边界框 + cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2) + + # 添加标签 + label_text = f"{label} ({obj_id})" + cv2.putText(marked_image, label_text, (x_min, y_min - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) + + return marked_image + # --- Flask 路由 --- @app.route('/') @@ -89,8 +141,8 @@ def get_images_api(): """API: 获取图片列表 (JSON 格式)""" conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - # 按时间倒序排列 - cursor.execute("SELECT id, left_filename, right_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC") + # 按时间倒序排列,包含标注图片字段 + cursor.execute("SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC") rows = cursor.fetchall() conn.close() @@ -100,10 +152,12 @@ def get_images_api(): "id": row[0], "left_filename": row[1], "right_filename": row[2], - "timestamp": row[3], - "metadata": row[4], - "comment": row[5] or "", # 如果没有comment则显示空字符串 - "created_at": row[6] + "left_marked_filename": row[3] or "", # 如果没有标注图片则显示空字符串 + "right_marked_filename": row[4] or "", # 如果没有标注图片则显示空字符串 + "timestamp": row[5], + "metadata": row[6], + "comment": row[7] or "", # 如果没有 comment 则显示空字符串 + "created_at": row[8] }) return jsonify(images) @@ -116,22 +170,25 @@ def delete_image_api(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - # 查询文件名 - cursor.execute("SELECT left_filename, right_filename FROM images WHERE id = ?", (image_id,)) + # 查询文件名,包含标注图片 + cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,)) row = cursor.fetchone() if not row: conn.close() return jsonify({"error": "Image not found"}), 404 - left_filename, right_filename = row + left_filename, right_filename, left_marked_filename, right_marked_filename = row # 删除数据库记录 cursor.execute("DELETE FROM images WHERE id = ?", (image_id,)) conn.commit() conn.close() - # 删除对应的文件 + # 删除对应的文件,包括标注图片 left_path = os.path.join(SAVE_PATH_LEFT, left_filename) right_path = os.path.join(SAVE_PATH_RIGHT, right_filename) + left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None + right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None + try: if os.path.exists(left_path): os.remove(left_path) @@ -139,6 +196,12 @@ def delete_image_api(): if os.path.exists(right_path): os.remove(right_path) logger.info(f"Deleted file: {right_path}") + if left_marked_path and os.path.exists(left_marked_path): + os.remove(left_marked_path) + logger.info(f"Deleted file: {left_marked_path}") + if right_marked_path and os.path.exists(right_marked_path): + os.remove(right_marked_path) + logger.info(f"Deleted file: {right_marked_path}") except OSError as e: logger.error(f"Error deleting files: {e}") # 即使删除文件失败,数据库记录也已删除,返回成功 @@ -148,7 +211,7 @@ def delete_image_api(): @app.route('/api/images/export', methods=['POST']) def export_images_api(): - """API: 打包导出选中的图片""" + """API: 打包导出选中的图片,优先导出标注图片""" selected_ids = request.json.get('ids', []) if not selected_ids: return jsonify({"error": "No image IDs selected"}), 400 @@ -156,7 +219,8 @@ def export_images_api(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() placeholders = ','.join('?' * len(selected_ids)) - cursor.execute(f"SELECT left_filename, right_filename FROM images WHERE id IN ({placeholders})", selected_ids) + # 查询包含标注图片的文件名 + cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids) rows = cursor.fetchall() conn.close() @@ -169,13 +233,20 @@ def export_images_api(): try: with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: - for left_fn, right_fn in rows: - left_path = os.path.join(SAVE_PATH_LEFT, left_fn) - right_path = os.path.join(SAVE_PATH_RIGHT, right_fn) - if os.path.exists(left_path): - zipf.write(left_path, os.path.join('left', left_fn)) - if os.path.exists(right_path): - zipf.write(right_path, os.path.join('right', right_fn)) + for left_fn, right_fn, left_marked_fn, right_marked_fn in rows: + # 优先使用标注图片,如果没有则使用原图 + left_export_fn = left_marked_fn if left_marked_fn else left_fn + right_export_fn = right_marked_fn if right_marked_fn else right_fn + + # 确定文件路径 + left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn) + right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn) + + # 添加到 ZIP 文件 + if os.path.exists(left_export_path): + zipf.write(left_export_path, os.path.join('left', left_export_fn)) + if os.path.exists(right_export_path): + zipf.write(right_export_path, os.path.join('right', right_export_fn)) logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}") # 返回 ZIP 文件给客户端 @@ -190,13 +261,13 @@ def export_images_api(): @app.route('/upload', methods=['POST']) def upload_images(): - """接收左右摄像头图片,保存并推送更新""" + """接收左右摄像头图片,保存并推送更新,同时生成标注图片""" try: # 从 multipart/form-data 中获取文件 left_file = request.files.get('left_image') right_file = request.files.get('right_image') metadata_str = request.form.get('metadata') # 如果需要处理元数据 - comment = request.form.get('comment', '') # 获取comment字段 + comment = request.form.get('comment', '') # 获取 comment 字段 if not left_file or not right_file: logger.warning("Received request without required image files.") @@ -230,26 +301,67 @@ def upload_images(): # 生成文件名 left_filename = f"left_{timestamp_str_safe}.jpg" right_filename = f"right_{timestamp_str_safe}.jpg" + left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名 + right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名 - # 保存图片到本地 + # 保存原图到本地 left_path = os.path.join(SAVE_PATH_LEFT, left_filename) right_path = os.path.join(SAVE_PATH_RIGHT, right_filename) # 确保目录存在 os.makedirs(SAVE_PATH_LEFT, exist_ok=True) os.makedirs(SAVE_PATH_RIGHT, exist_ok=True) + os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录 + os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录 cv2.imwrite(left_path, img_left) cv2.imwrite(right_path, img_right) - logger.info(f"Saved images: {left_path}, {right_path}") + logger.info(f"Saved original images: {left_path}, {right_path}") - # 将图片信息写入数据库 + # 使用 VisionAPIClient 处理图片并生成标注图片 + left_marked_path = None + right_marked_path = None + + try: + with VisionAPIClient() as client: + # 处理左图 + left_task_id = client.submit_task(image_id=1, image=img_left) + # 处理右图 + right_task_id = client.submit_task(image_id=2, image=img_right) + + # 等待任务完成 + client.task_queue.join() + + # 获取处理结果 + left_result = client.get_result(left_task_id) + right_result = client.get_result(right_task_id) + + # 生成标注图片 + if left_result and left_result.success: + marked_left_img = draw_detections_on_image(img_left, left_result.detections) + left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) + cv2.imwrite(left_marked_path, marked_left_img) + logger.info(f"Saved marked left image: {left_marked_path}") + + if right_result and right_result.success: + marked_right_img = draw_detections_on_image(img_right, right_result.detections) + right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) + cv2.imwrite(right_marked_path, marked_right_img) + logger.info(f"Saved marked right image: {right_marked_path}") + except Exception as e: + logger.error(f"Error processing images with VisionAPIClient: {e}") + # 即使处理失败,也继续保存原图 + + # 将图片信息写入数据库,包含标注图片字段 conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute(''' - INSERT INTO images (left_filename, right_filename, timestamp, metadata, comment) - VALUES (?, ?, ?, ?, ?) - ''', (left_filename, right_filename, float(timestamp_str), json.dumps(metadata), comment)) + INSERT INTO images (left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', (left_filename, right_filename, + left_marked_filename if left_marked_path else None, + right_marked_filename if right_marked_path else None, + float(timestamp_str), json.dumps(metadata), comment)) conn.commit() image_id = cursor.lastrowid # 获取新插入记录的 ID conn.close() @@ -282,7 +394,7 @@ def upload_images(): @app.route('/api/images/comment', methods=['PUT']) def update_image_comment(): - """API: 更新图片的comment""" + """API: 更新图片的 comment""" data = request.json image_id = data.get('id') comment = data.get('comment', '') @@ -292,7 +404,7 @@ def update_image_comment(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - # 更新comment字段 + # 更新 comment 字段 cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id)) conn.commit() conn.close() diff --git a/cap_trigger.py b/cap_trigger.py index 6da23c9..ec39e45 100644 --- a/cap_trigger.py +++ b/cap_trigger.py @@ -4,7 +4,7 @@ import uuid import time class ImageClient: - def __init__(self, server_address="tcp://10.42.70.1:54321", client_id=None): + def __init__(self, server_address="tcp://:54321", client_id=None): self.server_address = server_address self.client_id = client_id or f"client_{uuid.uuid4().hex[:8]}" self.socket = None @@ -43,7 +43,7 @@ class ImageClient: """发送同步请求到服务器""" socket = None try: - # 每次都创建一个新的socket实例以避免复用已关闭的socket + # 每次都创建一个新的 socket 实例以避免复用已关闭的 socket socket = pynng.Req0() socket.dial(self.server_address, block=True) # 使用阻塞模式确保连接建立 client_id_bytes = self.client_id.encode('utf-8') @@ -54,7 +54,7 @@ class ImageClient: print(f"Client {self.client_id} error: {e}") return None finally: - # 确保socket在使用后被正确关闭 + # 确保 socket 在使用后被正确关闭 if socket: try: socket.close() diff --git a/llm_req.py b/llm_req.py new file mode 100644 index 0000000..048d253 --- /dev/null +++ b/llm_req.py @@ -0,0 +1,610 @@ +import base64 +import json +import cv2 +import numpy as np +from typing import List, Dict, Any, Optional, Callable +from openai import OpenAI +import asyncio +from concurrent.futures import ThreadPoolExecutor +import threading +from dataclasses import dataclass +from enum import Enum +import queue +import time + +# Configuration +API_KEY = "sk-e3a0287ece6a41bb9b79b2c285f10197" +BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +MODEL_NAME = "qwen-vl-plus" + +# Category mapping +CATEGORY_MAPPING = { + 1: "caisson", + 2: "soldier", + 3: "gun", + 4: "number" +} + +CATEGORY_COLORS = { + 1: (0, 255, 0), # Green for caisson + 2: (0, 255, 255), # Yellow for soldier + 3: (0, 0, 255), # Red for gun + 4: (255, 0, 0) # Blue for number +} + +@dataclass +class DetectionResult: + """Detection result data class""" + image_id: str + original_image: np.ndarray + detections: List[Dict[str, Any]] + marked_image: Optional[np.ndarray] = None + success: bool = False + error_message: Optional[str] = None + timestamp: float = 0.0 + +class TaskStatus(Enum): + """Task status enumeration""" + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + +@dataclass +class Task: + """Task data class for queue""" + task_id: str + image_id: str + image: np.ndarray # OpenCV Mat format + prompt: str + callback: Optional[Callable[[DetectionResult], None]] = None + timestamp: float = 0.0 + +class VisionAPIClient: + """Vision API Client for asynchronous processing with OpenCV Mat input/output""" + + def __init__(self, api_key: str = API_KEY, base_url: str = BASE_URL, + model_name: str = MODEL_NAME, max_workers: int = 4): + self.api_key = api_key + self.base_url = base_url + self.model_name = model_name + self.max_workers = max_workers + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.client = OpenAI(api_key=api_key, base_url=base_url) + + # Task management + self.task_queue = queue.Queue() + self.processing_tasks = {} + self.results_cache = {} + self._running = False + self._worker_thread = None + + self.default_prompt = """Please perform object detection on the image, identifying and localizing the following four types of targets: + - Category 1: Green ammunition box (caisson) + - Category 2: Dummy soldier wearing digital camouflage uniform (soldier) + - Category 3: Gun (gun) + - Category 4: Round blue number plate (number) + + Please follow these requirements for the output: + 1. The output must be in valid JSON format. + 2. The JSON structure should contain a list named "detections". + 3. Each element in the list represents a detected target, containing the following fields: + - "id" (integer): Target category ID (1, 2, 3, 4). + - "label" (string): Target category name ("caisson", "soldier", "gun", "number"). + - "bbox" (list of int): Bounding box coordinates in format [x_min, y_min, x_max, y_max], where (x_min, y_min) is the top-left coordinate and (x_max, y_max) is the bottom-right coordinate. Coordinate values are integers normalized to the 0-999 range (0,0 represents top-left, 999,999 represents bottom-right). + 4. If no targets are detected in the image, "detections" should be an empty list []. + 5. Please output only JSON, no other explanatory text. + + JSON output example (when targets are detected): + { + "detections": [ + { + "id": 1, + "label": "caisson", + "bbox": [x1, y1, x2, y2] // x1, y1, x2, y2 are integers in the 0-999 range + }, + { + "id": 2, + "label": "soldier", + "bbox": [x3, y3, x4, y4] // x3, y3, x4, y4 are integers in the 0-999 range + } + ] + } + + JSON output example (when no targets are detected): + { + "detections": [] + }""" + + def encode_cv_image(self, image: np.ndarray) -> str: + """ + Encodes an OpenCV image (Mat) to base64 string. + + Args: + image: OpenCV image (Mat) format + + Returns: + Base64 encoded string of the image + """ + # Encode the image to JPEG format + _, buffer = cv2.imencode('.jpg', image) + return base64.b64encode(buffer).decode('utf-8') + + def validate_and_extract_json(self, response_text: str) -> Optional[Dict[str, Any]]: + """ + Validates and extracts JSON from API response text. + + Args: + response_text: Raw response text from API + + Returns: + Parsed JSON dictionary if valid, None otherwise + """ + # Try to find JSON within the response text (in case of additional text) + start_idx = response_text.find('{') + end_idx = response_text.rfind('}') + + if start_idx == -1 or end_idx == -1 or start_idx > end_idx: + print("No valid JSON structure found in response.") + return None + + json_str = response_text[start_idx:end_idx+1] + + try: + parsed_json = json.loads(json_str) + return parsed_json + except json.JSONDecodeError as e: + print(f"JSON parsing failed: {e}") + print(f"Problematic JSON string: {json_str[:200]}...") # Show first 200 chars + return None + + def validate_detections_format(self, data: Dict[str, Any]) -> bool: + """ + Validates the structure and content of the detections JSON. + + Args: + data: Parsed JSON data + + Returns: + True if format is valid, False otherwise + """ + if not isinstance(data, dict) or "detections" not in data: + print("Missing 'detections' key in response.") + return False + + detections = data["detections"] + if not isinstance(detections, list): + print("'detections' is not a list.") + return False + + for i, detection in enumerate(detections): + if not isinstance(detection, dict): + print(f"Detection item {i} is not a dictionary.") + return False + + required_keys = ["id", "label", "bbox"] + for key in required_keys: + if key not in detection: + print(f"Missing required key '{key}' in detection {i}.") + return False + + # Validate ID + if not isinstance(detection["id"], int) or detection["id"] not in [1, 2, 3, 4]: + print(f"Invalid ID in detection {i}: {detection['id']}") + return False + + # Validate label + if not isinstance(detection["label"], str) or detection["label"] not in CATEGORY_MAPPING.values(): + print(f"Invalid label in detection {i}: {detection['label']}") + return False + + # Validate bbox + bbox = detection["bbox"] + if not isinstance(bbox, list) or len(bbox) != 4: + print(f"Invalid bbox format in detection {i}: {bbox}") + return False + + for coord in bbox: + if not isinstance(coord, (int, float)) or not (0 <= coord <= 999): + print(f"Invalid bbox coordinate in detection {i}: {coord}") + return False + + # Validate confidence if present + if "confidence" in detection: + conf = detection["confidence"] + if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0): + print(f"Invalid confidence in detection {i}: {conf}") + return False + + return True + + def call_vision_api_sync(self, image: np.ndarray, prompt: str) -> Optional[Dict[str, Any]]: + """ + Synchronous call to the vision API with OpenCV Mat input. + + Args: + image: OpenCV image (Mat) format + prompt: The prompt to send to the API + + Returns: + Parsed JSON response if successful, None otherwise + """ + # Resize image to 1000x1000 if needed + h, w = image.shape[:2] + if h != 1000 or w != 1000: + image = cv2.resize(image, (1000, 1000)) + + # Encode the image directly to base64 + image_base64 = self.encode_cv_image(image) + image_url = f"data:image/jpeg;base64,{image_base64}" + + try: + # Create the completion request + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + 'role': 'user', + 'content': [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}} + ] + } + ], + stream=False # Set to False for single response instead of streaming + ) + + # Extract content from response + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + else: + print("No choices returned from API.") + return None + + if not content: + print("No content returned from API.") + return None + + # Validate and parse JSON from response + parsed_data = self.validate_and_extract_json(content) + + if parsed_data is None: + return None + + # Validate the structure of the detections + if not self.validate_detections_format(parsed_data): + print("Invalid detections format in response.") + return None + + return parsed_data + + except Exception as e: + print(f"API request failed: {e}") + return None + + def draw_detections_on_image(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray: + """ + Draws bounding boxes and labels on the OpenCV Mat image (without confidence). + + Args: + image: OpenCV image (Mat) format + detections: List of detection dictionaries + + Returns: + Image with drawn detections as numpy array + """ + # Work on a copy to avoid modifying the original + result_image = image.copy() + + # Get image dimensions + img_h, img_w = result_image.shape[:2] + + for detection in detections: + # Get bounding box coordinates (normalized to 0-999 range) + bbox = detection["bbox"] + x1_norm, y1_norm, x2_norm, y2_norm = bbox + + # Convert normalized coordinates to pixel coordinates + x1 = int((x1_norm / 999) * img_w) + y1 = int((y1_norm / 999) * img_h) + x2 = int((x2_norm / 999) * img_w) + y2 = int((y2_norm / 999) * img_h) + + # Ensure coordinates are within image bounds + x1 = max(0, min(x1, img_w - 1)) + y1 = max(0, min(y1, img_h - 1)) + x2 = max(0, min(x2, img_w - 1)) + y2 = max(0, min(y2, img_h - 1)) + + # Get color and label + category_id = detection["id"] + label = detection["label"] + color = CATEGORY_COLORS.get(category_id, (255, 255, 255)) # Default white if not found + + # Draw bounding box + cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) + + # Prepare label text (no confidence) + label_text = label + + # Calculate text size and position + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 2 + (text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness) + + # Draw label background + cv2.rectangle(result_image, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1) + + # Draw label text + cv2.putText(result_image, label_text, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness) + + return result_image + + def process_single_image(self, image_id: str, image: np.ndarray, prompt: str = None) -> DetectionResult: + """ + Process a single OpenCV Mat image synchronously. + + Args: + image_id: Unique identifier for the image + image: OpenCV image (Mat) format + prompt: The prompt for the vision API (optional) + + Returns: + DetectionResult containing the results + """ + start_time = time.time() + + # Validate input image + if image is None or image.size == 0: + error_msg = f"Invalid image for image_id: {image_id}" + print(error_msg) + return DetectionResult( + image_id=image_id, + original_image=image, + detections=[], + success=False, + error_message=error_msg, + timestamp=start_time + ) + + # Use provided prompt or default + use_prompt = prompt if prompt is not None else self.default_prompt + + # Call the vision API + print(f"Calling vision API for image {image_id}...") + result = self.call_vision_api_sync(image, use_prompt) + + if result is None: + error_msg = "Failed to get valid response from API." + print(error_msg) + return DetectionResult( + image_id=image_id, + original_image=image, + detections=[], + success=False, + error_message=error_msg, + timestamp=start_time + ) + + # Extract detections + detections = result.get("detections", []) + print(f"Found {len(detections)} detections for image {image_id}.") + + # Draw detections on image + try: + marked_image = self.draw_detections_on_image(image, detections) + except Exception as e: + error_msg = f"Error drawing detections on image: {e}" + print(error_msg) + return DetectionResult( + image_id=image_id, + original_image=image, + detections=detections, + success=False, + error_message=error_msg, + timestamp=start_time + ) + + # Return successful result + return DetectionResult( + image_id=image_id, + original_image=image, + detections=detections, + marked_image=marked_image, + success=True, + timestamp=start_time + ) + + def start_worker(self): + """Start the background worker thread for processing tasks""" + if self._running: + return + + self._running = True + self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True) + self._worker_thread.start() + print("Vision API worker started.") + + def stop_worker(self): + """Stop the background worker thread""" + self._running = False + if self._worker_thread: + self._worker_thread.join(timeout=5.0) # Wait up to 5 seconds + print("Vision API worker stopped.") + + def _worker_loop(self): + """Background worker loop for processing tasks from queue""" + while self._running: + try: + # Get task from queue with timeout + task = self.task_queue.get(timeout=1.0) + + # Mark as processing + self.processing_tasks[task.task_id] = TaskStatus.PROCESSING + + # Process the task + result = self.process_single_image(task.image_id, task.image, task.prompt) + + # Store result + self.results_cache[task.task_id] = result + + # Update task status + self.processing_tasks[task.task_id] = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED + + # Call callback if provided + if task.callback: + try: + task.callback(result) + except Exception as e: + print(f"Callback execution failed for task {task.task_id}: {e}") + + # Mark task as done + self.task_queue.task_done() + + except queue.Empty: + continue # Timeout, continue loop + except Exception as e: + print(f"Worker error: {e}") + continue + + def submit_task(self, image_id: int, image: np.ndarray, prompt: str = None, + callback: Callable[[DetectionResult], None] = None) -> str: + """ + Submit a task to the processing queue with OpenCV Mat input. + + Args: + image_id: Unique identifier for the image + image: OpenCV image (Mat) format + prompt: The prompt for the vision API (optional) + callback: Callback function to be called when processing is complete (optional) + + Returns: + Task ID for tracking the task + """ + task_id = f"task_{int(time.time() * 1000000)}_{image_id}" # Generate unique task ID + task = Task( + task_id=task_id, + image_id=image_id, + image=image, + prompt=prompt if prompt is not None else self.default_prompt, + callback=callback, + timestamp=time.time() + ) + + self.task_queue.put(task) + self.processing_tasks[task_id] = TaskStatus.PENDING + + return task_id + + def get_result(self, task_id: str) -> Optional[DetectionResult]: + """ + Get the result for a specific task. + + Args: + task_id: The task ID to retrieve result for + + Returns: + DetectionResult if available, None otherwise + """ + return self.results_cache.get(task_id) + + def get_task_status(self, task_id: str) -> Optional[TaskStatus]: + """ + Get the status of a specific task. + + Args: + task_id: The task ID to check status for + + Returns: + TaskStatus if task exists, None otherwise + """ + return self.processing_tasks.get(task_id) + + def get_queue_size(self) -> int: + """Get the current size of the task queue""" + return self.task_queue.qsize() + + def get_processing_count(self) -> int: + """Get the number of currently processing tasks""" + return sum(1 for status in self.processing_tasks.values() + if status == TaskStatus.PROCESSING) + + def get_completed_count(self) -> int: + """Get the number of completed tasks""" + return sum(1 for status in self.processing_tasks.values() + if status in [TaskStatus.COMPLETED, TaskStatus.FAILED]) + + def clear_results(self): + """Clear the results cache""" + self.results_cache.clear() + + def __enter__(self): + """Context manager entry""" + self.start_worker() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + self.stop_worker() + + +# Example usage +def example_callback(result: DetectionResult): + """Example callback function""" + if result.success: + print(f"Callback: Processing completed for image {result.image_id}, found {len(result.detections)} detections") + # The result.marked_image is the OpenCV Mat with detections drawn + marked_image = result.marked_image + # You can now use the marked_image for further processing + else: + print(f"Callback: Processing failed for image {result.image_id}: {result.error_message}") + + +def main(): + original_image = cv2.imread("/home/evan/Desktop/received/left/left_1761388243_7673044.jpg") # Replace with your image source + + if original_image is None: + print("Could not load image") + return + + # Example usage with context manager + with VisionAPIClient() as client: + # Submit a task with OpenCV Mat + task_id = client.submit_task( + image_id=1, + image=original_image, + callback=example_callback + ) + + print(f"Submitted task {task_id}") + + # Wait for the task to complete + print("Waiting for task to complete...") + client.task_queue.join() # Wait for all tasks in queue to be processed + + # Get the result + result = client.get_result(task_id) + if result: + if result.success: + print(f"Task completed successfully! Found {len(result.detections)} detections.") + # result.marked_image is the OpenCV Mat with detections drawn + marked_image = result.marked_image + + # Display the result (optional) + cv2.imshow("Original Image", original_image) + cv2.imshow("Marked Image", marked_image) + print("Press any key to close windows...") + cv2.waitKey(0) + cv2.destroyAllWindows() + + # Or save the result + # cv2.imwrite("marked_image.jpg", marked_image) + else: + print(f"Task failed: {result.error_message}") + else: + print("No result found") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/templates/list_images.html b/templates/list_images.html index f58c7b7..214b6d3 100644 --- a/templates/list_images.html +++ b/templates/list_images.html @@ -1,9 +1,11 @@ - +
-