feat: 增加部分任务

This commit is contained in:
bmy
2024-06-02 17:49:52 +08:00
parent 49c0499f24
commit 4c1cf9ceb0
6 changed files with 458 additions and 161 deletions

View File

@@ -1,24 +1,9 @@
from enum import Enum
import numpy as np
# 根据标签修改
# class tlabel(Enum):
# BBLOCK = 5 # 蓝色方块
# RBLOCK = 2 # 红色方块
# HOSPITAL = 3 # 医院
# BBALL = 4 # 蓝球
# YBALL = 5 # 黄球
# TOWER = 6 # 通信塔
# RBALL = 7 # 红球
# BASKET = 8 # 球筐
# MARKL = 9 # 指向标
# MARKR = 10 # 指向标
# SPILLAR = 11 # 小柱体 (红色)
# MPILLAR = 12 # 中柱体 (蓝色)
# LPILLAR = 13 # 大柱体 (红色)
# SIGN = 14 # 文字标牌
# TARGET = 15 # 目标靶
# SHELTER = 16 # 停车区
# BASE = 17 # 基地
lane_error = 0
class tlabel(Enum):
TOWER = 0
SIGN = 1
@@ -36,6 +21,9 @@ class tlabel(Enum):
LMARK = 13
BBLOCK = 14
BBALL = 15
'''
description: label_filter 的测试数据
'''
test_resp = {
'code': 0,
'data': np.array([
@@ -48,6 +36,9 @@ test1_resp = {
'code': 0,
'data': np.array([])
}
'''
description: yolo 目标检测标签过滤器,需要传入连接到 yolo server 的 socket 对象
'''
class label_filter:
def __init__(self, socket, threshold=0.6):
self.num = 0
@@ -55,16 +46,33 @@ class label_filter:
self.socket = socket
self.threshold = threshold
self.img_size = (320, 240)
self.img_size = (320, 240)
'''
description: 向 yolo server 请求目标检测数据
param {*} self
return {*}
'''
def get_resp(self):
self.socket.send_string('')
response = self.socket.recv_pyobj()
return response
'''
description: 切换 yolo server 视频源 在分叉路口时目标检测需要使用前摄
param {*} self
param {*} camera_id 1 或者 2 字符串
return {*}
'''
def switch_camera(self,camera_id):
if camera_id == 1 or camera_id == 2:
self.socket.send_string(f'{camera_id}')
response = self.socket.recv_pyobj()
return response
'''
description: 对模型推理推理结果使用 threshold 过滤 默认阈值为 0.5
param {*} self
param {*} data get_resp 返回的数据
return {bool,array}
'''
def filter_box(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
@@ -83,10 +91,15 @@ class label_filter:
if len(results) > 0:
return True, np.array(results)
return False, None
'''
description: 根据传入的标签过滤,返回该标签的个数以及 box
param {*} self
param {*} tlabel
return {int, array}
'''
def get(self, tlabel):
# 循环查找匹配的标签值
# 返回对应标签的个数,以及坐标列表
# TODO self.filter_box none judge
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
@@ -97,8 +110,38 @@ class label_filter:
self.pos = boxes[:, 2:] # [[x1 y1 x2 y2]]
return self.num, self.pos
return 0, []
'''
description:
param {*} self
param {*} tlabel_list
return {*}
'''
def get_mult(self, tlabel_list):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
target_counts = len(tlabel_list)
counts = 0
if ret:
for tlabel in tlabel_list:
expect_boxes = (results[:, 0] == tlabel.value)
has_true = np.any(expect_boxes)
if has_true:
counts += 1
else:
return False, []
if counts == target_counts:
return True, counts
return False, []
return False, []
return False, []
'''
description: 判断传入的标签是否存在,存在返回 True
param {*} self
param {*} tlabel
return {bool}
'''
def find(self, tlabel):
# 遍历返回的列表,有对应标签则返回 True
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
@@ -108,6 +151,12 @@ class label_filter:
if len(boxes) != 0:
return True
return False
'''
description: 根据传入的标签,
param {*} self
param {*} tlabel
return {*}
'''
def aim_left(self, tlabel):
# 如果标签存在,则返回列表中位置最靠左的目标框和中心的偏移值
response = self.get_resp()
@@ -151,7 +200,7 @@ class label_filter:
center_x_values = np.abs(boxes[:, 2] + boxes[:, 4] - self.img_size[0])
center_x_index = np.argmin(center_x_values)
error = (boxes[center_x_index][4] + boxes[center_x_index][2] - self.img_size[0]) / 2
return (True, error+15)
return (True, error)
return (False, 0)
# class Calibrate: