feat: 增加 hanoi 跳过过滤条件(存在问题)

This commit is contained in:
bmy
2024-07-14 13:34:18 +08:00
parent b1300fc8f1
commit 1edd292ac6
7 changed files with 215 additions and 28 deletions

View File

@@ -106,6 +106,53 @@ class label_filter:
if len(results) > 0:
return True, np.array(results)
return False, None
# '''
# description: 对模型推理推理结果使用 threshold 和其他条件过滤 默认阈值为 0.5
# param {*} self
# param {*} data get_resp 返回的数据
# return {bool,array}
# '''
# def filter_box_custom(self, data, ymax_range):
# if len(data) > 0:
# expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
# np_boxes = data[expect_boxes, :]
# results = [
# [
# item[0], # 'label':
# item[1], # 'score':
# item[2], # 'xmin':
# item[3], # 'ymin':
# item[4], # 'xmax':
# item[5], # 'ymax':
# not (ymax_range[0] < item[3] < ymax_range[1]), # 如果 ymin 处在范围内则返回 False认为该目标不符合要求
# not (ymax_range[0] < item[5] < ymax_range[1]) # 如果 ymax 处在范围内则返回 False认为该目标不符合要求
# ]
# for item in np_boxes
# ]
# if len(results) > 0:
# return True, np.array(results)
# return False, None
#原来的函数
def filter_box_custom(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
np_boxes = data[expect_boxes, :]
results = [
[
item[0], # 'label':
item[1], # 'score':
item[2], # 'xmin':
item[3], # 'ymin':
item[4], # 'xmax':
item[5] # 'ymax':
]
for item in np_boxes
if item[5] < 180
]
if len(results) > 0:
return True, np.array(results)
return False, None
'''
description: 根据传入的标签过滤返回该标签的个数、box
param {*} self
@@ -198,7 +245,7 @@ class label_filter:
def get_two(self, target_label, label):
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
ret, results = self.filter_box_custom(response['data'])
if ret:
expect_boxes = (results[:, 0] == target_label.value)
boxes = results[expect_boxes, :]
@@ -216,6 +263,35 @@ class label_filter:
label_bool = False
return (target_bool, label_bool, target_box)
return (False, False, None)
# '''
# description: 查询两个目标 只有 target_label 返回 box
# param {*} self
# param {*} tlabel
# return {[bool]}
# '''
# def get_two_hanoi(self, target_label, label, ymax_range):
# response = self.get_resp()
# if response['code'] == 0:
# ret, results = self.filter_box_custom(response['data'], ymax_range)
# if ret:
# expect_boxes = (results[:, 0] == target_label.value)
# boxes = results[expect_boxes, :]
# if len(boxes) != 0:
# target_bool = True
# target_box = boxes[:, 2:]
# else:
# target_bool = False
# target_box = None
# expect_boxes = (results[:, 0] == label.value)
# boxes = results[expect_boxes, :]
# if len(boxes) != 0:
# label_bool = True
# else:
# label_bool = False
# return (target_bool, label_bool, target_box)
# return (False, False, None)
'''
description: 判断传入的多目标标签是否存在,存在返回 True
param {*} self