Files
project_main/utils.py

338 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from enum import Enum
import numpy as np
import erniebot
from simple_pid import PID
# 巡线误差
lane_error = 0
# 进入任务时可以通过修改 task_speed 控制巡线速度
task_speed = 0
class tlabel(Enum):
TPLATFORM = 0
TOWER = 1
SIGN = 2
SHELTER = 3
HOSPITAL = 4
BASKET = 5
BASE = 6
YBALL = 7
SPILLER = 8
RMARK = 9
RBLOCK = 10
RBALL = 11
MPILLER = 12
LPILLER = 13
LMARK = 14
BBLOCK = 15
BBALL = 16
# 岔路口参数
direction = tlabel.RMARK
direction_left = 0
direction_right = 0
'''
description: label_filter 的测试数据
'''
test_resp = {
'code': 0,
'data': np.array([
[4., 0.97192055, 26.64415, 228.26755, 170.16872, 357.6216],
[4., 0.97049206, 474.0152, 251.2854, 612.91644, 381.6831],
[5., 0.972649, 250.84174, 238.43622, 378.115, 367.34906]
])
}
test1_resp = {
'code': 0,
'data': np.array([])
}
'''
description: yolo 目标检测标签过滤器,需要传入连接到 yolo server 的 socket 对象
'''
class label_filter:
def __init__(self, socket, threshold=0.5):
self.num = 0
self.pos = []
self.socket = socket
self.threshold = threshold
self.img_size = (320, 240)
'''
description: 向 yolo server 请求目标检测数据
param {*} self
return {*}
'''
def get_resp(self):
self.socket.send_string('')
response = self.socket.recv_pyobj()
return response
'''
description: 切换 yolo server 视频源 在分叉路口时目标检测需要使用前摄
param {*} self
param {*} camera_id 1 或者 2 字符串
return {*}
'''
def switch_camera(self,camera_id):
if camera_id == 1 or camera_id == 2:
self.socket.send_string(f'{camera_id}')
response = self.socket.recv_pyobj()
return response
'''
description: 对模型推理推理结果使用 threshold 过滤 默认阈值为 0.5
param {*} self
param {*} data get_resp 返回的数据
return {bool,array}
'''
def filter_box(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
np_boxes = data[expect_boxes, :]
results = [
[
item[0], # 'label':
item[1], # 'score':
item[2], # 'xmin':
item[3], # 'ymin':
item[4], # 'xmax':
item[5] # 'ymax':
]
for item in np_boxes
]
if len(results) > 0:
return True, np.array(results)
return False, None
'''
description: 根据传入的标签过滤返回该标签的个数、box
param {*} self
param {*} tlabel
return {int, array}
'''
def get(self, tlabel):
# 循环查找匹配的标签值
# 返回对应标签的个数,以及坐标列表
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
self.num = len(boxes)
if self.num:
self.pos = boxes[:, 2:] # [[x1 y1 x2 y2]]
return True, self.pos
return False, []
'''
description: 仅限在岔路口判断方向牌处使用
param {*} self
param {*} tlabel_list
return {*}
'''
def get_mult_box(self, tlabel_list):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
except_label = None
if ret:
for tlabel in tlabel_list:
expect_boxes = (results[:, 0] == tlabel.value)
has_true = np.any(expect_boxes)
if has_true:
except_label = tlabel
box = results[expect_boxes, :][:, 2:][0]
error = (box[2] + box[0] - self.img_size[0]) / 2
break
if except_label != None:
return True, except_label, error
return False, None, None
def get_near_box(self, tlabel_list):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
except_label = []
abs_error_list = []
error_list = []
if ret:
for tlabel in tlabel_list:
expect_boxes = (results[:, 0] == tlabel.value)
has_true = np.any(expect_boxes)
if has_true:
except_label.append(tlabel)
box = results[expect_boxes, :][:, 2:][0]
error = (box[2] + box[0] - self.img_size[0]) / 2
abs_error_list.append(abs(error))
error_list.append(error)
if len(error_list) != 0:
abs_error_list = np.array(abs_error_list)
errormin_index = np.argmin(abs_error_list)
return True, except_label[errormin_index], error_list[errormin_index]
return False, None, None
return False, None, None
return False, None, None
'''
description: 判断传入的标签是否存在,存在返回 True
param {*} self
param {*} tlabel
return {bool}
'''
def find(self, tlabel):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
return True
return False
'''
description: 根据传入的标签,寻找画面中最左侧的并返回 error
param {*} self
param {*} tlabel
return {bool, error}
'''
def aim_left(self, tlabel):
# 如果标签存在,则返回列表中位置最靠左的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, )
xmin_values = boxes[:, 2] # xmin
xmin_index = np.argmin(xmin_values)
error = (boxes[xmin_index][4] + boxes[xmin_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, )
def aim_right(self, tlabel):
# 如果标签存在,则返回列表中位置最靠右的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, None)
xmax_values = boxes[:, 4] # xmax
xmax_index = np.argmax(xmax_values)
error = (boxes[xmax_index][4] + boxes[xmax_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, None)
def aim_near(self, tlabel):
# 如果标签存在,则返回列表中位置最近的目标框和中心的偏移值
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, 0)
center_x_values = np.abs(boxes[:, 2] + boxes[:, 4] - self.img_size[0])
center_x_index = np.argmin(center_x_values)
error = (boxes[center_x_index][4] + boxes[center_x_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, 0)
class LLM:
def __init__(self):
erniebot.api_type = "qianfan"
erniebot.ak = "jReawMtWhPu0wrxN9Rp1MzZX"
erniebot.sk = "eowS1BqsNgD2i0C9xNnHUVOSNuAzVTh6"
self.model = 'ernie-3.5'
self.prompt = '''你是一个机器人动作规划者,需要把我的话翻译成机器人动作规划并生成对应的 json 结果,机器人工作空间参考右手坐标系。
严格按照下面的描述生成给定格式 json从现在开始你仅仅给我返回 json 数据!'''
self.prompt += '''正确的示例如下:
向左移 0.1m, 向左转弯 85 度 [{'func': 'move', 'x': 0, 'y': 0.1},{'func': 'turn','angle': -85}],
向右移 0.2m, 向前 0.1m [{'func': 'move', 'x': 0, 'y': -0.2},{'func': 'move', 'x': 0.1, 'y': 0}],
向右转 85 度,向右移 0.1m [{'func': 'turn','angle': 85},{'func': 'move', 'x': 0, 'y': -0.1}],
原地左转 38 度 [{'func': 'turn','angle': -38}],
蜂鸣器发声 5 秒 [{'func': 'beep', 'time': 5}]
发光或者照亮 5 秒 [{'func': 'light', 'time': 5}]
'''
self.prompt += '''你无需回复我'''
self.messages = []
self.resp = None
self.reset()
def reset(self):
self.messages = [self.make_message(self.prompt)]
self.resp = erniebot.ChatCompletion.create(
model=self.model,
messages=self.messages,
)
self.messages.append(self.resp.to_message())
def make_message(self,content):
return {'role': 'user', 'content': content}
def get_command_json(self,chat):
self.messages.append(self.make_message(chat))
self.resp = erniebot.ChatCompletion.create(
model=self.model,
messages=self.messages,
)
self.messages.append(self.resp.to_message())
return self.resp.get_result()
class CountRecord:
def __init__(self, stop_count=2) -> None:
self.last_record = None
self.count = 0
self.stop_cout = stop_count
def get_count(self, val):
try:
if val == self.last_record:
self.count += 1
else:
self.count=0
self.last_record = val
return self.count
except Exception as e:
print(e)
def __call__(self, val):
self.get_count(val)
if self.count >= self.stop_cout:
if type(val) == bool:
return val
return True
else:
return False
class PidWrap:
def __init__(self, kp, ki, kd, setpoint=0, output_limits=1):
self.pid_t = PID(kp, ki, kd, setpoint, output_limits=(0-output_limits, output_limits))
def set_target(self, target):
self.pid_t.setpoint = target
def set(self, kp, ki, kd):
self.pid_t.Kp = kp
self.pid_t.Ki = ki
self.pid_t.Kd = kd
def get(self, val_in):
return self.pid_t(val_in)
if __name__ == '__main__':
obj = label_filter(None)
# results = obj.filter_box(resp['data'])
# expect_boxes = (results[:, 0] == tlabel.SPILLAR.value)
# np_boxes = results[expect_boxes, :]
# print(np_boxes[:, 2:])
# print(len(np_boxes))
print(obj.find(tlabel.BBALL))
print(obj.aim_left(tlabel.BBALL))
print(obj.aim_right(tlabel.BBALL))
print(obj.aim_near(tlabel.BBALL))
print(obj.get(tlabel.HOSPITAL))
lmm_bot = LLM()
while True:
chat = input("输入:")
print(lmm_bot.get_command_json(chat))