123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import socket
- import json
- import time
- from ultralytics import YOLO
- import random
- import logging
- from logging.handlers import TimedRotatingFileHandler
- import os
- from datetime import datetime
- classify_model = YOLO('models/classify/best.pt')
- detection_model = YOLO('models/detection/best.pt')
- server_address = "127.0.0.1"
- server_port = 29989
- client_socket = None
- def connect_server():
- global client_socket
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if client_socket.connect_ex((server_address, server_port)) == 0:
- print("Connection established.")
- else:
- print("Connection failed.")
- response_data = {
- "front": {"mist": 0, "slag_truck": -1, "person": -1},
- "back": {"mist": 0, "slag_truck": -1, "person": -1},
- "left": {"mist": 0, "slag_truck": -1, "person": -1},
- "right": {"mist": 0, "slag_truck": -1, "person": -1}
- }
- def init_response_data():
- global response_data
- response_data = {
- "front": {"mist": 0, "slag_truck": -1, "person": -1},
- "back": {"mist": 0, "slag_truck": -1, "person": -1},
- "left": {"mist": 0, "slag_truck": -1, "person": -1},
- "right": {"mist": 0, "slag_truck": -1, "person": -1}
- }
- def detection_result_process(index, result_list):
- for sublist in result_list:
- conf = sublist[4]
- if conf > 0.5:
- classify = sublist[5]
- print("classify = "+ str(classify))
- logging.info("classify = "+ str(classify))
- width = sublist[2] - sublist[0]
- hight = sublist[3] - sublist[1]
- print("width = " + str(width))
- print("hight = " + str(hight))
- logging.info("width = " + str(width))
- logging.info("hight = " + str(hight))
- if classify == 0 and index == 0:
- dist = round(150 - 0.8*hight, 2)
- print(dist)
- current_dist = response_data["front"]["slag_truck"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["front"]["slag_truck"] = dist
- elif classify == 0 and index == 1:
- dist = round(150 - 0.8*hight, 2)
- current_dist = response_data["back"]["slag_truck"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["back"]["slag_truck"] = dist
- elif classify == 1 and index == 0:
- dist = round(150 - 3.5*width, 2)
- current_dist = response_data["front"]["person"]
- if current_dist == -1 or dist < current_dist :
- response_data["front"]["person"] = dist
- elif classify == 1 and index == 1:
- dist = round(150 - 3.5*width, 2)
- current_dist = response_data["back"]["person"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["back"]["person"] = dist
- fold = '/home/nvidia/newdisk/hkpc/'
- def start():
- while 1:
- init_response_data()
- start_time = time.time()
- # random_numbers = random.sample(range(1, 12), 4)
- # print(random_numbers)
- classify_results = classify_model([fold+'1.jpg', fold+'2.jpg'])
- # import pdb; pdb.set_trace()
- for index, result in enumerate(classify_results):
- # print(result.names)
- probs_list = result.probs.data.tolist()
- max_value = max(probs_list)
- # print(max_value)
- max_index = probs_list.index(max_value)
- # print(max_index)
- if index == 0:
- response_data["front"]["mist"] = max_index
- if index == 1:
- response_data["back"]["mist"] = max_index
- detection_results = detection_model([fold+'1.jpg',fold+'2.jpg', fold+'4.jpg',fold+'4.jpg'])
- for index, result in enumerate(detection_results):
- print(result.names)
- boxes_list = result.boxes.data.tolist()
- if len(boxes_list) > 0:
- detection_result_process(index, boxes_list)
- logging.info(response_data)
- end_time = time.time()
- elapsed_time = end_time - start_time
- print(f"代码运行时间: {elapsed_time} 秒")
- # Convert the message to JSON
- json_message = json.dumps(response_data)
- try:
- client_socket.send(json_message.encode('utf-8'))
- response = client_socket.recv(1024).decode('utf-8')
- logging.info(f"Server response: {response}")
- except BrokenPipeError:
- logging.error("Broken pipe error. Reconnecting to the server.")
- connect_server()
- time.sleep(2)
- time.sleep(0.2)
- # break
- # Close the connection after the loop
- client_socket.close()
- def setup_logging():
- log_dir = "logs"
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- # Include the current date in the log file name
- current_date = datetime.now().strftime("%Y-%m-%d")
- log_filename = os.path.join(log_dir, f"stip_log_{current_date}.log")
- # 设置日志格式
- log_format = "%(asctime)s - %(levelname)s - %(message)s"
- logging.basicConfig(format=log_format, level=logging.INFO)
- # 创建 TimedRotatingFileHandler
- handler = TimedRotatingFileHandler(
- log_filename, when="midnight", interval=1, backupCount=7
- )
- # 设置日志格式
- handler.setFormatter(logging.Formatter(log_format))
- # 添加 handler 到 root logger
- logging.getLogger().addHandler(handler)
- if __name__ == "__main__":
- setup_logging()
- connect_server()
- start()
|