diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 58d4a71a..e526943f 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -404,16 +404,7 @@ def get_tables(ds: CoreDatasource): "excel") else get_engine_config() db = DB.get_db(ds.type) sql, sql_param = get_table_sql(ds, conf, get_version(ds)) - if equals_ignore_case(ds.type, "sqlite"): - engine = get_engine(ds) - with engine.raw_connection() as conn: - cursor = conn.cursor() - cursor.execute(sql) - res = cursor.fetchall() - cursor.close() - res_list = [TableSchema(*item) for item in res] - return res_list - elif db.connect_type == ConnectType.sqlalchemy: + if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql), {"param": sql_param}) as result: res = result.fetchall() @@ -460,15 +451,12 @@ def get_tables(ds: CoreDatasource): res_list = [TableSchema(*item) for item in res] return res_list elif equals_ignore_case(ds.type, 'hive'): - conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, - database=conf.database, **extra_config_dict) - cursor = conn.cursor() - cursor.execute(sql) - res = cursor.fetchall() - res_list = [TableSchema(*item) for item in res] - cursor.close() - conn.close() - return res_list + with hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list def get_fields(ds: CoreDatasource, table_name: str = None): @@ -476,20 +464,14 @@ def get_fields(ds: CoreDatasource, table_name: str = None): "excel") else get_engine_config() db = DB.get_db(ds.type) sql, p1, p2 = get_field_sql(ds, conf, table_name) - if equals_ignore_case(ds.type, "sqlite"): - engine = get_engine(ds) - with engine.raw_connection() as conn: - cursor = conn.cursor() - cursor.execute(sql) - res = cursor.fetchall() - cursor.close() - res_list = [ColumnSchema(item[1], item[2], '') for item in res] - return res_list - elif db.connect_type == ConnectType.sqlalchemy: + if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql), {"param1": p1, "param2": p2}) as result: res = result.fetchall() - res_list = [ColumnSchema(*item) for item in res] + if equals_ignore_case(ds.type, "sqlite"): + res_list = [ColumnSchema(item[1], item[2], '') for item in res] + else: + res_list = [ColumnSchema(*item) for item in res] return res_list else: extra_config_dict = get_extra_config(conf) @@ -532,15 +514,12 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res_list = [ColumnSchema(*item) for item in res] return res_list elif equals_ignore_case(ds.type, 'hive'): - conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, - database=conf.database, **extra_config_dict) - cursor = conn.cursor() - cursor.execute(sql) - res = cursor.fetchall() - res_list = [ColumnSchema(*item) for item in res] - cursor.close() - conn.close() - return res_list + with hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list def convert_value(value, datetime_format='space'): @@ -730,28 +709,24 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= except Exception as ex: raise Exception(str(ex)) elif equals_ignore_case(ds.type, 'hive'): - conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, - database=conf.database, **extra_config_dict) - cursor = conn.cursor() - try: - # Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback. - hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql) - cursor.execute(hive_sql) - res = cursor.fetchall() - columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for - field in - cursor.description] - result_list = [ - {str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in - res - ] - return {"fields": columns, "data": result_list, - "sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))} - except Exception as ex: - raise ParseSQLResultError(str(ex)) - finally: - cursor.close() - conn.close() + with hive.connect(host=conf.host, port=conf.port, username=conf.username, + database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + # Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback. + hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql) + cursor.execute(hive_sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in + res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):