# update: 2021-7-2-9
"""
pymysql==0.9.3
peewee==3.13.1  # db orm
"""
from peewee import MySQLDatabase as PeeweeEngine  # todo UserWarning: Unable to determine MySQL version
# from playhouse.mysql_ext import MySQLConnectorDatabase as PeeweeEngine  # todo MySQL connector not installed!
from playhouse.reflection import generate_models


class Client(PeeweeEngine):

    def __init__(self, host='test-stack-mariadb', port=3306, database='vms', username='root', password='root'):
        """
        内部访问:
            host: test-stack-mariadb
            port: 3306
        远程访问:
            host: 118.190.217.96、192.168.20.162
            port: 17706
        """
        super().__init__(database, user=username, password=password, host=host, port=port)
        tables = generate_models(self)
        globals().update(tables)
        self.tables = globals()

        # --- vuls ---
        # self.vuls = dict()
        # table = self.tables['vul_detail']
        # for item in table.select():
        # 	self.vuls[item.cve_id] = item.__data__

        # def get_vul_by_cve_id(self, cve_id):
        # 	return self.vuls.get(cve_id, {})

    def get_all(self, table_name):
        """
        获取全部数据
        """
        table = self.tables[table_name]
        return table.select()

    def get_one(self, table_name, unique_dict):
        """
        单条获取
        """
        table = self.tables[table_name]
        wheres = list()
        for k, v in unique_dict.items():
            data = getattr(table, k) == v
            wheres.append(data)
        for item in table.select().where(*wheres):
            return item.__data__

    def execute_sql(self, sql):
        """
        执行sql
        """
        try:
            result = db.execute_sql(sql)
            print(result.__class__.__name__)
            return result
        except Exception as exception:
            print(exception.__class__.__name__)
            return None


if __name__ == '__main__':
    # --- init ---
    # db = Client()
    # db = Client(host='192.168.30.49', port=7033)
    db = Client(host='127.0.0.1', port=3306, database='ar', username='root', password='20221212!')

    # --- test ---
    sql = f'''select sum(timestampdiff(second, createtime, updatetime)) from ar.ar_phone_call where status = 1'''
    result = db.execute_sql(sql)
    for row in result:
        print(' --- --- ---- ---')
        print(row)
        print(row.__class__.__name__)
    print(row[0])
    db.close()

    # --- test ---
    # items = db.get_all('vul_detail')
    # for item in items:
    # 	print(item.cve_id)
    # 	# print(item.vul_name)

    # --- test ---
    # out = db.get_one('vul_detail', {'cve_id': 'CVE-2007-1858'})
    # print(type(out.get('cvss_point')))
    # print(out.get('cvss_point'))