client_1.15.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import socket
  2. import json
  3. import time
  4. from ultralytics import YOLO
  5. import random
  6. import logging
  7. from logging.handlers import TimedRotatingFileHandler
  8. import os
  9. from datetime import datetime
  10. classify_model = YOLO('models/classify/best.pt')
  11. detection_model = YOLO('models/detection/best.pt')
  12. server_address = "127.0.0.1"
  13. server_port = 29989
  14. client_socket = None
  15. def connect_server():
  16. global client_socket
  17. client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  18. if client_socket.connect_ex((server_address, server_port)) == 0:
  19. print("Connection established.")
  20. else:
  21. print("Connection failed.")
  22. response_data = {
  23. "front": {"mist": 0, "slag_truck": -1, "person": -1},
  24. "back": {"mist": 0, "slag_truck": -1, "person": -1},
  25. "left": {"mist": 0, "slag_truck": -1, "person": -1},
  26. "right": {"mist": 0, "slag_truck": -1, "person": -1}
  27. }
  28. def init_response_data():
  29. global response_data
  30. response_data = {
  31. "front": {"mist": 0, "slag_truck": -1, "person": -1},
  32. "back": {"mist": 0, "slag_truck": -1, "person": -1},
  33. "left": {"mist": 0, "slag_truck": -1, "person": -1},
  34. "right": {"mist": 0, "slag_truck": -1, "person": -1}
  35. }
  36. def detection_result_process(index, result_list):
  37. for sublist in result_list:
  38. conf = sublist[4]
  39. if conf > 0.5:
  40. classify = sublist[5]
  41. print("classify = "+ str(classify))
  42. logging.info("classify = "+ str(classify))
  43. width = sublist[2] - sublist[0]
  44. hight = sublist[3] - sublist[1]
  45. print("width = " + str(width))
  46. print("hight = " + str(hight))
  47. logging.info("width = " + str(width))
  48. logging.info("hight = " + str(hight))
  49. if classify == 0 and index == 0:
  50. dist = round(150 - 0.8*hight, 2)
  51. print(dist)
  52. current_dist = response_data["front"]["slag_truck"]
  53. if current_dist == -1 or dist < float(current_dist) :
  54. response_data["front"]["slag_truck"] = dist
  55. elif classify == 0 and index == 1:
  56. dist = round(150 - 0.8*hight, 2)
  57. current_dist = response_data["back"]["slag_truck"]
  58. if current_dist == -1 or dist < float(current_dist) :
  59. response_data["back"]["slag_truck"] = dist
  60. elif classify == 1 and index == 0:
  61. dist = round(150 - 3.5*width, 2)
  62. current_dist = response_data["front"]["person"]
  63. if current_dist == -1 or dist < current_dist :
  64. response_data["front"]["person"] = dist
  65. elif classify == 1 and index == 1:
  66. dist = round(150 - 3.5*width, 2)
  67. current_dist = response_data["back"]["person"]
  68. if current_dist == -1 or dist < float(current_dist) :
  69. response_data["back"]["person"] = dist
  70. fold = '/home/nvidia/newdisk/hkpc/'
  71. def start():
  72. while 1:
  73. init_response_data()
  74. start_time = time.time()
  75. # random_numbers = random.sample(range(1, 12), 4)
  76. # print(random_numbers)
  77. classify_results = classify_model([fold+'1.jpg', fold+'2.jpg'])
  78. # import pdb; pdb.set_trace()
  79. for index, result in enumerate(classify_results):
  80. # print(result.names)
  81. probs_list = result.probs.data.tolist()
  82. max_value = max(probs_list)
  83. # print(max_value)
  84. max_index = probs_list.index(max_value)
  85. # print(max_index)
  86. if index == 0:
  87. response_data["front"]["mist"] = max_index
  88. if index == 1:
  89. response_data["back"]["mist"] = max_index
  90. detection_results = detection_model([fold+'1.jpg',fold+'2.jpg', fold+'4.jpg',fold+'4.jpg'])
  91. for index, result in enumerate(detection_results):
  92. print(result.names)
  93. boxes_list = result.boxes.data.tolist()
  94. if len(boxes_list) > 0:
  95. detection_result_process(index, boxes_list)
  96. logging.info(response_data)
  97. end_time = time.time()
  98. elapsed_time = end_time - start_time
  99. print(f"代码运行时间: {elapsed_time} 秒")
  100. # Convert the message to JSON
  101. json_message = json.dumps(response_data)
  102. try:
  103. client_socket.send(json_message.encode('utf-8'))
  104. response = client_socket.recv(1024).decode('utf-8')
  105. logging.info(f"Server response: {response}")
  106. except BrokenPipeError:
  107. logging.error("Broken pipe error. Reconnecting to the server.")
  108. connect_server()
  109. time.sleep(2)
  110. time.sleep(0.2)
  111. # break
  112. # Close the connection after the loop
  113. client_socket.close()
  114. def setup_logging():
  115. log_dir = "logs"
  116. if not os.path.exists(log_dir):
  117. os.makedirs(log_dir)
  118. # Include the current date in the log file name
  119. current_date = datetime.now().strftime("%Y-%m-%d")
  120. log_filename = os.path.join(log_dir, f"stip_log_{current_date}.log")
  121. # 设置日志格式
  122. log_format = "%(asctime)s - %(levelname)s - %(message)s"
  123. logging.basicConfig(format=log_format, level=logging.INFO)
  124. # 创建 TimedRotatingFileHandler
  125. handler = TimedRotatingFileHandler(
  126. log_filename, when="midnight", interval=1, backupCount=7
  127. )
  128. # 设置日志格式
  129. handler.setFormatter(logging.Formatter(log_format))
  130. # 添加 handler 到 root logger
  131. logging.getLogger().addHandler(handler)
  132. if __name__ == "__main__":
  133. setup_logging()
  134. connect_server()
  135. start()