feat: 增加标注相关功能
This commit is contained in:
172
cam_web.py
172
cam_web.py
@@ -14,6 +14,9 @@ from datetime import datetime
|
|||||||
import zipfile
|
import zipfile
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
# 导入视觉处理相关的模块
|
||||||
|
from llm_req import VisionAPIClient, DetectionResult
|
||||||
|
|
||||||
from cap_trigger import ImageClient
|
from cap_trigger import ImageClient
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +24,8 @@ from cap_trigger import ImageClient
|
|||||||
# --- 配置 ---
|
# --- 配置 ---
|
||||||
SAVE_PATH_LEFT = "./static/received/left"
|
SAVE_PATH_LEFT = "./static/received/left"
|
||||||
SAVE_PATH_RIGHT = "./static/received/right"
|
SAVE_PATH_RIGHT = "./static/received/right"
|
||||||
|
SAVE_PATH_LEFT_MARKED = "./static/received/left_marked" # 标注图片保存路径
|
||||||
|
SAVE_PATH_RIGHT_MARKED = "./static/received/right_marked" # 标注图片保存路径
|
||||||
FLASK_HOST = "0.0.0.0"
|
FLASK_HOST = "0.0.0.0"
|
||||||
FLASK_PORT = 5000
|
FLASK_PORT = 5000
|
||||||
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示
|
MAX_LIVE_FRAMES = 2 # 保留最新的几帧用于实时显示
|
||||||
@@ -36,12 +41,14 @@ def init_db():
|
|||||||
"""初始化 SQLite 数据库和表"""
|
"""初始化 SQLite 数据库和表"""
|
||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
# 创建表,存储图片信息
|
# 创建表,存储图片信息,添加标注图片字段
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS images (
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
left_filename TEXT NOT NULL,
|
left_filename TEXT NOT NULL,
|
||||||
right_filename TEXT NOT NULL,
|
right_filename TEXT NOT NULL,
|
||||||
|
left_marked_filename TEXT,
|
||||||
|
right_marked_filename TEXT,
|
||||||
timestamp REAL NOT NULL,
|
timestamp REAL NOT NULL,
|
||||||
metadata TEXT,
|
metadata TEXT,
|
||||||
comment TEXT,
|
comment TEXT,
|
||||||
@@ -72,6 +79,51 @@ image_client = ImageClient("tcp://127.0.0.1:54321", client_id="local")
|
|||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
# --- 辅助函数 ---
|
||||||
|
|
||||||
|
def draw_detections_on_image(image: np.ndarray, detections: list) -> np.ndarray:
|
||||||
|
"""在图像上绘制检测框"""
|
||||||
|
# 复制原图以避免修改原始图像
|
||||||
|
marked_image = image.copy()
|
||||||
|
|
||||||
|
# 定义颜色映射
|
||||||
|
color_map = {
|
||||||
|
1: (0, 255, 0), # 绿色 - 弹药箱
|
||||||
|
2: (255, 0, 0), # 蓝色 - 士兵
|
||||||
|
3: (0, 0, 255), # 红色 - 枪支
|
||||||
|
4: (255, 255, 0) # 青色 - 数字牌
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取图像尺寸
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
# 绘制每个检测框
|
||||||
|
for detection in detections:
|
||||||
|
# 获取检测信息
|
||||||
|
obj_id = detection.get("id", 0)
|
||||||
|
label = detection.get("label", "")
|
||||||
|
bbox = detection.get("bbox", [])
|
||||||
|
|
||||||
|
if len(bbox) == 4:
|
||||||
|
# 将归一化的坐标转换为实际像素坐标
|
||||||
|
x_min = int(bbox[0] * w / 999)
|
||||||
|
y_min = int(bbox[1] * h / 999)
|
||||||
|
x_max = int(bbox[2] * w / 999)
|
||||||
|
y_max = int(bbox[3] * h / 999)
|
||||||
|
|
||||||
|
# 获取颜色
|
||||||
|
color = color_map.get(obj_id, (255, 255, 255)) # 默认白色
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(marked_image, (x_min, y_min), (x_max, y_max), color, 2)
|
||||||
|
|
||||||
|
# 添加标签
|
||||||
|
label_text = f"{label} ({obj_id})"
|
||||||
|
cv2.putText(marked_image, label_text, (x_min, y_min - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||||
|
|
||||||
|
return marked_image
|
||||||
|
|
||||||
# --- Flask 路由 ---
|
# --- Flask 路由 ---
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@@ -89,8 +141,8 @@ def get_images_api():
|
|||||||
"""API: 获取图片列表 (JSON 格式)"""
|
"""API: 获取图片列表 (JSON 格式)"""
|
||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
# 按时间倒序排列
|
# 按时间倒序排列,包含标注图片字段
|
||||||
cursor.execute("SELECT id, left_filename, right_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
|
cursor.execute("SELECT id, left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment, created_at FROM images ORDER BY timestamp DESC")
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -100,10 +152,12 @@ def get_images_api():
|
|||||||
"id": row[0],
|
"id": row[0],
|
||||||
"left_filename": row[1],
|
"left_filename": row[1],
|
||||||
"right_filename": row[2],
|
"right_filename": row[2],
|
||||||
"timestamp": row[3],
|
"left_marked_filename": row[3] or "", # 如果没有标注图片则显示空字符串
|
||||||
"metadata": row[4],
|
"right_marked_filename": row[4] or "", # 如果没有标注图片则显示空字符串
|
||||||
"comment": row[5] or "", # 如果没有comment则显示空字符串
|
"timestamp": row[5],
|
||||||
"created_at": row[6]
|
"metadata": row[6],
|
||||||
|
"comment": row[7] or "", # 如果没有 comment 则显示空字符串
|
||||||
|
"created_at": row[8]
|
||||||
})
|
})
|
||||||
return jsonify(images)
|
return jsonify(images)
|
||||||
|
|
||||||
@@ -116,22 +170,25 @@ def delete_image_api():
|
|||||||
|
|
||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
# 查询文件名
|
# 查询文件名,包含标注图片
|
||||||
cursor.execute("SELECT left_filename, right_filename FROM images WHERE id = ?", (image_id,))
|
cursor.execute("SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id = ?", (image_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
conn.close()
|
conn.close()
|
||||||
return jsonify({"error": "Image not found"}), 404
|
return jsonify({"error": "Image not found"}), 404
|
||||||
|
|
||||||
left_filename, right_filename = row
|
left_filename, right_filename, left_marked_filename, right_marked_filename = row
|
||||||
# 删除数据库记录
|
# 删除数据库记录
|
||||||
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
|
cursor.execute("DELETE FROM images WHERE id = ?", (image_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 删除对应的文件
|
# 删除对应的文件,包括标注图片
|
||||||
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
||||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
||||||
|
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename) if left_marked_filename else None
|
||||||
|
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename) if right_marked_filename else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(left_path):
|
if os.path.exists(left_path):
|
||||||
os.remove(left_path)
|
os.remove(left_path)
|
||||||
@@ -139,6 +196,12 @@ def delete_image_api():
|
|||||||
if os.path.exists(right_path):
|
if os.path.exists(right_path):
|
||||||
os.remove(right_path)
|
os.remove(right_path)
|
||||||
logger.info(f"Deleted file: {right_path}")
|
logger.info(f"Deleted file: {right_path}")
|
||||||
|
if left_marked_path and os.path.exists(left_marked_path):
|
||||||
|
os.remove(left_marked_path)
|
||||||
|
logger.info(f"Deleted file: {left_marked_path}")
|
||||||
|
if right_marked_path and os.path.exists(right_marked_path):
|
||||||
|
os.remove(right_marked_path)
|
||||||
|
logger.info(f"Deleted file: {right_marked_path}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(f"Error deleting files: {e}")
|
logger.error(f"Error deleting files: {e}")
|
||||||
# 即使删除文件失败,数据库记录也已删除,返回成功
|
# 即使删除文件失败,数据库记录也已删除,返回成功
|
||||||
@@ -148,7 +211,7 @@ def delete_image_api():
|
|||||||
|
|
||||||
@app.route('/api/images/export', methods=['POST'])
|
@app.route('/api/images/export', methods=['POST'])
|
||||||
def export_images_api():
|
def export_images_api():
|
||||||
"""API: 打包导出选中的图片"""
|
"""API: 打包导出选中的图片,优先导出标注图片"""
|
||||||
selected_ids = request.json.get('ids', [])
|
selected_ids = request.json.get('ids', [])
|
||||||
if not selected_ids:
|
if not selected_ids:
|
||||||
return jsonify({"error": "No image IDs selected"}), 400
|
return jsonify({"error": "No image IDs selected"}), 400
|
||||||
@@ -156,7 +219,8 @@ def export_images_api():
|
|||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
placeholders = ','.join('?' * len(selected_ids))
|
placeholders = ','.join('?' * len(selected_ids))
|
||||||
cursor.execute(f"SELECT left_filename, right_filename FROM images WHERE id IN ({placeholders})", selected_ids)
|
# 查询包含标注图片的文件名
|
||||||
|
cursor.execute(f"SELECT left_filename, right_filename, left_marked_filename, right_marked_filename FROM images WHERE id IN ({placeholders})", selected_ids)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -169,13 +233,20 @@ def export_images_api():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
with zipfile.ZipFile(temp_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||||
for left_fn, right_fn in rows:
|
for left_fn, right_fn, left_marked_fn, right_marked_fn in rows:
|
||||||
left_path = os.path.join(SAVE_PATH_LEFT, left_fn)
|
# 优先使用标注图片,如果没有则使用原图
|
||||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_fn)
|
left_export_fn = left_marked_fn if left_marked_fn else left_fn
|
||||||
if os.path.exists(left_path):
|
right_export_fn = right_marked_fn if right_marked_fn else right_fn
|
||||||
zipf.write(left_path, os.path.join('left', left_fn))
|
|
||||||
if os.path.exists(right_path):
|
# 确定文件路径
|
||||||
zipf.write(right_path, os.path.join('right', right_fn))
|
left_export_path = os.path.join(SAVE_PATH_LEFT_MARKED if left_marked_fn else SAVE_PATH_LEFT, left_export_fn)
|
||||||
|
right_export_path = os.path.join(SAVE_PATH_RIGHT_MARKED if right_marked_fn else SAVE_PATH_RIGHT, right_export_fn)
|
||||||
|
|
||||||
|
# 添加到 ZIP 文件
|
||||||
|
if os.path.exists(left_export_path):
|
||||||
|
zipf.write(left_export_path, os.path.join('left', left_export_fn))
|
||||||
|
if os.path.exists(right_export_path):
|
||||||
|
zipf.write(right_export_path, os.path.join('right', right_export_fn))
|
||||||
|
|
||||||
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
|
logger.info(f"Exported {len(rows)} image pairs to {temp_zip_path}")
|
||||||
# 返回 ZIP 文件给客户端
|
# 返回 ZIP 文件给客户端
|
||||||
@@ -190,13 +261,13 @@ def export_images_api():
|
|||||||
|
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
def upload_images():
|
def upload_images():
|
||||||
"""接收左右摄像头图片,保存并推送更新"""
|
"""接收左右摄像头图片,保存并推送更新,同时生成标注图片"""
|
||||||
try:
|
try:
|
||||||
# 从 multipart/form-data 中获取文件
|
# 从 multipart/form-data 中获取文件
|
||||||
left_file = request.files.get('left_image')
|
left_file = request.files.get('left_image')
|
||||||
right_file = request.files.get('right_image')
|
right_file = request.files.get('right_image')
|
||||||
metadata_str = request.form.get('metadata') # 如果需要处理元数据
|
metadata_str = request.form.get('metadata') # 如果需要处理元数据
|
||||||
comment = request.form.get('comment', '') # 获取comment字段
|
comment = request.form.get('comment', '') # 获取 comment 字段
|
||||||
|
|
||||||
if not left_file or not right_file:
|
if not left_file or not right_file:
|
||||||
logger.warning("Received request without required image files.")
|
logger.warning("Received request without required image files.")
|
||||||
@@ -230,26 +301,67 @@ def upload_images():
|
|||||||
# 生成文件名
|
# 生成文件名
|
||||||
left_filename = f"left_{timestamp_str_safe}.jpg"
|
left_filename = f"left_{timestamp_str_safe}.jpg"
|
||||||
right_filename = f"right_{timestamp_str_safe}.jpg"
|
right_filename = f"right_{timestamp_str_safe}.jpg"
|
||||||
|
left_marked_filename = f"left_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
|
||||||
|
right_marked_filename = f"right_marked_{timestamp_str_safe}.jpg" # 标注图片文件名
|
||||||
|
|
||||||
# 保存图片到本地
|
# 保存原图到本地
|
||||||
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
left_path = os.path.join(SAVE_PATH_LEFT, left_filename)
|
||||||
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
right_path = os.path.join(SAVE_PATH_RIGHT, right_filename)
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
|
os.makedirs(SAVE_PATH_LEFT, exist_ok=True)
|
||||||
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
|
os.makedirs(SAVE_PATH_RIGHT, exist_ok=True)
|
||||||
|
os.makedirs(SAVE_PATH_LEFT_MARKED, exist_ok=True) # 创建标注图片目录
|
||||||
|
os.makedirs(SAVE_PATH_RIGHT_MARKED, exist_ok=True) # 创建标注图片目录
|
||||||
|
|
||||||
cv2.imwrite(left_path, img_left)
|
cv2.imwrite(left_path, img_left)
|
||||||
cv2.imwrite(right_path, img_right)
|
cv2.imwrite(right_path, img_right)
|
||||||
logger.info(f"Saved images: {left_path}, {right_path}")
|
logger.info(f"Saved original images: {left_path}, {right_path}")
|
||||||
|
|
||||||
# 将图片信息写入数据库
|
# 使用 VisionAPIClient 处理图片并生成标注图片
|
||||||
|
left_marked_path = None
|
||||||
|
right_marked_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with VisionAPIClient() as client:
|
||||||
|
# 处理左图
|
||||||
|
left_task_id = client.submit_task(image_id=1, image=img_left)
|
||||||
|
# 处理右图
|
||||||
|
right_task_id = client.submit_task(image_id=2, image=img_right)
|
||||||
|
|
||||||
|
# 等待任务完成
|
||||||
|
client.task_queue.join()
|
||||||
|
|
||||||
|
# 获取处理结果
|
||||||
|
left_result = client.get_result(left_task_id)
|
||||||
|
right_result = client.get_result(right_task_id)
|
||||||
|
|
||||||
|
# 生成标注图片
|
||||||
|
if left_result and left_result.success:
|
||||||
|
marked_left_img = draw_detections_on_image(img_left, left_result.detections)
|
||||||
|
left_marked_path = os.path.join(SAVE_PATH_LEFT_MARKED, left_marked_filename)
|
||||||
|
cv2.imwrite(left_marked_path, marked_left_img)
|
||||||
|
logger.info(f"Saved marked left image: {left_marked_path}")
|
||||||
|
|
||||||
|
if right_result and right_result.success:
|
||||||
|
marked_right_img = draw_detections_on_image(img_right, right_result.detections)
|
||||||
|
right_marked_path = os.path.join(SAVE_PATH_RIGHT_MARKED, right_marked_filename)
|
||||||
|
cv2.imwrite(right_marked_path, marked_right_img)
|
||||||
|
logger.info(f"Saved marked right image: {right_marked_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing images with VisionAPIClient: {e}")
|
||||||
|
# 即使处理失败,也继续保存原图
|
||||||
|
|
||||||
|
# 将图片信息写入数据库,包含标注图片字段
|
||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
INSERT INTO images (left_filename, right_filename, timestamp, metadata, comment)
|
INSERT INTO images (left_filename, right_filename, left_marked_filename, right_marked_filename, timestamp, metadata, comment)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
''', (left_filename, right_filename, float(timestamp_str), json.dumps(metadata), comment))
|
''', (left_filename, right_filename,
|
||||||
|
left_marked_filename if left_marked_path else None,
|
||||||
|
right_marked_filename if right_marked_path else None,
|
||||||
|
float(timestamp_str), json.dumps(metadata), comment))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
image_id = cursor.lastrowid # 获取新插入记录的 ID
|
image_id = cursor.lastrowid # 获取新插入记录的 ID
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -282,7 +394,7 @@ def upload_images():
|
|||||||
|
|
||||||
@app.route('/api/images/comment', methods=['PUT'])
|
@app.route('/api/images/comment', methods=['PUT'])
|
||||||
def update_image_comment():
|
def update_image_comment():
|
||||||
"""API: 更新图片的comment"""
|
"""API: 更新图片的 comment"""
|
||||||
data = request.json
|
data = request.json
|
||||||
image_id = data.get('id')
|
image_id = data.get('id')
|
||||||
comment = data.get('comment', '')
|
comment = data.get('comment', '')
|
||||||
@@ -292,7 +404,7 @@ def update_image_comment():
|
|||||||
|
|
||||||
conn = sqlite3.connect(DATABASE_PATH)
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
# 更新comment字段
|
# 更新 comment 字段
|
||||||
cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id))
|
cursor.execute("UPDATE images SET comment = ? WHERE id = ?", (comment, image_id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import uuid
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
class ImageClient:
|
class ImageClient:
|
||||||
def __init__(self, server_address="tcp://10.42.70.1:54321", client_id=None):
|
def __init__(self, server_address="tcp://:54321", client_id=None):
|
||||||
self.server_address = server_address
|
self.server_address = server_address
|
||||||
self.client_id = client_id or f"client_{uuid.uuid4().hex[:8]}"
|
self.client_id = client_id or f"client_{uuid.uuid4().hex[:8]}"
|
||||||
self.socket = None
|
self.socket = None
|
||||||
@@ -43,7 +43,7 @@ class ImageClient:
|
|||||||
"""发送同步请求到服务器"""
|
"""发送同步请求到服务器"""
|
||||||
socket = None
|
socket = None
|
||||||
try:
|
try:
|
||||||
# 每次都创建一个新的socket实例以避免复用已关闭的socket
|
# 每次都创建一个新的 socket 实例以避免复用已关闭的 socket
|
||||||
socket = pynng.Req0()
|
socket = pynng.Req0()
|
||||||
socket.dial(self.server_address, block=True) # 使用阻塞模式确保连接建立
|
socket.dial(self.server_address, block=True) # 使用阻塞模式确保连接建立
|
||||||
client_id_bytes = self.client_id.encode('utf-8')
|
client_id_bytes = self.client_id.encode('utf-8')
|
||||||
@@ -54,7 +54,7 @@ class ImageClient:
|
|||||||
print(f"Client {self.client_id} error: {e}")
|
print(f"Client {self.client_id} error: {e}")
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
# 确保socket在使用后被正确关闭
|
# 确保 socket 在使用后被正确关闭
|
||||||
if socket:
|
if socket:
|
||||||
try:
|
try:
|
||||||
socket.close()
|
socket.close()
|
||||||
|
|||||||
610
llm_req.py
Normal file
610
llm_req.py
Normal file
@@ -0,0 +1,610 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
from openai import OpenAI
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
API_KEY = "sk-e3a0287ece6a41bb9b79b2c285f10197"
|
||||||
|
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
MODEL_NAME = "qwen-vl-plus"
|
||||||
|
|
||||||
|
# Category mapping
|
||||||
|
CATEGORY_MAPPING = {
|
||||||
|
1: "caisson",
|
||||||
|
2: "soldier",
|
||||||
|
3: "gun",
|
||||||
|
4: "number"
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY_COLORS = {
|
||||||
|
1: (0, 255, 0), # Green for caisson
|
||||||
|
2: (0, 255, 255), # Yellow for soldier
|
||||||
|
3: (0, 0, 255), # Red for gun
|
||||||
|
4: (255, 0, 0) # Blue for number
|
||||||
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DetectionResult:
|
||||||
|
"""Detection result data class"""
|
||||||
|
image_id: str
|
||||||
|
original_image: np.ndarray
|
||||||
|
detections: List[Dict[str, Any]]
|
||||||
|
marked_image: Optional[np.ndarray] = None
|
||||||
|
success: bool = False
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
timestamp: float = 0.0
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task status enumeration"""
|
||||||
|
PENDING = "pending"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
"""Task data class for queue"""
|
||||||
|
task_id: str
|
||||||
|
image_id: str
|
||||||
|
image: np.ndarray # OpenCV Mat format
|
||||||
|
prompt: str
|
||||||
|
callback: Optional[Callable[[DetectionResult], None]] = None
|
||||||
|
timestamp: float = 0.0
|
||||||
|
|
||||||
|
class VisionAPIClient:
|
||||||
|
"""Vision API Client for asynchronous processing with OpenCV Mat input/output"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = API_KEY, base_url: str = BASE_URL,
|
||||||
|
model_name: str = MODEL_NAME, max_workers: int = 4):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
# Task management
|
||||||
|
self.task_queue = queue.Queue()
|
||||||
|
self.processing_tasks = {}
|
||||||
|
self.results_cache = {}
|
||||||
|
self._running = False
|
||||||
|
self._worker_thread = None
|
||||||
|
|
||||||
|
self.default_prompt = """Please perform object detection on the image, identifying and localizing the following four types of targets:
|
||||||
|
- Category 1: Green ammunition box (caisson)
|
||||||
|
- Category 2: Dummy soldier wearing digital camouflage uniform (soldier)
|
||||||
|
- Category 3: Gun (gun)
|
||||||
|
- Category 4: Round blue number plate (number)
|
||||||
|
|
||||||
|
Please follow these requirements for the output:
|
||||||
|
1. The output must be in valid JSON format.
|
||||||
|
2. The JSON structure should contain a list named "detections".
|
||||||
|
3. Each element in the list represents a detected target, containing the following fields:
|
||||||
|
- "id" (integer): Target category ID (1, 2, 3, 4).
|
||||||
|
- "label" (string): Target category name ("caisson", "soldier", "gun", "number").
|
||||||
|
- "bbox" (list of int): Bounding box coordinates in format [x_min, y_min, x_max, y_max], where (x_min, y_min) is the top-left coordinate and (x_max, y_max) is the bottom-right coordinate. Coordinate values are integers normalized to the 0-999 range (0,0 represents top-left, 999,999 represents bottom-right).
|
||||||
|
4. If no targets are detected in the image, "detections" should be an empty list [].
|
||||||
|
5. Please output only JSON, no other explanatory text.
|
||||||
|
|
||||||
|
JSON output example (when targets are detected):
|
||||||
|
{
|
||||||
|
"detections": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"label": "caisson",
|
||||||
|
"bbox": [x1, y1, x2, y2] // x1, y1, x2, y2 are integers in the 0-999 range
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"label": "soldier",
|
||||||
|
"bbox": [x3, y3, x4, y4] // x3, y3, x4, y4 are integers in the 0-999 range
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
JSON output example (when no targets are detected):
|
||||||
|
{
|
||||||
|
"detections": []
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def encode_cv_image(self, image: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
Encodes an OpenCV image (Mat) to base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (Mat) format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded string of the image
|
||||||
|
"""
|
||||||
|
# Encode the image to JPEG format
|
||||||
|
_, buffer = cv2.imencode('.jpg', image)
|
||||||
|
return base64.b64encode(buffer).decode('utf-8')
|
||||||
|
|
||||||
|
def validate_and_extract_json(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Validates and extracts JSON from API response text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_text: Raw response text from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON dictionary if valid, None otherwise
|
||||||
|
"""
|
||||||
|
# Try to find JSON within the response text (in case of additional text)
|
||||||
|
start_idx = response_text.find('{')
|
||||||
|
end_idx = response_text.rfind('}')
|
||||||
|
|
||||||
|
if start_idx == -1 or end_idx == -1 or start_idx > end_idx:
|
||||||
|
print("No valid JSON structure found in response.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_str = response_text[start_idx:end_idx+1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_json = json.loads(json_str)
|
||||||
|
return parsed_json
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"JSON parsing failed: {e}")
|
||||||
|
print(f"Problematic JSON string: {json_str[:200]}...") # Show first 200 chars
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_detections_format(self, data: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Validates the structure and content of the detections JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Parsed JSON data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if format is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict) or "detections" not in data:
|
||||||
|
print("Missing 'detections' key in response.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
detections = data["detections"]
|
||||||
|
if not isinstance(detections, list):
|
||||||
|
print("'detections' is not a list.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i, detection in enumerate(detections):
|
||||||
|
if not isinstance(detection, dict):
|
||||||
|
print(f"Detection item {i} is not a dictionary.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
required_keys = ["id", "label", "bbox"]
|
||||||
|
for key in required_keys:
|
||||||
|
if key not in detection:
|
||||||
|
print(f"Missing required key '{key}' in detection {i}.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate ID
|
||||||
|
if not isinstance(detection["id"], int) or detection["id"] not in [1, 2, 3, 4]:
|
||||||
|
print(f"Invalid ID in detection {i}: {detection['id']}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate label
|
||||||
|
if not isinstance(detection["label"], str) or detection["label"] not in CATEGORY_MAPPING.values():
|
||||||
|
print(f"Invalid label in detection {i}: {detection['label']}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate bbox
|
||||||
|
bbox = detection["bbox"]
|
||||||
|
if not isinstance(bbox, list) or len(bbox) != 4:
|
||||||
|
print(f"Invalid bbox format in detection {i}: {bbox}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
for coord in bbox:
|
||||||
|
if not isinstance(coord, (int, float)) or not (0 <= coord <= 999):
|
||||||
|
print(f"Invalid bbox coordinate in detection {i}: {coord}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate confidence if present
|
||||||
|
if "confidence" in detection:
|
||||||
|
conf = detection["confidence"]
|
||||||
|
if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0):
|
||||||
|
print(f"Invalid confidence in detection {i}: {conf}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def call_vision_api_sync(self, image: np.ndarray, prompt: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Synchronous call to the vision API with OpenCV Mat input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (Mat) format
|
||||||
|
prompt: The prompt to send to the API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response if successful, None otherwise
|
||||||
|
"""
|
||||||
|
# Resize image to 1000x1000 if needed
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
if h != 1000 or w != 1000:
|
||||||
|
image = cv2.resize(image, (1000, 1000))
|
||||||
|
|
||||||
|
# Encode the image directly to base64
|
||||||
|
image_base64 = self.encode_cv_image(image)
|
||||||
|
image_url = f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the completion request
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=False # Set to False for single response instead of streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
if response.choices and len(response.choices) > 0:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
else:
|
||||||
|
print("No choices returned from API.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
print("No content returned from API.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate and parse JSON from response
|
||||||
|
parsed_data = self.validate_and_extract_json(content)
|
||||||
|
|
||||||
|
if parsed_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate the structure of the detections
|
||||||
|
if not self.validate_detections_format(parsed_data):
|
||||||
|
print("Invalid detections format in response.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return parsed_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"API request failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def draw_detections_on_image(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Draws bounding boxes and labels on the OpenCV Mat image (without confidence).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (Mat) format
|
||||||
|
detections: List of detection dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image with drawn detections as numpy array
|
||||||
|
"""
|
||||||
|
# Work on a copy to avoid modifying the original
|
||||||
|
result_image = image.copy()
|
||||||
|
|
||||||
|
# Get image dimensions
|
||||||
|
img_h, img_w = result_image.shape[:2]
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
# Get bounding box coordinates (normalized to 0-999 range)
|
||||||
|
bbox = detection["bbox"]
|
||||||
|
x1_norm, y1_norm, x2_norm, y2_norm = bbox
|
||||||
|
|
||||||
|
# Convert normalized coordinates to pixel coordinates
|
||||||
|
x1 = int((x1_norm / 999) * img_w)
|
||||||
|
y1 = int((y1_norm / 999) * img_h)
|
||||||
|
x2 = int((x2_norm / 999) * img_w)
|
||||||
|
y2 = int((y2_norm / 999) * img_h)
|
||||||
|
|
||||||
|
# Ensure coordinates are within image bounds
|
||||||
|
x1 = max(0, min(x1, img_w - 1))
|
||||||
|
y1 = max(0, min(y1, img_h - 1))
|
||||||
|
x2 = max(0, min(x2, img_w - 1))
|
||||||
|
y2 = max(0, min(y2, img_h - 1))
|
||||||
|
|
||||||
|
# Get color and label
|
||||||
|
category_id = detection["id"]
|
||||||
|
label = detection["label"]
|
||||||
|
color = CATEGORY_COLORS.get(category_id, (255, 255, 255)) # Default white if not found
|
||||||
|
|
||||||
|
# Draw bounding box
|
||||||
|
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# Prepare label text (no confidence)
|
||||||
|
label_text = label
|
||||||
|
|
||||||
|
# Calculate text size and position
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = 0.6
|
||||||
|
thickness = 2
|
||||||
|
(text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
|
||||||
|
|
||||||
|
# Draw label background
|
||||||
|
cv2.rectangle(result_image, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
|
||||||
|
|
||||||
|
# Draw label text
|
||||||
|
cv2.putText(result_image, label_text, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness)
|
||||||
|
|
||||||
|
return result_image
|
||||||
|
|
||||||
|
def process_single_image(self, image_id: str, image: np.ndarray, prompt: str = None) -> DetectionResult:
|
||||||
|
"""
|
||||||
|
Process a single OpenCV Mat image synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Unique identifier for the image
|
||||||
|
image: OpenCV image (Mat) format
|
||||||
|
prompt: The prompt for the vision API (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DetectionResult containing the results
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Validate input image
|
||||||
|
if image is None or image.size == 0:
|
||||||
|
error_msg = f"Invalid image for image_id: {image_id}"
|
||||||
|
print(error_msg)
|
||||||
|
return DetectionResult(
|
||||||
|
image_id=image_id,
|
||||||
|
original_image=image,
|
||||||
|
detections=[],
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg,
|
||||||
|
timestamp=start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use provided prompt or default
|
||||||
|
use_prompt = prompt if prompt is not None else self.default_prompt
|
||||||
|
|
||||||
|
# Call the vision API
|
||||||
|
print(f"Calling vision API for image {image_id}...")
|
||||||
|
result = self.call_vision_api_sync(image, use_prompt)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
error_msg = "Failed to get valid response from API."
|
||||||
|
print(error_msg)
|
||||||
|
return DetectionResult(
|
||||||
|
image_id=image_id,
|
||||||
|
original_image=image,
|
||||||
|
detections=[],
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg,
|
||||||
|
timestamp=start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract detections
|
||||||
|
detections = result.get("detections", [])
|
||||||
|
print(f"Found {len(detections)} detections for image {image_id}.")
|
||||||
|
|
||||||
|
# Draw detections on image
|
||||||
|
try:
|
||||||
|
marked_image = self.draw_detections_on_image(image, detections)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error drawing detections on image: {e}"
|
||||||
|
print(error_msg)
|
||||||
|
return DetectionResult(
|
||||||
|
image_id=image_id,
|
||||||
|
original_image=image,
|
||||||
|
detections=detections,
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg,
|
||||||
|
timestamp=start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return successful result
|
||||||
|
return DetectionResult(
|
||||||
|
image_id=image_id,
|
||||||
|
original_image=image,
|
||||||
|
detections=detections,
|
||||||
|
marked_image=marked_image,
|
||||||
|
success=True,
|
||||||
|
timestamp=start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
def start_worker(self):
|
||||||
|
"""Start the background worker thread for processing tasks"""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||||
|
self._worker_thread.start()
|
||||||
|
print("Vision API worker started.")
|
||||||
|
|
||||||
|
def stop_worker(self):
|
||||||
|
"""Stop the background worker thread"""
|
||||||
|
self._running = False
|
||||||
|
if self._worker_thread:
|
||||||
|
self._worker_thread.join(timeout=5.0) # Wait up to 5 seconds
|
||||||
|
print("Vision API worker stopped.")
|
||||||
|
|
||||||
|
def _worker_loop(self):
|
||||||
|
"""Background worker loop for processing tasks from queue"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
# Get task from queue with timeout
|
||||||
|
task = self.task_queue.get(timeout=1.0)
|
||||||
|
|
||||||
|
# Mark as processing
|
||||||
|
self.processing_tasks[task.task_id] = TaskStatus.PROCESSING
|
||||||
|
|
||||||
|
# Process the task
|
||||||
|
result = self.process_single_image(task.image_id, task.image, task.prompt)
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
self.results_cache[task.task_id] = result
|
||||||
|
|
||||||
|
# Update task status
|
||||||
|
self.processing_tasks[task.task_id] = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
|
||||||
|
|
||||||
|
# Call callback if provided
|
||||||
|
if task.callback:
|
||||||
|
try:
|
||||||
|
task.callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Callback execution failed for task {task.task_id}: {e}")
|
||||||
|
|
||||||
|
# Mark task as done
|
||||||
|
self.task_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
continue # Timeout, continue loop
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Worker error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def submit_task(self, image_id: int, image: np.ndarray, prompt: str = None,
|
||||||
|
callback: Callable[[DetectionResult], None] = None) -> str:
|
||||||
|
"""
|
||||||
|
Submit a task to the processing queue with OpenCV Mat input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Unique identifier for the image
|
||||||
|
image: OpenCV image (Mat) format
|
||||||
|
prompt: The prompt for the vision API (optional)
|
||||||
|
callback: Callback function to be called when processing is complete (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task ID for tracking the task
|
||||||
|
"""
|
||||||
|
task_id = f"task_{int(time.time() * 1000000)}_{image_id}" # Generate unique task ID
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
image_id=image_id,
|
||||||
|
image=image,
|
||||||
|
prompt=prompt if prompt is not None else self.default_prompt,
|
||||||
|
callback=callback,
|
||||||
|
timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.task_queue.put(task)
|
||||||
|
self.processing_tasks[task_id] = TaskStatus.PENDING
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def get_result(self, task_id: str) -> Optional[DetectionResult]:
|
||||||
|
"""
|
||||||
|
Get the result for a specific task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to retrieve result for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DetectionResult if available, None otherwise
|
||||||
|
"""
|
||||||
|
return self.results_cache.get(task_id)
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
||||||
|
"""
|
||||||
|
Get the status of a specific task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to check status for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskStatus if task exists, None otherwise
|
||||||
|
"""
|
||||||
|
return self.processing_tasks.get(task_id)
|
||||||
|
|
||||||
|
def get_queue_size(self) -> int:
|
||||||
|
"""Get the current size of the task queue"""
|
||||||
|
return self.task_queue.qsize()
|
||||||
|
|
||||||
|
def get_processing_count(self) -> int:
|
||||||
|
"""Get the number of currently processing tasks"""
|
||||||
|
return sum(1 for status in self.processing_tasks.values()
|
||||||
|
if status == TaskStatus.PROCESSING)
|
||||||
|
|
||||||
|
def get_completed_count(self) -> int:
|
||||||
|
"""Get the number of completed tasks"""
|
||||||
|
return sum(1 for status in self.processing_tasks.values()
|
||||||
|
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED])
|
||||||
|
|
||||||
|
def clear_results(self):
|
||||||
|
"""Clear the results cache"""
|
||||||
|
self.results_cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry"""
|
||||||
|
self.start_worker()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit"""
|
||||||
|
self.stop_worker()
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
def example_callback(result: DetectionResult):
|
||||||
|
"""Example callback function"""
|
||||||
|
if result.success:
|
||||||
|
print(f"Callback: Processing completed for image {result.image_id}, found {len(result.detections)} detections")
|
||||||
|
# The result.marked_image is the OpenCV Mat with detections drawn
|
||||||
|
marked_image = result.marked_image
|
||||||
|
# You can now use the marked_image for further processing
|
||||||
|
else:
|
||||||
|
print(f"Callback: Processing failed for image {result.image_id}: {result.error_message}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
original_image = cv2.imread("/home/evan/Desktop/received/left/left_1761388243_7673044.jpg") # Replace with your image source
|
||||||
|
|
||||||
|
if original_image is None:
|
||||||
|
print("Could not load image")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Example usage with context manager
|
||||||
|
with VisionAPIClient() as client:
|
||||||
|
# Submit a task with OpenCV Mat
|
||||||
|
task_id = client.submit_task(
|
||||||
|
image_id=1,
|
||||||
|
image=original_image,
|
||||||
|
callback=example_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Submitted task {task_id}")
|
||||||
|
|
||||||
|
# Wait for the task to complete
|
||||||
|
print("Waiting for task to complete...")
|
||||||
|
client.task_queue.join() # Wait for all tasks in queue to be processed
|
||||||
|
|
||||||
|
# Get the result
|
||||||
|
result = client.get_result(task_id)
|
||||||
|
if result:
|
||||||
|
if result.success:
|
||||||
|
print(f"Task completed successfully! Found {len(result.detections)} detections.")
|
||||||
|
# result.marked_image is the OpenCV Mat with detections drawn
|
||||||
|
marked_image = result.marked_image
|
||||||
|
|
||||||
|
# Display the result (optional)
|
||||||
|
cv2.imshow("Original Image", original_image)
|
||||||
|
cv2.imshow("Marked Image", marked_image)
|
||||||
|
print("Press any key to close windows...")
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# Or save the result
|
||||||
|
# cv2.imwrite("marked_image.jpg", marked_image)
|
||||||
|
else:
|
||||||
|
print(f"Task failed: {result.error_message}")
|
||||||
|
else:
|
||||||
|
print("No result found")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html lang="en">
|
||||||
|
|
||||||
<head>
|
<head>
|
||||||
<title>Saved Images List</title>
|
<meta charset="UTF-8">
|
||||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.min.js "></script>
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Saved Images</title>
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.js"></script>
|
||||||
<style>
|
<style>
|
||||||
body {
|
body {
|
||||||
font-family: Arial, sans-serif;
|
font-family: Arial, sans-serif;
|
||||||
@@ -15,11 +17,12 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
button {
|
button {
|
||||||
padding: 8px 16px;
|
padding: 10px 15px;
|
||||||
margin-right: 10px;
|
margin-right: 10px;
|
||||||
background-color: #007bff;
|
background-color: #007bff;
|
||||||
color: white;
|
color: white;
|
||||||
border: none;
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,6 +100,8 @@
|
|||||||
<th>ID</th>
|
<th>ID</th>
|
||||||
<th>Left Image</th>
|
<th>Left Image</th>
|
||||||
<th>Right Image</th>
|
<th>Right Image</th>
|
||||||
|
<th>Left Marked Image</th> <!-- 新增标注图片列 -->
|
||||||
|
<th>Right Marked Image</th> <!-- 新增标注图片列 -->
|
||||||
<th>Timestamp</th>
|
<th>Timestamp</th>
|
||||||
<th>Comment</th>
|
<th>Comment</th>
|
||||||
<th>Actions</th>
|
<th>Actions</th>
|
||||||
@@ -138,6 +143,7 @@
|
|||||||
row.insertCell(0).innerHTML = `<input type="checkbox" class="selectCheckbox" data-id="${image.id}">`;
|
row.insertCell(0).innerHTML = `<input type="checkbox" class="selectCheckbox" data-id="${image.id}">`;
|
||||||
row.insertCell(1).textContent = image.id;
|
row.insertCell(1).textContent = image.id;
|
||||||
|
|
||||||
|
// 原图
|
||||||
const leftCell = row.insertCell(2);
|
const leftCell = row.insertCell(2);
|
||||||
const leftImg = document.createElement('img');
|
const leftImg = document.createElement('img');
|
||||||
// 修改这里:使用 Flask 静态文件路径
|
// 修改这里:使用 Flask 静态文件路径
|
||||||
@@ -156,10 +162,35 @@
|
|||||||
rightImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
rightImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||||
rightCell.appendChild(rightImg);
|
rightCell.appendChild(rightImg);
|
||||||
|
|
||||||
row.insertCell(4).textContent = new Date(image.timestamp * 1000).toISOString();
|
// 标注图片
|
||||||
|
const leftMarkedCell = row.insertCell(4);
|
||||||
|
if (image.left_marked_filename) {
|
||||||
|
const leftMarkedImg = document.createElement('img');
|
||||||
|
leftMarkedImg.src = `/static/received/left_marked/${image.left_marked_filename}`;
|
||||||
|
leftMarkedImg.alt = "Left Marked Image";
|
||||||
|
leftMarkedImg.className = 'image-preview';
|
||||||
|
leftMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||||
|
leftMarkedCell.appendChild(leftMarkedImg);
|
||||||
|
} else {
|
||||||
|
leftMarkedCell.textContent = "N/A";
|
||||||
|
}
|
||||||
|
|
||||||
// 添加可编辑的comment单元格
|
const rightMarkedCell = row.insertCell(5);
|
||||||
const commentCell = row.insertCell(5);
|
if (image.right_marked_filename) {
|
||||||
|
const rightMarkedImg = document.createElement('img');
|
||||||
|
rightMarkedImg.src = `/static/received/right_marked/${image.right_marked_filename}`;
|
||||||
|
rightMarkedImg.alt = "Right Marked Image";
|
||||||
|
rightMarkedImg.className = 'image-preview';
|
||||||
|
rightMarkedImg.onerror = function () { this.src = 'data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="75" viewBox="0 0 100 75"><rect width="100" height="75" fill="%23eee"/><text x="50" y="40" font-family="Arial" font-size="12" fill="%23999" text-anchor="middle">No Image</text></svg>'; };
|
||||||
|
rightMarkedCell.appendChild(rightMarkedImg);
|
||||||
|
} else {
|
||||||
|
rightMarkedCell.textContent = "N/A";
|
||||||
|
}
|
||||||
|
|
||||||
|
row.insertCell(6).textContent = new Date(image.timestamp * 1000).toISOString();
|
||||||
|
|
||||||
|
// 添加可编辑的 comment 单元格
|
||||||
|
const commentCell = row.insertCell(7);
|
||||||
const commentInput = document.createElement('input');
|
const commentInput = document.createElement('input');
|
||||||
commentInput.type = 'text';
|
commentInput.type = 'text';
|
||||||
commentInput.value = image.comment || '';
|
commentInput.value = image.comment || '';
|
||||||
@@ -171,11 +202,11 @@
|
|||||||
});
|
});
|
||||||
commentCell.appendChild(commentInput);
|
commentCell.appendChild(commentInput);
|
||||||
|
|
||||||
row.insertCell(6).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
|
row.insertCell(8).innerHTML = `<button onclick="deleteImage(${image.id})">Delete</button>`;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加更新comment的函数
|
// 添加更新 comment 的函数
|
||||||
async function updateComment(id, comment) {
|
async function updateComment(id, comment) {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/images/comment', {
|
const response = await fetch('/api/images/comment', {
|
||||||
|
|||||||
Reference in New Issue
Block a user