feat: 增加标注相关功能
This commit is contained in:
610
llm_req.py
Normal file
610
llm_req.py
Normal file
@@ -0,0 +1,610 @@
|
||||
import base64
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from openai import OpenAI
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import queue
|
||||
import time
|
||||
|
||||
# Configuration
|
||||
API_KEY = "sk-e3a0287ece6a41bb9b79b2c285f10197"
|
||||
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
MODEL_NAME = "qwen-vl-plus"
|
||||
|
||||
# Category mapping
|
||||
CATEGORY_MAPPING = {
|
||||
1: "caisson",
|
||||
2: "soldier",
|
||||
3: "gun",
|
||||
4: "number"
|
||||
}
|
||||
|
||||
CATEGORY_COLORS = {
|
||||
1: (0, 255, 0), # Green for caisson
|
||||
2: (0, 255, 255), # Yellow for soldier
|
||||
3: (0, 0, 255), # Red for gun
|
||||
4: (255, 0, 0) # Blue for number
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Detection result data class"""
|
||||
image_id: str
|
||||
original_image: np.ndarray
|
||||
detections: List[Dict[str, Any]]
|
||||
marked_image: Optional[np.ndarray] = None
|
||||
success: bool = False
|
||||
error_message: Optional[str] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task status enumeration"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task data class for queue"""
|
||||
task_id: str
|
||||
image_id: str
|
||||
image: np.ndarray # OpenCV Mat format
|
||||
prompt: str
|
||||
callback: Optional[Callable[[DetectionResult], None]] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
class VisionAPIClient:
|
||||
"""Vision API Client for asynchronous processing with OpenCV Mat input/output"""
|
||||
|
||||
def __init__(self, api_key: str = API_KEY, base_url: str = BASE_URL,
|
||||
model_name: str = MODEL_NAME, max_workers: int = 4):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.max_workers = max_workers
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
# Task management
|
||||
self.task_queue = queue.Queue()
|
||||
self.processing_tasks = {}
|
||||
self.results_cache = {}
|
||||
self._running = False
|
||||
self._worker_thread = None
|
||||
|
||||
self.default_prompt = """Please perform object detection on the image, identifying and localizing the following four types of targets:
|
||||
- Category 1: Green ammunition box (caisson)
|
||||
- Category 2: Dummy soldier wearing digital camouflage uniform (soldier)
|
||||
- Category 3: Gun (gun)
|
||||
- Category 4: Round blue number plate (number)
|
||||
|
||||
Please follow these requirements for the output:
|
||||
1. The output must be in valid JSON format.
|
||||
2. The JSON structure should contain a list named "detections".
|
||||
3. Each element in the list represents a detected target, containing the following fields:
|
||||
- "id" (integer): Target category ID (1, 2, 3, 4).
|
||||
- "label" (string): Target category name ("caisson", "soldier", "gun", "number").
|
||||
- "bbox" (list of int): Bounding box coordinates in format [x_min, y_min, x_max, y_max], where (x_min, y_min) is the top-left coordinate and (x_max, y_max) is the bottom-right coordinate. Coordinate values are integers normalized to the 0-999 range (0,0 represents top-left, 999,999 represents bottom-right).
|
||||
4. If no targets are detected in the image, "detections" should be an empty list [].
|
||||
5. Please output only JSON, no other explanatory text.
|
||||
|
||||
JSON output example (when targets are detected):
|
||||
{
|
||||
"detections": [
|
||||
{
|
||||
"id": 1,
|
||||
"label": "caisson",
|
||||
"bbox": [x1, y1, x2, y2] // x1, y1, x2, y2 are integers in the 0-999 range
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"label": "soldier",
|
||||
"bbox": [x3, y3, x4, y4] // x3, y3, x4, y4 are integers in the 0-999 range
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
JSON output example (when no targets are detected):
|
||||
{
|
||||
"detections": []
|
||||
}"""
|
||||
|
||||
def encode_cv_image(self, image: np.ndarray) -> str:
|
||||
"""
|
||||
Encodes an OpenCV image (Mat) to base64 string.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image
|
||||
"""
|
||||
# Encode the image to JPEG format
|
||||
_, buffer = cv2.imencode('.jpg', image)
|
||||
return base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
def validate_and_extract_json(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validates and extracts JSON from API response text.
|
||||
|
||||
Args:
|
||||
response_text: Raw response text from API
|
||||
|
||||
Returns:
|
||||
Parsed JSON dictionary if valid, None otherwise
|
||||
"""
|
||||
# Try to find JSON within the response text (in case of additional text)
|
||||
start_idx = response_text.find('{')
|
||||
end_idx = response_text.rfind('}')
|
||||
|
||||
if start_idx == -1 or end_idx == -1 or start_idx > end_idx:
|
||||
print("No valid JSON structure found in response.")
|
||||
return None
|
||||
|
||||
json_str = response_text[start_idx:end_idx+1]
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(json_str)
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing failed: {e}")
|
||||
print(f"Problematic JSON string: {json_str[:200]}...") # Show first 200 chars
|
||||
return None
|
||||
|
||||
def validate_detections_format(self, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validates the structure and content of the detections JSON.
|
||||
|
||||
Args:
|
||||
data: Parsed JSON data
|
||||
|
||||
Returns:
|
||||
True if format is valid, False otherwise
|
||||
"""
|
||||
if not isinstance(data, dict) or "detections" not in data:
|
||||
print("Missing 'detections' key in response.")
|
||||
return False
|
||||
|
||||
detections = data["detections"]
|
||||
if not isinstance(detections, list):
|
||||
print("'detections' is not a list.")
|
||||
return False
|
||||
|
||||
for i, detection in enumerate(detections):
|
||||
if not isinstance(detection, dict):
|
||||
print(f"Detection item {i} is not a dictionary.")
|
||||
return False
|
||||
|
||||
required_keys = ["id", "label", "bbox"]
|
||||
for key in required_keys:
|
||||
if key not in detection:
|
||||
print(f"Missing required key '{key}' in detection {i}.")
|
||||
return False
|
||||
|
||||
# Validate ID
|
||||
if not isinstance(detection["id"], int) or detection["id"] not in [1, 2, 3, 4]:
|
||||
print(f"Invalid ID in detection {i}: {detection['id']}")
|
||||
return False
|
||||
|
||||
# Validate label
|
||||
if not isinstance(detection["label"], str) or detection["label"] not in CATEGORY_MAPPING.values():
|
||||
print(f"Invalid label in detection {i}: {detection['label']}")
|
||||
return False
|
||||
|
||||
# Validate bbox
|
||||
bbox = detection["bbox"]
|
||||
if not isinstance(bbox, list) or len(bbox) != 4:
|
||||
print(f"Invalid bbox format in detection {i}: {bbox}")
|
||||
return False
|
||||
|
||||
for coord in bbox:
|
||||
if not isinstance(coord, (int, float)) or not (0 <= coord <= 999):
|
||||
print(f"Invalid bbox coordinate in detection {i}: {coord}")
|
||||
return False
|
||||
|
||||
# Validate confidence if present
|
||||
if "confidence" in detection:
|
||||
conf = detection["confidence"]
|
||||
if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0):
|
||||
print(f"Invalid confidence in detection {i}: {conf}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def call_vision_api_sync(self, image: np.ndarray, prompt: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Synchronous call to the vision API with OpenCV Mat input.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt to send to the API
|
||||
|
||||
Returns:
|
||||
Parsed JSON response if successful, None otherwise
|
||||
"""
|
||||
# Resize image to 1000x1000 if needed
|
||||
h, w = image.shape[:2]
|
||||
if h != 1000 or w != 1000:
|
||||
image = cv2.resize(image, (1000, 1000))
|
||||
|
||||
# Encode the image directly to base64
|
||||
image_base64 = self.encode_cv_image(image)
|
||||
image_url = f"data:image/jpeg;base64,{image_base64}"
|
||||
|
||||
try:
|
||||
# Create the completion request
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
]
|
||||
}
|
||||
],
|
||||
stream=False # Set to False for single response instead of streaming
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
else:
|
||||
print("No choices returned from API.")
|
||||
return None
|
||||
|
||||
if not content:
|
||||
print("No content returned from API.")
|
||||
return None
|
||||
|
||||
# Validate and parse JSON from response
|
||||
parsed_data = self.validate_and_extract_json(content)
|
||||
|
||||
if parsed_data is None:
|
||||
return None
|
||||
|
||||
# Validate the structure of the detections
|
||||
if not self.validate_detections_format(parsed_data):
|
||||
print("Invalid detections format in response.")
|
||||
return None
|
||||
|
||||
return parsed_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"API request failed: {e}")
|
||||
return None
|
||||
|
||||
def draw_detections_on_image(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes and labels on the OpenCV Mat image (without confidence).
|
||||
|
||||
Args:
|
||||
image: OpenCV image (Mat) format
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Image with drawn detections as numpy array
|
||||
"""
|
||||
# Work on a copy to avoid modifying the original
|
||||
result_image = image.copy()
|
||||
|
||||
# Get image dimensions
|
||||
img_h, img_w = result_image.shape[:2]
|
||||
|
||||
for detection in detections:
|
||||
# Get bounding box coordinates (normalized to 0-999 range)
|
||||
bbox = detection["bbox"]
|
||||
x1_norm, y1_norm, x2_norm, y2_norm = bbox
|
||||
|
||||
# Convert normalized coordinates to pixel coordinates
|
||||
x1 = int((x1_norm / 999) * img_w)
|
||||
y1 = int((y1_norm / 999) * img_h)
|
||||
x2 = int((x2_norm / 999) * img_w)
|
||||
y2 = int((y2_norm / 999) * img_h)
|
||||
|
||||
# Ensure coordinates are within image bounds
|
||||
x1 = max(0, min(x1, img_w - 1))
|
||||
y1 = max(0, min(y1, img_h - 1))
|
||||
x2 = max(0, min(x2, img_w - 1))
|
||||
y2 = max(0, min(y2, img_h - 1))
|
||||
|
||||
# Get color and label
|
||||
category_id = detection["id"]
|
||||
label = detection["label"]
|
||||
color = CATEGORY_COLORS.get(category_id, (255, 255, 255)) # Default white if not found
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# Prepare label text (no confidence)
|
||||
label_text = label
|
||||
|
||||
# Calculate text size and position
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 2
|
||||
(text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(result_image, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(result_image, label_text, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness)
|
||||
|
||||
return result_image
|
||||
|
||||
def process_single_image(self, image_id: str, image: np.ndarray, prompt: str = None) -> DetectionResult:
|
||||
"""
|
||||
Process a single OpenCV Mat image synchronously.
|
||||
|
||||
Args:
|
||||
image_id: Unique identifier for the image
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt for the vision API (optional)
|
||||
|
||||
Returns:
|
||||
DetectionResult containing the results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input image
|
||||
if image is None or image.size == 0:
|
||||
error_msg = f"Invalid image for image_id: {image_id}"
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=[],
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Use provided prompt or default
|
||||
use_prompt = prompt if prompt is not None else self.default_prompt
|
||||
|
||||
# Call the vision API
|
||||
print(f"Calling vision API for image {image_id}...")
|
||||
result = self.call_vision_api_sync(image, use_prompt)
|
||||
|
||||
if result is None:
|
||||
error_msg = "Failed to get valid response from API."
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=[],
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Extract detections
|
||||
detections = result.get("detections", [])
|
||||
print(f"Found {len(detections)} detections for image {image_id}.")
|
||||
|
||||
# Draw detections on image
|
||||
try:
|
||||
marked_image = self.draw_detections_on_image(image, detections)
|
||||
except Exception as e:
|
||||
error_msg = f"Error drawing detections on image: {e}"
|
||||
print(error_msg)
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=detections,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
# Return successful result
|
||||
return DetectionResult(
|
||||
image_id=image_id,
|
||||
original_image=image,
|
||||
detections=detections,
|
||||
marked_image=marked_image,
|
||||
success=True,
|
||||
timestamp=start_time
|
||||
)
|
||||
|
||||
def start_worker(self):
|
||||
"""Start the background worker thread for processing tasks"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
print("Vision API worker started.")
|
||||
|
||||
def stop_worker(self):
|
||||
"""Stop the background worker thread"""
|
||||
self._running = False
|
||||
if self._worker_thread:
|
||||
self._worker_thread.join(timeout=5.0) # Wait up to 5 seconds
|
||||
print("Vision API worker stopped.")
|
||||
|
||||
def _worker_loop(self):
|
||||
"""Background worker loop for processing tasks from queue"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
task = self.task_queue.get(timeout=1.0)
|
||||
|
||||
# Mark as processing
|
||||
self.processing_tasks[task.task_id] = TaskStatus.PROCESSING
|
||||
|
||||
# Process the task
|
||||
result = self.process_single_image(task.image_id, task.image, task.prompt)
|
||||
|
||||
# Store result
|
||||
self.results_cache[task.task_id] = result
|
||||
|
||||
# Update task status
|
||||
self.processing_tasks[task.task_id] = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
|
||||
|
||||
# Call callback if provided
|
||||
if task.callback:
|
||||
try:
|
||||
task.callback(result)
|
||||
except Exception as e:
|
||||
print(f"Callback execution failed for task {task.task_id}: {e}")
|
||||
|
||||
# Mark task as done
|
||||
self.task_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
continue # Timeout, continue loop
|
||||
except Exception as e:
|
||||
print(f"Worker error: {e}")
|
||||
continue
|
||||
|
||||
def submit_task(self, image_id: int, image: np.ndarray, prompt: str = None,
|
||||
callback: Callable[[DetectionResult], None] = None) -> str:
|
||||
"""
|
||||
Submit a task to the processing queue with OpenCV Mat input.
|
||||
|
||||
Args:
|
||||
image_id: Unique identifier for the image
|
||||
image: OpenCV image (Mat) format
|
||||
prompt: The prompt for the vision API (optional)
|
||||
callback: Callback function to be called when processing is complete (optional)
|
||||
|
||||
Returns:
|
||||
Task ID for tracking the task
|
||||
"""
|
||||
task_id = f"task_{int(time.time() * 1000000)}_{image_id}" # Generate unique task ID
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
image_id=image_id,
|
||||
image=image,
|
||||
prompt=prompt if prompt is not None else self.default_prompt,
|
||||
callback=callback,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
self.task_queue.put(task)
|
||||
self.processing_tasks[task_id] = TaskStatus.PENDING
|
||||
|
||||
return task_id
|
||||
|
||||
def get_result(self, task_id: str) -> Optional[DetectionResult]:
|
||||
"""
|
||||
Get the result for a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to retrieve result for
|
||||
|
||||
Returns:
|
||||
DetectionResult if available, None otherwise
|
||||
"""
|
||||
return self.results_cache.get(task_id)
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
||||
"""
|
||||
Get the status of a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check status for
|
||||
|
||||
Returns:
|
||||
TaskStatus if task exists, None otherwise
|
||||
"""
|
||||
return self.processing_tasks.get(task_id)
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get the current size of the task queue"""
|
||||
return self.task_queue.qsize()
|
||||
|
||||
def get_processing_count(self) -> int:
|
||||
"""Get the number of currently processing tasks"""
|
||||
return sum(1 for status in self.processing_tasks.values()
|
||||
if status == TaskStatus.PROCESSING)
|
||||
|
||||
def get_completed_count(self) -> int:
|
||||
"""Get the number of completed tasks"""
|
||||
return sum(1 for status in self.processing_tasks.values()
|
||||
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED])
|
||||
|
||||
def clear_results(self):
|
||||
"""Clear the results cache"""
|
||||
self.results_cache.clear()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
self.start_worker()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.stop_worker()
|
||||
|
||||
|
||||
# Example usage
|
||||
def example_callback(result: DetectionResult):
|
||||
"""Example callback function"""
|
||||
if result.success:
|
||||
print(f"Callback: Processing completed for image {result.image_id}, found {len(result.detections)} detections")
|
||||
# The result.marked_image is the OpenCV Mat with detections drawn
|
||||
marked_image = result.marked_image
|
||||
# You can now use the marked_image for further processing
|
||||
else:
|
||||
print(f"Callback: Processing failed for image {result.image_id}: {result.error_message}")
|
||||
|
||||
|
||||
def main():
|
||||
original_image = cv2.imread("/home/evan/Desktop/received/left/left_1761388243_7673044.jpg") # Replace with your image source
|
||||
|
||||
if original_image is None:
|
||||
print("Could not load image")
|
||||
return
|
||||
|
||||
# Example usage with context manager
|
||||
with VisionAPIClient() as client:
|
||||
# Submit a task with OpenCV Mat
|
||||
task_id = client.submit_task(
|
||||
image_id=1,
|
||||
image=original_image,
|
||||
callback=example_callback
|
||||
)
|
||||
|
||||
print(f"Submitted task {task_id}")
|
||||
|
||||
# Wait for the task to complete
|
||||
print("Waiting for task to complete...")
|
||||
client.task_queue.join() # Wait for all tasks in queue to be processed
|
||||
|
||||
# Get the result
|
||||
result = client.get_result(task_id)
|
||||
if result:
|
||||
if result.success:
|
||||
print(f"Task completed successfully! Found {len(result.detections)} detections.")
|
||||
# result.marked_image is the OpenCV Mat with detections drawn
|
||||
marked_image = result.marked_image
|
||||
|
||||
# Display the result (optional)
|
||||
cv2.imshow("Original Image", original_image)
|
||||
cv2.imshow("Marked Image", marked_image)
|
||||
print("Press any key to close windows...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Or save the result
|
||||
# cv2.imwrite("marked_image.jpg", marked_image)
|
||||
else:
|
||||
print(f"Task failed: {result.error_message}")
|
||||
else:
|
||||
print("No result found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user