client.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import socket
  2. import json
  3. import time
  4. from ultralytics import YOLO
  5. import random
  6. import logging
  7. from logging.handlers import TimedRotatingFileHandler
  8. import os
  9. from datetime import datetime
  10. classify_model = YOLO('models/classify/best.pt')
  11. detection_model = YOLO('models/detection/best.pt')
  12. server_address = "127.0.0.1"
  13. server_port = 29989
  14. client_socket = None
  15. def connect_server():
  16. global client_socket
  17. client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  18. if client_socket.connect_ex((server_address, server_port)) == 0:
  19. print("Connection established.")
  20. else:
  21. print("Connection failed.")
  22. response_data = {
  23. "front": {"mist": 0, "slag_truck": -1, "person": -1},
  24. "back": {"mist": 0, "slag_truck": -1, "person": -1},
  25. "left": {"mist": 0, "slag_truck": -1, "person": -1},
  26. "right": {"mist": 0, "slag_truck": -1, "person": -1}
  27. }
  28. def init_response_data():
  29. global response_data
  30. response_data = {
  31. "front": {"mist": 0, "slag_truck": -1, "person": -1},
  32. "back": {"mist": 0, "slag_truck": -1, "person": -1},
  33. "left": {"mist": 0, "slag_truck": -1, "person": -1},
  34. "right": {"mist": 0, "slag_truck": -1, "person": -1}
  35. }
  36. # 判断矩形和梯形是否有重合
  37. def is_overlapping(rectangle, trapezoid):
  38. """
  39. 判断矩形和梯形是否重叠
  40. 参数:
  41. rectangle: 矩形,表示为 (x1, y1, x2, y2),其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标
  42. trapezoid: 梯形,表示为 (x1_top, y1_top, x2_top, y2_top, x1_bottom, y1_bottom, x2_bottom, y2_bottom),其中 (x1_top, y1_top) 和 (x2_top, y2_top) 是梯形的上底两个端点坐标,(x1_bottom, y1_bottom) 和 (x2_bottom, y2_bottom) 是梯形的下底两个端点坐标
  43. """
  44. rect_x1, rect_y1, rect_x2, rect_y2 = rectangle
  45. trap_x1_top, trap_y1_top, trap_x2_top, trap_y2_top, trap_x1_bottom, trap_y1_bottom, trap_x2_bottom, trap_y2_bottom = trapezoid
  46. horizontal_overlap = (rect_x1 < max(trap_x2_top, trap_x2_bottom) and rect_x2 > min(trap_x1_top, trap_x1_bottom))
  47. vertical_overlap = (rect_y1 > min(trap_y2_top, trap_y2_bottom) and rect_y2 < max(trap_y1_top, trap_y1_bottom))
  48. return horizontal_overlap and vertical_overlap
  49. camera1_trapezoid = ()
  50. camera2_trapezoid = (380,240, 900,240, 0,640, 1280,640)
  51. # classify 0==truck 1==person index=0为1.jpg index=1为2.jpg
  52. def detection_result_process(index, result_list):
  53. for sublist in result_list:
  54. conf = sublist[4]
  55. logging.info("conf = " + str(conf))
  56. if float(conf) > 0.5:
  57. classify = sublist[5]
  58. # print("classify = "+ str(classify))
  59. logging.info("classify = "+ str(classify))
  60. width = sublist[2] - sublist[0]
  61. hight = sublist[3] - sublist[1]
  62. # print("width = " + str(width))
  63. # print("hight = " + str(hight))
  64. logging.info("width = " + str(width))
  65. logging.info("hight = " + str(hight))
  66. rectangle_coords = (sublist[0], sublist[1], sublist[2], sublist[3])
  67. if classify == 0 and index == 0:
  68. dist = round(150 - 0.8*hight, 2)
  69. print(dist)
  70. current_dist = response_data["front"]["slag_truck"]
  71. if current_dist == -1 or dist < float(current_dist) :
  72. response_data["front"]["slag_truck"] = dist
  73. elif classify == 0 and index == 1:
  74. if is_overlapping(rectangle_coords, camera2_trapezoid):
  75. response_data["back"]["slag_truck"] = 10
  76. logging.info("back slag_truck is_overlapping true!")
  77. dist = round(150 - 0.8*hight, 2)
  78. current_dist = response_data["back"]["slag_truck"]
  79. if current_dist == -1 or dist < float(current_dist) :
  80. response_data["back"]["slag_truck"] = dist
  81. elif classify == 1 and index == 0:
  82. dist_width = round(0.01256*width*width - 1.345*width + 41.67)
  83. logging.info("dist_width = " + str(dist_width))
  84. dist_hight = round(0.003061*hight*hight - 0.7771*hight + 54.58)
  85. logging.info("dist_hight = " + str(dist_hight))
  86. dist = dist_width if dist_width<dist_hight else dist_hight
  87. if dist <= 0:
  88. break
  89. current_dist = response_data["front"]["person"]
  90. if current_dist == -1 or dist < current_dist :
  91. response_data["front"]["person"] = dist
  92. elif classify == 1 and index == 1:
  93. if is_overlapping(rectangle_coords, camera2_trapezoid):
  94. logging.info("back person is_overlapping true!")
  95. response_data["back"]["person"] = 10
  96. dist_width = round(0.01256*width*width - 1.345*width + 41.67)
  97. logging.info("dist_width = " + str(dist_width))
  98. dist_hight = round(0.003061*hight*hight - 0.7771*hight + 54.58)
  99. logging.info("dist_hight = " + str(dist_hight))
  100. dist = dist_width if dist_width<dist_hight else dist_hight
  101. if dist <= 0:
  102. break
  103. current_dist = response_data["back"]["person"]
  104. if current_dist == -1 or dist < float(current_dist) :
  105. response_data["back"]["person"] = dist
  106. fold = '/home/nvidia/newdisk/hkpc/'
  107. def start():
  108. while 1:
  109. init_response_data()
  110. # Convert the message to JSON
  111. json_message = json.dumps(response_data)
  112. try:
  113. client_socket.send(json_message.encode('utf-8'))
  114. server_data = client_socket.recv(1024).decode('utf-8')
  115. logging.info(f"server_data: {server_data}")
  116. except BrokenPipeError:
  117. logging.error("Broken pipe error. Reconnecting to the server.")
  118. connect_server()
  119. time.sleep(2)
  120. # classify_results = classify_model([fold+'1.jpg', fold+'2.jpg'])
  121. # for index, result in enumerate(classify_results):
  122. # probs_list = result.probs.data.tolist()
  123. # max_value = max(probs_list)
  124. # max_index = probs_list.index(max_value)
  125. # if index == 0:
  126. # response_data["front"]["mist"] = max_index
  127. # if index == 1:
  128. # response_data["back"]["mist"] = max_index
  129. # detection_results = detection_model([fold+'1.jpg',fold+'2.jpg', fold+'3.jpg',fold+'4.jpg'])
  130. start_time = time.time()
  131. detection_results = detection_model([fold+'2.jpg'])
  132. for index, result in enumerate(detection_results):
  133. print(result.names)
  134. boxes_list = result.boxes.data.tolist()
  135. if len(boxes_list) > 0:
  136. detection_result_process(index, boxes_list)
  137. logging.info(response_data)
  138. end_time = time.time()
  139. elapsed_time = end_time - start_time
  140. print(f"coding run time: {elapsed_time} s")
  141. # time.sleep(0.2)
  142. time.sleep(5)
  143. # client_socket.close()
  144. def setup_logging():
  145. log_dir = "logs"
  146. if not os.path.exists(log_dir):
  147. os.makedirs(log_dir)
  148. # Include the current date in the log file name
  149. current_date = datetime.now().strftime("%Y-%m-%d")
  150. log_filename = os.path.join(log_dir, f"stip_log_{current_date}.log")
  151. # 设置日志格式
  152. log_format = "%(asctime)s - %(levelname)s - %(message)s"
  153. logging.basicConfig(format=log_format, level=logging.INFO)
  154. # 创建 TimedRotatingFileHandler
  155. handler = TimedRotatingFileHandler(
  156. log_filename, when="midnight", interval=1, backupCount=7
  157. )
  158. # 设置日志格式
  159. handler.setFormatter(logging.Formatter(log_format))
  160. # 添加 handler 到 root logger
  161. logging.getLogger().addHandler(handler)
  162. if __name__ == "__main__":
  163. setup_logging()
  164. connect_server()
  165. start()