from hub import methods, Global, numpy_method, camera_handle
from factories.line_manage import LineManage

import base64
import cv2
import os


class TEST(object):

    @classmethod
    def test002(cls, **params):
        """
        """
        try:

            # --- define ---
            run_at = methods.now_ts()
            condition = {
                'face_name': {'$ne': None},
            }
            total = Global.mdb.get_count('Face', condition)
            items = Global.mdb.filter('Face', condition)
            count = 0

            for item in items:
                # --- log ---
                count += 1
                methods.debug_log('TEST.test002', f"m-24: {count}/{total}")

                # --- check ---
                face_image_path = item.get('face_feature_info_list')[0].get('face_image_path')
                face_image_path = f"{face_image_path[:-4]}-fake.jpg"
                if not os.path.isfile(face_image_path):
                    continue

                # --- 提取特征 ---
                image_array = cv2.imread(face_image_path)
                result = Global.scrfd_agent.inference_with_image_array(image_array)
                face_image = result[0].get('align_face')
                face_features = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

                # --- check null ---
                if face_features is None:
                    methods.debug_log('TEST.test002', f"m-46: face_features is None! | {face_image_path}")
                    continue

                # --- update ---
                face_uuid = str(item.get('_id'))
                face_feature_info_list = item.get('face_feature_info_list')
                face_feature_info_list.append(
                    {
                        'face_object_confidence': 1.0,
                        'face_features': methods.pickle_dumps(face_features),
                        'face_image_path': face_image_path,
                        'detector_name': 'scrfd',
                        'mask_on_face_is': False,
                    }
                )
                Global.mdb.update_one_by_id('Face', face_uuid, {'face_feature_info_list': face_feature_info_list})
                methods.debug_log('TEST.test002', f"m-62: face_image_path -> {face_image_path} | {face_uuid}")

            # --- log ---
            methods.debug_log('TEST.test002', f"m-32: time use {round(methods.now_ts() - run_at, 2)}s, "
                                              f"update count is {count}")

            return dict(code=0, details=f"end.")

        except Exception as exception:
            methods.debug_log('TEST.test002', f"m-84: exception | {exception}")
            methods.debug_log('TEST.test002', f"m-84: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test001(cls, **params):
        """批量导入人脸特征数据"""
        try:
            # --- define ---
            api = Global.get_yzw_client()
            db1 = Global.get_mariadb_client()
            db2 = Global.get_mongodb_client()
            run_at = methods.now_ts()

            # --- get mariadb.worker_info ---
            count = 0
            items = db1.get_all('worker_info')
            for item in items:

                # --- debug log ---
                count += 1
                if count % 1000 == 0:
                    methods.debug_log('TEST.test001', f"count: {count}")

                # --- check null ---
                if not item.image_path:
                    continue

                # --- set tags if not exists ---
                data = db2.get_one(Global.dataset_name, {'prc_id': item.worker_id})
                project_name_tags = data.get('project_name_tags', list())
                worker_type_tags = data.get('worker_type_tags', list())
                if item.project_name and item.project_name not in project_name_tags:
                    project_name_tags.append(item.project_name)
                if item.worker_type and item.worker_type not in worker_type_tags:
                    worker_type_tags.append(item.worker_type)

                # --- set face if not exists ---
                face_features = data.get('face_features')
                if not face_features:

                    # --- check null ---
                    image_bytes = api.get_image(item.image_path)
                    if not image_bytes:
                        # methods.debug_log('TEST.test001', f"d1: {item.worker_id}")
                        continue

                    # --- get face features ---
                    face_features = DetectFaceEngine.get_face_features_by_image_bytes(image_bytes)
                    # methods.debug_log('TEST.test001', f"d1: {face_features}")

                # --- check null ---
                if not face_features:
                    # methods.debug_log('TEST.test001', f"d1: {item.worker_id}")
                    continue

                # --- set mongodb.VisitorInfo ---
                unique_dict = {
                    'prc_id': item.worker_id,
                }
                update_dict = {
                    'face_name': item.worker_name,
                    'face_path': f"https://lwres.yzw.cn/{item.image_path}",
                    'face_features': methods.pickle_dumps(face_features),
                    'project_name_tags': project_name_tags,
                    'worker_type_tags': worker_type_tags,
                }
                db2.update_one(Global.dataset_name, unique_dict, update_dict)
                # break

            methods.debug_log('TEST.test001', f"t1: {methods.now_ts() - run_at}s")
            return dict(code=0, details=f"end.")

        except Exception as exception:
            methods.debug_log('TEST.test001', f"m-84: exception | {exception}")
            methods.debug_log('TEST.test001', f"m-84: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test004(cls, **params):
        """更新 mongo.VisitorTags 表数据"""
        run_at = methods.now_ts()
        try:

            # --- update mongodb.VisitorInfo ---
            mdb = Global.get_mongodb_client()
            for item in mdb.get_all(Global.dataset_name):

                # --- update ProjectName ---
                project_name_tags = item.get('project_name_tags', [])
                for project_name in project_name_tags:
                    unique_dict = {
                        'tag_type': 'ProjectName',
                        'tag_name': project_name,
                    }
                    update_dict = {}
                    mdb.update_one('VisitorTags', unique_dict, update_dict)

                # --- update ProjectName ---
                worker_type_tags = item.get('worker_type_tags', [])
                for worker_type in worker_type_tags:
                    unique_dict = {
                        'tag_type': 'WorkerType',
                        'tag_name': worker_type,
                    }
                    update_dict = {}
                    mdb.update_one('VisitorTags', unique_dict, update_dict)

            # --- todo 遍历 VisitorTags 表 然后更新每条数据的 face_amout 字段 统计每个标签所属的人数

            methods.debug_log('TEST.test004', f"m1: {methods.now_ts() - run_at}s")
            return dict(code=0)

        except Exception as exception:
            methods.debug_log('TEST.test004', f"m-124: exception | {exception}")
            methods.debug_log('TEST.test004', f"m-124: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test006(cls, **params):
        """
        crop_img=np.array(faces[i]) # 每张图中的单个人脸图像
        cv2.imwrite('crop.jpg',crop_img)
        """
        try:
            # --- define ---
            mdb = Global.get_mongodb_client()

            item = mdb.get_one_by_id('Face', '622fffd65bda389a51c96c24')
            face_features_1 = methods.pickle_loads(item.get('face_features'))

            item = mdb.get_one_by_id('Face', '622fff612ae2021a236e2c90')
            face_features_2 = methods.pickle_loads(item.get('face_features'))

            dist = Global.arcface_agent.compare_faces_by_normalization(face_features_1, face_features_2)
            methods.debug_log('TEST.test006', f"m-196: dist: {dist}")
            return dict(code=0, details=f"End.", dist=dist)

        except Exception as exception:
            methods.debug_log('TEST.test006', f"m-124: exception | {exception}")
            methods.debug_log('TEST.test006', f"m-124: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test007(cls, **params):
        """url测试websocket在线信息"""
        try:
            line_total = LineManage.get_line_total()
            line_state = LineManage.get_line_state()
            return dict(code=0, line_total=line_total, line_state=line_state)

        except Exception as exception:
            methods.debug_log('TEST.test007', f"m-161: exception | {exception}")
            methods.debug_log('TEST.test007', f"m-161: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test008(cls, **params):
        """"""
        try:
            # --- define ---
            api = Global.cscec_agent()

            # --- get face b64 ---
            file_path = f"/home/server/resources/vms-files/2021-0826-160627-600346.pickle"
            image = methods.load_pickle_file(file_path)  # 解压
            frame = numpy_method.to_array(image)  # list to numpy array
            _, image = cv2.imencode('.jpg', frame)
            base64_data = base64.b64encode(image)  # byte to b64 byte
            s = base64_data.decode()  # byte to str
            face_image_b64 = f'data:image/jpeg;base64,{s}'

            # --- test push_face ---
            methods.debug_log('TEST.test008', f"m-236: {type(methods.now_string())}")
            api.push_face(face_image_b64=face_image_b64, now_at=methods.now_string())

            # --- test push_face_log ---
            api.push_face_log()

            # --- test ---
            api.pull_alarm_list()
            api.pull_alarm_group_list()

            return dict(code=0, details=f"end.")

        except Exception as exception:
            methods.debug_log('TEST.test008', f"m-196: exception | {exception}")
            methods.debug_log('TEST.test008', f"m-196: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test009(cls, **params):
        """"""
        try:
            file_path = '/home/server/resources/vms-files/2021-0822-161705-929476.pickle'
            image = methods.load_pickle_file(file_path)  # 解压
            frame = numpy_method.to_array(image)  # list to numpy array
            _, image = cv2.imencode('.jpg', frame)
            base64_data = base64.b64encode(image)  # byte to b64 byte
            s = base64_data.decode()  # byte to str
            face_image_b64 = f'data:image/jpeg;base64,{s}'
            return dict(code=0, face_image_b64=face_image_b64)

        except Exception as exception:
            methods.debug_log('TEST.test009', f"m-214: exception | {exception}")
            methods.debug_log('TEST.test009', f"m-214: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test011(cls, **params):
        """
        下载语音文件
        http://192.168.30.13:8800/api?module=TEST&method=test011&mp3_name=2021-08-17-14-38-49-223380.mp3
        http://192.168.30.13:8800/api?module=TEST&method=test011&mp3_name=2021-08-17-14-59-15-665168.mp3
        """
        try:
            # --- check ---
            mp3_name = params.get('mp3_name')
            if not mp3_name:
                return dict(code=1, details=f"something is wrong.")

            # --- check ---
            file_path = f"/home/server/resources/mp3-files/{mp3_name}"
            if not methods.is_file(file_path):
                return dict(code=2, details=f"something is wrong.")

            return methods.read_bytes(file_path)

        except Exception as exception:

            methods.debug_log('Task.test011', f"m-239: exception | {exception}")
            methods.debug_log('Task.test011', f"m-239: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test012(cls, **params):
        """
        录入人脸图片
        """
        try:
            # --- 提取特征 ---
            # file_name = 'IMG_20220315_101522.jpg'
            # file_name = 'IMG_20220315_101649.jpg'
            # file_name = '2022-0324-154547-015729.jpg'
            # file_name = 'bing.jpg'
            file_name = 'qiang.jpg'
            image_array = cv2.imread(f'/home/server/resources/TestData/2022/0315/{file_name}')

            # --- 剪裁人脸 ---
            # inference_result = Global.scrfd_agent.inference(image_array)
            # image_array = inference_result[0].get('face_image')

            # --- 特征提取 ---
            face_features = Global.arcface_agent.get_face_features_normalization_by_image_array(image_array)

            # --- save file ---
            save_at = methods.now_string('%Y-%m%d-%H%M%S-%f')
            face_file_path = f"/home/server/resources/vms-files/{save_at}.jpg"
            cv2.imwrite(face_file_path, image_array)

            # --- save mongodb ---
            """
            Face: 陌生人脸表
            Face.face_features: 人脸特征
            Face.face_image_path: 人脸图像文件路径
            Face.create_at: 录入时间
            Face.prc_id: 身份证号
            Face.face_image_url: 人脸地址
            Face.face_name: 人员姓名
            """
            unique_dict = {
                'prc_id': file_name,
            }
            update_dict = {
                'face_features': methods.pickle_dumps(face_features),
                'face_image_path': face_file_path,
                'create_at': methods.now_ts(),
                'face_image_url': None,
                # 'face_name': 'zhang',
                # 'face_name': 'bing',
                'face_name': '陆强',
            }
            Global.mdb.update_one('Face', unique_dict, update_dict)

            return dict(code=0, details=f"end.")

        except Exception as exception:

            methods.debug_log('TEST.test012', f"m-297: exception | {exception}")
            methods.debug_log('TEST.test012', f"m-297: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test013(cls, **params):
        """对比人脸图像的相似度"""
        try:
            # --- test 1 ---

            # face_path = '/home/server/resources/vms-files/2022-0507-152213-741783.jpg'
            # result = Global.scrfd_agent.inference(face_path)
            # face_image = result[0].get('face_image')
            # face_features_0 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # face_path = '/home/server/resources/TestData/2022/0513/lq1.jpg'
            # result = Global.scrfd_agent.inference(face_path)
            # face_image = result[0].get('face_image')
            # face_features_1 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # face_path = '/home/server/resources/TestData/2022/0513/lq2.jpg'
            # result = Global.scrfd_agent.inference(face_path)
            # face_image = result[0].get('face_image')
            # face_features_2 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # dist_1 = Global.arcface_agent.compare_faces_by_normalization(face_features_0, face_features_1)
            # dist_2 = Global.arcface_agent.compare_faces_by_normalization(face_features_0, face_features_2)

            # --- test 2 ---

            # face_path = '/home/server/resources/TestData/2022/0531-1/001.jpg'  # 0.7584765553474426
            # face_path = '/home/server/resources/TestData/2022/0531-1/002.jpg'  # 0.8457706868648529
            face_path = '/home/server/resources/TestData/2022/0531-1/004.jpg'  # 0.8554660677909851 0.8809504210948944
            # face_path = '/home/server/resources/TestData/2022/0531-1/009.jpg'  # 0.7084800899028778
            image_array = cv2.imread(face_path)
            result = Global.scrfd_agent.inference_with_image_array(image_array)
            # face_image = result[0].get('raw_face')
            face_image = result[0].get('align_face')
            face_features_1 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # face_path = '/home/server/resources/TestData/2022/0531-1/101.jpg'  # 0.7629551887512207
            # face_path = '/home/server/resources/TestData/2022/0531-1/102.jpg'  # 0.8441838622093201
            face_path = '/home/server/resources/TestData/2022/0531-1/104.jpg'  # 0.8168102502822876 0.8709447979927063
            # face_path = '/home/server/resources/TestData/2022/0531-1/109.jpg'  # 0.6636235117912292
            image_array = cv2.imread(face_path)
            result = Global.scrfd_agent.inference_with_image_array(image_array)
            # face_image = result[0].get('raw_face')
            face_image = result[0].get('align_face')
            face_features_2 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # face_path = '/home/server/resources/TestData/2022/0531-1/201.jpg'
            # face_path = '/home/server/resources/TestData/2022/0531-1/202.jpg'
            face_path = '/home/server/resources/TestData/2022/0531-1/204.jpg'
            # face_path = '/home/server/resources/TestData/2022/0531-1/209.jpg'
            image_array = cv2.imread(face_path)
            result = Global.scrfd_agent.inference_with_image_array(image_array)
            # face_image = result[0].get('raw_face')
            face_image = result[0].get('align_face')
            face_features_3 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            dist_1 = Global.arcface_agent.compare_faces_by_normalization(face_features_1, face_features_3)
            dist_2 = Global.arcface_agent.compare_faces_by_normalization(face_features_2, face_features_3)

            return dict(code=0, dist_1=dist_1, dist_2=dist_2)

        except Exception as exception:
            methods.debug_log('TEST.test013', f"m-214: exception | {exception}")
            methods.debug_log('TEST.test013', f"m-214: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test014(cls, **params):
        """
        对比人脸图像的相似度
        face-uuid
        62a008b32bb1ca6081d36b1a
        base
        629eb0fb1ef13cc2688962a4
        """
        try:

            # item = Global.mdb.get_one_by_id('Face', '627621d4013c78548cee0020')  # 于佑飞
            # item = Global.mdb.get_one_by_id('Face', '62761dc5013c78548cedfff4')  # 王彦杰
            item = Global.mdb.get_one_by_id('Face', '629eb0fb1ef13cc2688962a4')  # test
            face_features_0 = methods.pickle_loads(item.get('face_feature_info_list')[0].get('face_features'))

            item = Global.mdb.get_one_by_id('Face', '62a008b32bb1ca6081d36b1a')  # test
            face_features_1 = methods.pickle_loads(item.get('face_feature_info_list')[0].get('face_features'))

            # face_path = '/home/server/resources/TestData/2022/0527/1653640682.jpg'
            # face_path = '/home/server/resources/TestData/2022/0527/1653640715.jpg'
            # image_array = cv2.imread(face_path)
            # result = Global.scrfd_agent.inference_with_image_array(image_array)
            # # face_image = result[0].get('raw_face')
            # face_image = result[0].get('align_face')
            # face_features_1 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            dist_1 = Global.arcface_agent.compare_faces_by_normalization(face_features_0, face_features_1)
            return dict(code=0, dist_1=dist_1)

        except Exception as exception:
            methods.debug_log('TEST.test014', f"m-214: exception | {exception}")
            methods.debug_log('TEST.test014', f"m-214: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")

    @classmethod
    def test015(cls, **params):
        """
        检索相似人脸
        """
        try:

            # face_path = '/home/server/resources/TestData/2022/0527/1653640682.jpg'
            face_path = '/home/server/resources/TestData/2022/0629/20220629090605.jpg'
            image_array = cv2.imread(face_path)
            result = Global.scrfd_agent.inference_with_image_array(image_array)
            face_image = result[0].get('align_face')
            face_features_0 = Global.arcface_agent.get_face_features_normalization_by_image_array(face_image)

            # --- fill face_dict ---
            face_dict = dict()
            condition = {
                'face_name': {'$ne': None},  # 只获取人为标注的底库
            }
            items = Global.mdb.filter('Face', condition)
            for index, item in enumerate(items):

                # --- define ---
                face_uuid = str(item.get('_id'))
                face_dict[face_uuid] = list()

                # --- check ---
                face_feature_info_list = item.get('face_feature_info_list')
                if not face_feature_info_list:
                    continue

                # --- fill ---
                for face_feature_info in face_feature_info_list:
                    face_features = face_feature_info.get('face_features')
                    if not face_features:
                        continue
                    if methods.pickle_loads(face_features) is None:
                        continue
                    d1 = methods.pickle_loads(face_features), face_feature_info.get('face_image_path')
                    face_dict[face_uuid].append(d1)

            run_at = methods.now_ts(unit='ms')
            face_uuid, face_dist, base_face_image_path = Global.arcface_agent.search_face_v2(face_features_0, face_dict)
            use_time = round((methods.now_ts(unit='ms') - run_at), 2)
            methods.debug_log('TEST.test015', f"m-439: use time {use_time}ms")

            return dict(code=0, face_uuid=face_uuid, face_dist=face_dist, base_face_image_path=base_face_image_path,
                        use_time=use_time)

        except Exception as exception:
            methods.debug_log('TEST.test015', f"m-214: exception | {exception}")
            methods.debug_log('TEST.test015', f"m-214: traceback | {methods.trace_log()}")
            return dict(code=-1, details=f"{methods.trace_log()}")