from hub import methods, Global

import cv2
import base64


class Task(object):

    @classmethod
    def task002(cls, **params):
        """获取特征值接口"""
        run_at = methods.now_ts()
        try:

            # --- get face_features ---
            call_id = params.get('call_id')
            save_is = params.get('save_is')
            rdb = Global.get_redis_client()
            image_bytes = rdb.get_one(key=call_id)

            # --- 人脸裁剪逻辑 ---
            # methods.debug_log('Task.task002', f"m-24: {type(image_bytes)}")
            image_array = Global.scrfd_agent.image_bytes_to_image_array(image_bytes)
            # methods.debug_log('Task.task002', f"m-24: {type(image_array)}")
            # inference_result = Global.scrfd_agent.inference(image_array)
            inference_result = Global.scrfd_agent.inference_with_image_array(image_array)

            # --- save file ---
            if save_is:
                save_at = methods.now_string('%Y-%m%d-%H%M%S-%f')
                frame = Global.arcface_agent.image_bytes_to_image_array(image_bytes)
                face_file_path = f"/home/server/resources/vms-files/{save_at}-raw.jpg"
                cv2.imwrite(face_file_path, frame)
            else:
                face_file_path = ''

            # --- check ---
            if not inference_result:
                return dict(code=1, details=f"something is wrong.")

            # --- 歪脸矫正 ---
            face_image = inference_result[0].get('align_face')
            probability = 0.0

            # --- 提取人脸特征 ---
            face_features = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # --- check ---
            if face_features is None:
                return dict(code=2, details=f"something is wrong.")

            # --- debug to close --- todo 存在歪脸矫正后,图像图像变模糊的问题
            # save_at = methods.now_string('%Y-%m%d-%H%M%S-%f')
            # face_file_path_x = f"/home/server/resources/vms-files/{save_at}-align.jpg"
            # cv2.imwrite(face_file_path_x, face_image)

            # --- 增加戴口罩人脸底库特征与图片 ---
            masked_face_image = Global.face_to_mask_face.add_mask_one(face_image)
            if masked_face_image is not None:
                face_features_2 = Global.arcface_agent.get_face_features_normalization_by_image_array(masked_face_image)
            else:
                face_features_2 = face_features
                masked_face_image = face_image

            # --- save file ---
            if save_is:
                save_at = methods.now_string('%Y-%m%d-%H%M%S-%f')
                face_file_path_2 = f"/home/server/resources/vms-files/{save_at}-masked.jpg"
                cv2.imwrite(face_file_path_2, masked_face_image)
            else:
                face_file_path_2 = ''

            # --- update call_id value ---
            rdb.set_one(key=call_id, data=(face_features, probability, face_file_path,
                                           face_features_2, probability, face_file_path_2), expire_time=600)
            return dict(code=0, details=f"congratulations. | use time {methods.now_ts() - run_at}s")

        except Exception as exception:

            methods.debug_log('Task.task002', f"m-76: exception | {exception}")
            methods.debug_log('Task.task002', f"m-76: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def task003(cls, **params):
        """
        支持8301接口
        """
        run_at = methods.now_ts()
        try:
            # --- get data ---
            call_id = params.get('call_id')
            rdb = Global.get_redis_client()
            input_data = rdb.get_one(key=call_id)

            # --- 人脸裁剪逻辑 ---
            image_array = Global.scrfd_agent.image_bytes_to_image_array(input_data.get('image_bytes'))
            inference_result = Global.scrfd_agent.inference_with_image_array(image_array)

            # --- 歪脸矫正 ---
            face_image = inference_result[0].get('align_face')

            # --- 提取人脸特征 ---
            face_features = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)
            # methods.debug_log('Task.task003', f"m-105: face_features | {face_features}")

            # --- fill face_type_name_dict ---
            """
            face_type_name_dict = {<type_uuid>: <name>}
            """
            face_type_name_dict = dict()
            for item in Global.mdb.get_all('FaceType'):
                uuid = str(item.get('_id'))
                face_type_name_dict[uuid] = item.get('name')

            # --- fill d5 ---
            d5 = dict()  # {<face_uuid>: {info}}
            for item in Global.mdb.get_all('Face'):
                face_uuid = str(item.get('_id'))
                d5[face_uuid] = item.copy()

            # --- fill d2 ---
            """
            FaceLog: 陌生人访问日志表
            FaceLog.face_image_path: 抓拍人脸图像路径
            FaceLog.face_probability: 人脸概率
            FaceLog.base_face_uuid: 底库人脸id
            FaceLog.base_face_image_path: 底库人脸图片路径
            FaceLog.face_similarity: 对比相似度
            FaceLog.count: 条目计数
            """
            d2 = list()
            start_ts = methods.string_to_ts(f"{input_data.get('start_date_at')}", pattern='%Y-%m-%d %H:%M:%S')
            end_ts = methods.string_to_ts(f"{input_data.get('end_date_at')}", pattern='%Y-%m-%d %H:%M:%S') + 1
            items = list(Global.xdb.filter_by_time_range('FaceLog', start_ts, end_ts))
            for item in items:

                # --- check ---
                base_face_uuid = item.get('base_face_uuid')
                if base_face_uuid not in d5:
                    continue

                # --- check ---
                # face = Global.mdb.get_one_by_id('Face', base_face_uuid)
                face = d5.get(base_face_uuid)
                face_features_1 = methods.pickle_loads(face.get('face_feature_info_list')[0].get('face_features'))
                if face_features_1 is None:
                    continue

                # --- check ---
                dist = Global.arcface_agent.compare_faces_by_normalization(face_features, face_features_1)
                if dist < input_data.get('face_similarity'):
                    continue

                # --- fill ---
                d4 = item, face, dist
                d2.append(d4)

            # --- fill d3 ---
            d3 = list()
            page = input_data.get('page')
            size = input_data.get('size')
            for item, face, dist in d2[(page - 1) * size: page * size]:

                # --- fill d1 --- todo 返回数据 -> 抓拍照片、检索照片、相似度
                d1 = {
                    'face_uuid': item.get('base_face_uuid'),
                    'face_name': face.get('face_name'),
                    'record_at': methods.string_to_ts(item.get('time'), pattern='%Y-%m-%dT%H:%M:%S.%fZ'),
                    'snap_face_image_b64': str(),  # 抓拍照片
                    'base_face_image_b64': str(),  # 对比照片
                    'face_type_name_list': list(),  # 类别名称
                    'face_similarity': dist,  # 相似度
                }
                # methods.debug_log('actions.action_8201', f"m-86: d1 is {d1}")

                # --- fill snap_face_image_b64 ---
                file_path = item.get('face_image_path')
                if file_path and methods.is_file(file_path):
                    frame = cv2.imread(file_path)
                    if frame is not None:
                        _, image = cv2.imencode('.jpg', frame)
                        base64_data = base64.b64encode(image)  # byte to b64 byte
                        s = base64_data.decode()  # byte to str
                        d1['snap_face_image_b64'] = f'data:image/jpeg;base64,{s}'

                # --- fill base_face_image_b64 ---
                # file_path = item.get('base_face_image_path')
                # if file_path and methods.is_file(file_path):
                #     frame = cv2.imread(file_path)
                #     if frame is not None:
                #         _, image = cv2.imencode('.jpg', frame)
                #         base64_data = base64.b64encode(image)  # byte to b64 byte
                #         s = base64_data.decode()  # byte to str
                #         d1['base_face_image_b64'] = f'data:image/jpeg;base64,{s}'

                # --- fill base_face_image_b64 ---
                frame = Global.arcface_agent.image_bytes_to_image_array(input_data.get('image_bytes'))
                if frame is not None:
                    _, image = cv2.imencode('.jpg', frame)
                    base64_data = base64.b64encode(image)  # byte to b64 byte
                    s = base64_data.decode()  # byte to str
                    d1['base_face_image_b64'] = f'data:image/jpeg;base64,{s}'

                # --- fill face_type_name_list ---
                face_type_uuid_list = face.get('face_type_uuid_list')
                if face_type_uuid_list:
                    d1['face_type_name_list'] = [face_type_name_dict.get(i)
                                                 for i in face_type_uuid_list if face_type_name_dict.get(i)]

                # --- append d1 ---
                d3.append(d1)

            # --- update call_id value ---
            rdb.set_one(key=call_id, data=(d3, len(d2)), expire_time=600)
            return dict(code=0, details=f"congratulations. | use time {methods.now_ts() - run_at}s")

        except Exception as exception:

            methods.debug_log('Task.task003', f"m-76: exception | {exception}")
            methods.debug_log('Task.task003', f"m-76: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")