123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- import socket
- import json
- import time
- from ultralytics import YOLO
- import random
- import logging
- from logging.handlers import TimedRotatingFileHandler
- import os
- from datetime import datetime
- classify_model = YOLO('models/classify/best.pt')
- detection_model = YOLO('models/detection/best.pt')
- server_address = "127.0.0.1"
- server_port = 29989
- client_socket = None
- def connect_server():
- global client_socket
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if client_socket.connect_ex((server_address, server_port)) == 0:
- print("Connection established.")
- else:
- print("Connection failed.")
- response_data = {
- "front": {"mist": 0, "slag_truck": -1, "person": -1},
- "back": {"mist": 0, "slag_truck": -1, "person": -1},
- "left": {"mist": 0, "slag_truck": -1, "person": -1},
- "right": {"mist": 0, "slag_truck": -1, "person": -1}
- }
- def init_response_data():
- global response_data
- response_data = {
- "front": {"mist": 0, "slag_truck": -1, "person": -1},
- "back": {"mist": 0, "slag_truck": -1, "person": -1},
- "left": {"mist": 0, "slag_truck": -1, "person": -1},
- "right": {"mist": 0, "slag_truck": -1, "person": -1}
- }
- # 判断矩形和梯形是否有重合
- def is_overlapping(rectangle, trapezoid):
- """
- 判断矩形和梯形是否重叠
- 参数:
- rectangle: 矩形,表示为 (x1, y1, x2, y2),其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标
- trapezoid: 梯形,表示为 (x1_top, y1_top, x2_top, y2_top, x1_bottom, y1_bottom, x2_bottom, y2_bottom),其中 (x1_top, y1_top) 和 (x2_top, y2_top) 是梯形的上底两个端点坐标,(x1_bottom, y1_bottom) 和 (x2_bottom, y2_bottom) 是梯形的下底两个端点坐标
- """
- rect_x1, rect_y1, rect_x2, rect_y2 = rectangle
- trap_x1_top, trap_y1_top, trap_x2_top, trap_y2_top, trap_x1_bottom, trap_y1_bottom, trap_x2_bottom, trap_y2_bottom = trapezoid
- horizontal_overlap = (rect_x1 < max(trap_x2_top, trap_x2_bottom) and rect_x2 > min(trap_x1_top, trap_x1_bottom))
- vertical_overlap = (rect_y1 > min(trap_y2_top, trap_y2_bottom) and rect_y2 < max(trap_y1_top, trap_y1_bottom))
- return horizontal_overlap and vertical_overlap
- camera1_trapezoid = ()
- camera2_trapezoid = (380,240, 900,240, 0,640, 1280,640)
- # classify 0==truck 1==person index=0为1.jpg index=1为2.jpg
- def detection_result_process(index, result_list):
- for sublist in result_list:
- conf = sublist[4]
- logging.info("conf = " + str(conf))
- if float(conf) > 0.5:
- classify = sublist[5]
- # print("classify = "+ str(classify))
- logging.info("classify = "+ str(classify))
- width = sublist[2] - sublist[0]
- hight = sublist[3] - sublist[1]
- # print("width = " + str(width))
- # print("hight = " + str(hight))
- logging.info("width = " + str(width))
- logging.info("hight = " + str(hight))
- rectangle_coords = (sublist[0], sublist[1], sublist[2], sublist[3])
- if classify == 0 and index == 0:
- dist = round(150 - 0.8*hight, 2)
- print(dist)
- current_dist = response_data["front"]["slag_truck"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["front"]["slag_truck"] = dist
- elif classify == 0 and index == 1:
- if is_overlapping(rectangle_coords, camera2_trapezoid):
- response_data["back"]["slag_truck"] = 10
- logging.info("back slag_truck is_overlapping true!")
- dist = round(150 - 0.8*hight, 2)
- current_dist = response_data["back"]["slag_truck"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["back"]["slag_truck"] = dist
- elif classify == 1 and index == 0:
- dist_width = round(0.01256*width*width - 1.345*width + 41.67)
- logging.info("dist_width = " + str(dist_width))
- dist_hight = round(0.003061*hight*hight - 0.7771*hight + 54.58)
- logging.info("dist_hight = " + str(dist_hight))
- dist = dist_width if dist_width<dist_hight else dist_hight
- if dist <= 0:
- break
- current_dist = response_data["front"]["person"]
- if current_dist == -1 or dist < current_dist :
- response_data["front"]["person"] = dist
- elif classify == 1 and index == 1:
- if is_overlapping(rectangle_coords, camera2_trapezoid):
- logging.info("back person is_overlapping true!")
- response_data["back"]["person"] = 10
- dist_width = round(0.01256*width*width - 1.345*width + 41.67)
- logging.info("dist_width = " + str(dist_width))
- dist_hight = round(0.003061*hight*hight - 0.7771*hight + 54.58)
- logging.info("dist_hight = " + str(dist_hight))
- dist = dist_width if dist_width<dist_hight else dist_hight
- if dist <= 0:
- break
- current_dist = response_data["back"]["person"]
- if current_dist == -1 or dist < float(current_dist) :
- response_data["back"]["person"] = dist
- fold = '/home/nvidia/newdisk/hkpc/'
- def start():
- while 1:
- init_response_data()
-
- # Convert the message to JSON
- json_message = json.dumps(response_data)
- try:
- client_socket.send(json_message.encode('utf-8'))
- server_data = client_socket.recv(1024).decode('utf-8')
- logging.info(f"server_data: {server_data}")
- except BrokenPipeError:
- logging.error("Broken pipe error. Reconnecting to the server.")
- connect_server()
- time.sleep(2)
- # classify_results = classify_model([fold+'1.jpg', fold+'2.jpg'])
- # for index, result in enumerate(classify_results):
- # probs_list = result.probs.data.tolist()
- # max_value = max(probs_list)
- # max_index = probs_list.index(max_value)
- # if index == 0:
- # response_data["front"]["mist"] = max_index
- # if index == 1:
- # response_data["back"]["mist"] = max_index
- # detection_results = detection_model([fold+'1.jpg',fold+'2.jpg', fold+'3.jpg',fold+'4.jpg'])
-
- start_time = time.time()
- detection_results = detection_model([fold+'2.jpg'])
- for index, result in enumerate(detection_results):
- print(result.names)
- boxes_list = result.boxes.data.tolist()
- if len(boxes_list) > 0:
- detection_result_process(index, boxes_list)
- logging.info(response_data)
- end_time = time.time()
- elapsed_time = end_time - start_time
- print(f"coding run time: {elapsed_time} s")
- # time.sleep(0.2)
- time.sleep(5)
- # client_socket.close()
- def setup_logging():
- log_dir = "logs"
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- # Include the current date in the log file name
- current_date = datetime.now().strftime("%Y-%m-%d")
- log_filename = os.path.join(log_dir, f"stip_log_{current_date}.log")
- # 设置日志格式
- log_format = "%(asctime)s - %(levelname)s - %(message)s"
- logging.basicConfig(format=log_format, level=logging.INFO)
- # 创建 TimedRotatingFileHandler
- handler = TimedRotatingFileHandler(
- log_filename, when="midnight", interval=1, backupCount=7
- )
- # 设置日志格式
- handler.setFormatter(logging.Formatter(log_format))
- # 添加 handler 到 root logger
- logging.getLogger().addHandler(handler)
- if __name__ == "__main__":
- setup_logging()
- connect_server()
- start()
|