import socket import json import time from ultralytics import YOLO import random import logging from logging.handlers import TimedRotatingFileHandler import os from datetime import datetime classify_model = YOLO('models/classify/best.pt') detection_model = YOLO('models/detection/best.pt') server_address = "127.0.0.1" server_port = 29989 client_socket = None def connect_server(): global client_socket client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if client_socket.connect_ex((server_address, server_port)) == 0: print("Connection established.") else: print("Connection failed.") response_data = { "front": {"mist": 0, "slag_truck": -1, "person": -1}, "back": {"mist": 0, "slag_truck": -1, "person": -1}, "left": {"mist": 0, "slag_truck": -1, "person": -1}, "right": {"mist": 0, "slag_truck": -1, "person": -1} } def init_response_data(): global response_data response_data = { "front": {"mist": 0, "slag_truck": -1, "person": -1}, "back": {"mist": 0, "slag_truck": -1, "person": -1}, "left": {"mist": 0, "slag_truck": -1, "person": -1}, "right": {"mist": 0, "slag_truck": -1, "person": -1} } def detection_result_process(index, result_list): for sublist in result_list: conf = sublist[4] if conf > 0.5: classify = sublist[5] print("classify = "+ str(classify)) logging.info("classify = "+ str(classify)) width = sublist[2] - sublist[0] hight = sublist[3] - sublist[1] print("width = " + str(width)) print("hight = " + str(hight)) logging.info("width = " + str(width)) logging.info("hight = " + str(hight)) if classify == 0 and index == 0: dist = round(150 - 0.8*hight, 2) print(dist) current_dist = response_data["front"]["slag_truck"] if current_dist == -1 or dist < float(current_dist) : response_data["front"]["slag_truck"] = dist elif classify == 0 and index == 1: dist = round(150 - 0.8*hight, 2) current_dist = response_data["back"]["slag_truck"] if current_dist == -1 or dist < float(current_dist) : response_data["back"]["slag_truck"] = dist elif classify == 1 and index == 0: dist = round(150 - 3.5*width, 2) current_dist = response_data["front"]["person"] if current_dist == -1 or dist < current_dist : response_data["front"]["person"] = dist elif classify == 1 and index == 1: dist = round(150 - 3.5*width, 2) current_dist = response_data["back"]["person"] if current_dist == -1 or dist < float(current_dist) : response_data["back"]["person"] = dist fold = '/home/nvidia/newdisk/hkpc/' def start(): while 1: init_response_data() start_time = time.time() # random_numbers = random.sample(range(1, 12), 4) # print(random_numbers) classify_results = classify_model([fold+'1.jpg', fold+'2.jpg']) # import pdb; pdb.set_trace() for index, result in enumerate(classify_results): # print(result.names) probs_list = result.probs.data.tolist() max_value = max(probs_list) # print(max_value) max_index = probs_list.index(max_value) # print(max_index) if index == 0: response_data["front"]["mist"] = max_index if index == 1: response_data["back"]["mist"] = max_index detection_results = detection_model([fold+'1.jpg',fold+'2.jpg', fold+'4.jpg',fold+'4.jpg']) for index, result in enumerate(detection_results): print(result.names) boxes_list = result.boxes.data.tolist() if len(boxes_list) > 0: detection_result_process(index, boxes_list) logging.info(response_data) end_time = time.time() elapsed_time = end_time - start_time print(f"代码运行时间: {elapsed_time} 秒") # Convert the message to JSON json_message = json.dumps(response_data) try: client_socket.send(json_message.encode('utf-8')) response = client_socket.recv(1024).decode('utf-8') logging.info(f"Server response: {response}") except BrokenPipeError: logging.error("Broken pipe error. Reconnecting to the server.") connect_server() time.sleep(2) time.sleep(0.2) # break # Close the connection after the loop client_socket.close() def setup_logging(): log_dir = "logs" if not os.path.exists(log_dir): os.makedirs(log_dir) # Include the current date in the log file name current_date = datetime.now().strftime("%Y-%m-%d") log_filename = os.path.join(log_dir, f"stip_log_{current_date}.log") # 设置日志格式 log_format = "%(asctime)s - %(levelname)s - %(message)s" logging.basicConfig(format=log_format, level=logging.INFO) # 创建 TimedRotatingFileHandler handler = TimedRotatingFileHandler( log_filename, when="midnight", interval=1, backupCount=7 ) # 设置日志格式 handler.setFormatter(logging.Formatter(log_format)) # 添加 handler 到 root logger logging.getLogger().addHandler(handler) if __name__ == "__main__": setup_logging() connect_server() start()