Files
project_main/utils.py

484 lines
18 KiB
Python
Raw Normal View History

2024-06-02 17:49:52 +08:00
2024-05-22 18:50:21 +08:00
from enum import Enum
import numpy as np
2024-06-05 16:07:32 +08:00
import erniebot
from simple_pid import PID
2024-06-15 22:04:11 +08:00
from loguru import logger
2024-06-30 21:57:01 +08:00
import threading
# 巡线误差
2024-06-02 17:49:52 +08:00
lane_error = 0
# 进入任务时可以通过修改 task_speed 控制巡线速度
task_speed = 0
2024-06-02 17:49:52 +08:00
2024-05-22 18:50:21 +08:00
class tlabel(Enum):
TPLATFORM = 0
TOWER = 1
SIGN = 2
SHELTER = 3
HOSPITAL = 4
BASKET = 5
BASE = 6
YBALL = 7
SPILLER = 8
RMARK = 9
RBLOCK = 10
RBALL = 11
MPILLER = 12
LPILLER = 13
LMARK = 14
BBLOCK = 15
BBALL = 16
# 岔路口参数
direction = tlabel.RMARK
direction_left = 0
direction_right = 0
2024-06-02 17:49:52 +08:00
'''
description: label_filter 的测试数据
'''
test_resp = {
'code': 0,
'data': np.array([
[4., 0.97192055, 26.64415, 228.26755, 170.16872, 357.6216],
[4., 0.97049206, 474.0152, 251.2854, 612.91644, 381.6831],
[5., 0.972649, 250.84174, 238.43622, 378.115, 367.34906]
])
}
test1_resp = {
'code': 0,
'data': np.array([])
}
2024-06-02 17:49:52 +08:00
'''
description: yolo 目标检测标签过滤器需要传入连接到 yolo server socket 对象
'''
2024-05-22 18:50:21 +08:00
class label_filter:
2024-06-05 16:07:32 +08:00
def __init__(self, socket, threshold=0.5):
2024-05-22 18:50:21 +08:00
self.num = 0
self.pos = []
self.socket = socket
self.threshold = threshold
2024-06-02 17:49:52 +08:00
self.img_size = (320, 240)
'''
description: yolo server 请求目标检测数据
param {*} self
return {*}
'''
def get_resp(self):
self.socket.send_string('')
response = self.socket.recv_pyobj()
return response
2024-06-02 17:49:52 +08:00
'''
description: 切换 yolo server 视频源 在分叉路口时目标检测需要使用前摄
param {*} self
param {*} camera_id 1 或者 2 字符串
return {*}
'''
def switch_camera(self,camera_id):
if camera_id == 1 or camera_id == 2:
self.socket.send_string(f'{camera_id}')
response = self.socket.recv_pyobj()
return response
2024-06-02 17:49:52 +08:00
'''
description: 对模型推理推理结果使用 threshold 过滤 默认阈值为 0.5
param {*} self
param {*} data get_resp 返回的数据
return {bool,array}
'''
def filter_box(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
np_boxes = data[expect_boxes, :]
results = [
[
item[0], # 'label':
item[1], # 'score':
item[2], # 'xmin':
item[3], # 'ymin':
item[4], # 'xmax':
item[5] # 'ymax':
]
for item in np_boxes
]
if len(results) > 0:
return True, np.array(results)
return False, None
# '''
# description: 对模型推理推理结果使用 threshold 和其他条件过滤 默认阈值为 0.5
# param {*} self
# param {*} data get_resp 返回的数据
# return {bool,array}
# '''
# def filter_box_custom(self, data, ymax_range):
# if len(data) > 0:
# expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
# np_boxes = data[expect_boxes, :]
# results = [
# [
# item[0], # 'label':
# item[1], # 'score':
# item[2], # 'xmin':
# item[3], # 'ymin':
# item[4], # 'xmax':
# item[5], # 'ymax':
# not (ymax_range[0] < item[3] < ymax_range[1]), # 如果 ymin 处在范围内则返回 False认为该目标不符合要求
# not (ymax_range[0] < item[5] < ymax_range[1]) # 如果 ymax 处在范围内则返回 False认为该目标不符合要求
# ]
# for item in np_boxes
# ]
# if len(results) > 0:
# return True, np.array(results)
# return False, None
#原来的函数
def filter_box_custom(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
np_boxes = data[expect_boxes, :]
results = [
[
item[0], # 'label':
item[1], # 'score':
item[2], # 'xmin':
item[3], # 'ymin':
item[4], # 'xmax':
item[5] # 'ymax':
]
for item in np_boxes
if item[5] < 180
]
if len(results) > 0:
return True, np.array(results)
return False, None
2024-06-02 17:49:52 +08:00
'''
2024-06-05 16:07:32 +08:00
description: 根据传入的标签过滤返回该标签的个数box
2024-06-02 17:49:52 +08:00
param {*} self
param {*} tlabel
return {int, array}
'''
2024-05-22 18:50:21 +08:00
def get(self, tlabel):
# 循环查找匹配的标签值
# 返回对应标签的个数,以及坐标列表
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
self.num = len(boxes)
2024-06-05 16:07:32 +08:00
if self.num:
self.pos = boxes[:, 2:] # [[x1 y1 x2 y2]]
return True, self.pos
return False, []
2024-06-02 17:49:52 +08:00
'''
2024-06-05 16:07:32 +08:00
description: 仅限在岔路口判断方向牌处使用
2024-06-02 17:49:52 +08:00
param {*} self
param {*} tlabel_list
return {*}
'''
2024-06-05 16:07:32 +08:00
def get_mult_box(self, tlabel_list):
2024-06-02 17:49:52 +08:00
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
2024-06-05 16:07:32 +08:00
except_label = None
2024-06-02 17:49:52 +08:00
if ret:
for tlabel in tlabel_list:
expect_boxes = (results[:, 0] == tlabel.value)
has_true = np.any(expect_boxes)
if has_true:
2024-06-05 16:07:32 +08:00
except_label = tlabel
box = results[expect_boxes, :][:, 2:][0]
error = (box[2] + box[0] - self.img_size[0]) / 2
break
if except_label != None:
return True, except_label, error
return False, None, None
def get_near_box(self, tlabel_list):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
except_label = []
abs_error_list = []
error_list = []
if ret:
for tlabel in tlabel_list:
expect_boxes = (results[:, 0] == tlabel.value)
has_true = np.any(expect_boxes)
if has_true:
except_label.append(tlabel)
box = results[expect_boxes, :][:, 2:][0]
error = (box[2] + box[0] - self.img_size[0]) / 2
abs_error_list.append(abs(error))
error_list.append(error)
if len(error_list) != 0:
abs_error_list = np.array(abs_error_list)
errormin_index = np.argmin(abs_error_list)
return True, except_label[errormin_index], error_list[errormin_index]
return False, None, None
return False, None, None
return False, None, None
2024-06-02 17:49:52 +08:00
'''
description: 判断传入的标签是否存在存在返回 True
param {*} self
param {*} tlabel
return {bool}
'''
2024-05-22 18:50:21 +08:00
def find(self, tlabel):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
return True
2024-05-22 18:50:21 +08:00
return False
2024-06-02 17:49:52 +08:00
'''
description: 查询两个目标 只有 target_label 返回 box
param {*} self
param {*} tlabel
return {[bool]}
'''
def get_two(self, target_label, label):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box_custom(response['data'])
if ret:
expect_boxes = (results[:, 0] == target_label.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
target_bool = True
target_box = boxes[:, 2:]
else:
target_bool = False
target_box = None
expect_boxes = (results[:, 0] == label.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
label_bool = True
else:
label_bool = False
return (target_bool, label_bool, target_box)
return (False, False, None)
'''
description: 查询两个目标 只有 target_label 返回 box
param {*} self
param {*} tlabel
return {[bool]}
'''
def get_two_hanoi(self, target_label, label, flipv):
response = self.get_resp()
if response['code'] == 0:
# FIXME 直接在外部过滤,不在 fliter 内过滤
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == target_label.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
target_bool = True
target_box = boxes[:, 2:]
else:
target_bool = False
target_box = None
expect_boxes = (results[:, 0] == label.value)
boxes = results[expect_boxes, :]
# 在此处过滤
if len(boxes) != 0:
# 如果垂直翻转 (走右侧) 且 *ymin* 小于 60走右侧
if flipv:
label_bool = all(box[3] > 60 for box in boxes)
# 如果不垂直翻转 (走左侧) 且 *ymax* 大于 180走左侧
else:
label_bool = all(box[5] < 180 for box in boxes)
# label_bool = True
else:
label_bool = False
return (target_bool, label_bool, target_box)
return (False, False, None)
'''
2024-07-05 18:29:22 +08:00
description: 判断传入的多目标标签是否存在存在返回 True
param {*} self
param {*} tlabel
return {[bool]}
'''
def find_mult(self, tlabel):
response = self.get_resp()
find_result = []
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
for label in tlabel:
expect_boxes = (results[:, 0] == label.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
find_result.append(True)
else:
find_result.append(False)
return find_result
return [False for _ in range(len(tlabel))]
'''
2024-06-05 16:07:32 +08:00
description: 根据传入的标签寻找画面中最左侧的并返回 error
2024-06-02 17:49:52 +08:00
param {*} self
param {*} tlabel
2024-06-05 16:07:32 +08:00
return {bool, error}
2024-06-02 17:49:52 +08:00
'''
2024-05-22 18:50:21 +08:00
def aim_left(self, tlabel):
# 如果标签存在,则返回列表中位置最靠左的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, )
xmin_values = boxes[:, 2] # xmin
xmin_index = np.argmin(xmin_values)
error = (boxes[xmin_index][4] + boxes[xmin_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, )
2024-05-22 18:50:21 +08:00
def aim_right(self, tlabel):
# 如果标签存在,则返回列表中位置最靠右的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
2024-06-05 16:07:32 +08:00
return (False, None)
xmax_values = boxes[:, 4] # xmax
xmax_index = np.argmax(xmax_values)
error = (boxes[xmax_index][4] + boxes[xmax_index][2] - self.img_size[0]) / 2
return (True, error)
2024-06-05 16:07:32 +08:00
return (False, None)
2024-05-22 18:50:21 +08:00
def aim_near(self, tlabel):
# 如果标签存在,则返回列表中位置最近的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, 0)
center_x_values = np.abs(boxes[:, 2] + boxes[:, 4] - self.img_size[0])
center_x_index = np.argmin(center_x_values)
error = (boxes[center_x_index][4] + boxes[center_x_index][2] - self.img_size[0]) / 2
2024-06-02 17:49:52 +08:00
return (True, error)
return (False, 0)
2024-06-05 16:07:32 +08:00
class LLM:
def __init__(self):
2024-06-30 21:57:01 +08:00
self.init_done_flag = False
2024-06-05 16:07:32 +08:00
erniebot.api_type = "qianfan"
erniebot.ak = "jReawMtWhPu0wrxN9Rp1MzZX"
erniebot.sk = "eowS1BqsNgD2i0C9xNnHUVOSNuAzVTh6"
self.model = 'ernie-3.5'
self.prompt = '''你是一个机器人动作规划者,需要把我的话翻译成机器人动作规划并生成对应的 json 结果,机器人工作空间参考右手坐标系。
严格按照下面的描述生成给定格式 json从现在开始你仅仅给我返回 json 数据'''
self.prompt += '''正确的示例如下:
向左移 0.1m, 向左转弯 85 [{'func': 'move', 'x': 0, 'y': 0.1},{'func': 'turn','angle': -85}],
向右移 0.2m, 向前 0.1m [{'func': 'move', 'x': 0, 'y': -0.2},{'func': 'move', 'x': 0.1, 'y': 0}],
向右转 85 向右移 0.1m [{'func': 'turn','angle': 85},{'func': 'move', 'x': 0, 'y': -0.1}],
原地左转 38 [{'func': 'turn','angle': -38}],
蜂鸣器发声 5 [{'func': 'beep', 'time': 5}]
发光或者照亮 5 [{'func': 'light', 'time': 5}]
2024-06-28 21:34:49 +08:00
向右走 30cm照亮 2s [{'func': 'move', 'x': 0, 'y': -0.3}, {'func': 'light', 'time': 2}],
向左移 0.2m, 向后 0.1m [{'func': 'move', 'x': 0, 'y': 0.2},{'func': 'move', 'x': -0.1, 'y': 0}],
2024-07-05 18:29:22 +08:00
鸣叫 3 [{'func': 'beep', 'time': 3}]
前行零点五米 [{'func': 'move', 'x': 0.5, 'y': 0}]
2024-06-05 16:07:32 +08:00
'''
2024-07-05 18:29:22 +08:00
self.prompt += '''你只需要根据我的示例解析出指令即可,不要给我其他多余的回复;再次强调 你无需给我其他多余的回复 这对我很重要'''
2024-06-05 16:07:32 +08:00
self.messages = []
self.resp = None
worker = threading.Thread(target=self.reset, daemon=True)
worker.start()
2024-06-05 16:07:32 +08:00
def reset(self):
self.messages = [self.make_message(self.prompt)]
self.resp = erniebot.ChatCompletion.create(
model=self.model,
messages=self.messages,
)
self.messages.append(self.resp.to_message())
self.init_done_flag = True
logger.info("LLM init done")
2024-06-05 16:07:32 +08:00
def make_message(self,content):
return {'role': 'user', 'content': content}
def get_command_json(self,chat):
2024-06-30 21:57:01 +08:00
while self.init_done_flag == False: # 等待初始化 (要是等到调用还没初始化,那就是真寄了)
pass
2024-06-05 16:07:32 +08:00
self.messages.append(self.make_message(chat))
self.resp = erniebot.ChatCompletion.create(
model=self.model,
messages=self.messages,
)
self.messages.append(self.resp.to_message())
2024-06-28 21:34:49 +08:00
resp = self.resp.get_result().replace(' ', '').replace('\n', '').replace('\t', '')
2024-07-05 18:29:22 +08:00
return resp
2024-06-05 16:07:32 +08:00
class CountRecord:
def __init__(self, stop_count=2) -> None:
self.last_record = None
self.count = 0
self.stop_cout = stop_count
2024-06-05 16:07:32 +08:00
def get_count(self, val):
try:
if val == self.last_record:
self.count += 1
else:
self.count=0
self.last_record = val
return self.count
except Exception as e:
print(e)
def __call__(self, val):
self.get_count(val)
if self.count >= self.stop_cout:
if type(val) == bool:
return val
return True
else:
return False
2024-06-05 16:07:32 +08:00
class PidWrap:
def __init__(self, kp, ki, kd, setpoint=0, output_limits=1):
self.pid_t = PID(kp, ki, kd, setpoint, output_limits=(0-output_limits, output_limits))
def set_target(self, target):
self.pid_t.setpoint = target
def set(self, kp, ki, kd):
self.pid_t.Kp = kp
self.pid_t.Ki = ki
self.pid_t.Kd = kd
2024-06-15 22:04:11 +08:00
logger.info(f"[PID]# 更新 PID 参数Kp({kp:.2f}) Ki({ki:.2f}) Kd({kd:.2f})")
def get(self, val_in):
return self.pid_t(val_in)
if __name__ == '__main__':
obj = label_filter(None)
# results = obj.filter_box(resp['data'])
# expect_boxes = (results[:, 0] == tlabel.SPILLAR.value)
# np_boxes = results[expect_boxes, :]
# print(np_boxes[:, 2:])
# print(len(np_boxes))
print(obj.find(tlabel.BBALL))
print(obj.aim_left(tlabel.BBALL))
print(obj.aim_right(tlabel.BBALL))
print(obj.aim_near(tlabel.BBALL))
2024-06-05 16:07:32 +08:00
print(obj.get(tlabel.HOSPITAL))
lmm_bot = LLM()
while True:
chat = input("输入:")
print(lmm_bot.get_command_json(chat))