diff --git a/ocr_server/ocr_infer_server.py b/ocr_server/ocr_infer_server.py index f26940a..1a34cb2 100644 --- a/ocr_server/ocr_infer_server.py +++ b/ocr_server/ocr_infer_server.py @@ -1,10 +1,60 @@ import toml from loguru import logger -import logging import zmq -from paddleocr import PaddleOCR import cv2 -logging.getLogger('paddleocr').setLevel(logging.CRITICAL) +import numpy as np +import requests +import base64 + + + +def get_access_token(): + client_id = "MDCGplPqK0kteOgbXwt5cyn0" + client_secret = "yIHJQUUiMkkw53nlQqHpiLvRFsLGcqgn" + url = "https://aip.baidubce.com/oauth/2.0/token" + params = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': client_secret + } + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + response = requests.post(url, params=params, headers=headers) + response_json = response.json() + if 'access_token' in response_json: + return response_json['access_token'] + else: + print("Failed to get access_token:", response_json.get('error_description')) + return None + +def ocr_api_request(image_base64): + # url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic" # 高精度 + # url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate" # 高精度带位置 + url = "https://aip.baidubce.com/rest/2.0/ocr/v1/general" # 标准精度带位置 + headers = { + 'Content-Type': 'application/json' + } + params = { + 'access_token': get_access_token(), + 'image': image_base64, + 'probability': 'true' + } + + try: + response = requests.post(url, headers=headers, data=params, timeout=5) + + try: + return response.json() + except requests.exceptions.JSONDecodeError: + return None + + except requests.exceptions.Timeout: + return None + except requests.exceptions.RequestException as e: + return None + if __name__ == "__main__": @@ -13,14 +63,11 @@ if __name__ == "__main__": # 配置日志输出 logger.add(cfg['debug']['logger_filename'], format=cfg['debug']['logger_format'], retention = 5, level="INFO") - # 连接摄像头 - cap = cv2.VideoCapture(4) - cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320) - cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) + context1 = zmq.Context() + camera_socket = context1.socket(zmq.REQ) + camera_socket.connect(f"tcp://localhost:{cfg['camera']['camera2_port']}") + logger.info("connect camera success") - # 初始化 paddle 推理器 - predictor = PaddleOCR(use_angle_cls=False, use_gpu=True) - logger.info("ocr model load success") # 初始化 server context = zmq.Context() @@ -29,17 +76,26 @@ if __name__ == "__main__": socket.bind(f"tcp://*:{cfg['server']['ocr_infer_port']}") while True: - socket.recv_string("") - ret, frame = cap.read() - try: - if ret: - result = predictor.ocr(frame) - response = {'code': 0, 'data': result} - socket.send_pyobj(response) - else: - socket.send_pyobj({'code': -1, 'data': None}) - except: - socket.send_pyobj({'code': -1, 'data': None}) + message1 = socket.recv_string() + logger.info("recv client request") + for _ in range(5): + camera_socket.send_string("") + message = camera_socket.recv() + + np_array = np.frombuffer(message, dtype=np.uint8) + image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + output_file_path = 'output_image.jpg' + success = cv2.imwrite(output_file_path, image) + + encoded_image = base64.b64encode(message).decode('utf-8') + + result = ocr_api_request(encoded_image) + print(result) + if result != None: + socket.send_pyobj({'code': 0, 'content': result.get('words_result')}) + else: + socket.send_pyobj({'code': -1, 'content': None}) + if cv2.waitKey(1) == 27: break logger.info("ocr infer server exit")