1、修改api文件位置
2、意图识别继承langfuse
This commit is contained in:
@@ -1,379 +0,0 @@
|
||||
# from gevent import monkey
|
||||
# monkey.patch_all()
|
||||
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
def main(query: str) -> dict:
|
||||
query = query.strip()
|
||||
escaped_query = json.dumps(query, ensure_ascii=False)
|
||||
return {
|
||||
"format_query": escaped_query,
|
||||
}
|
||||
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
|
||||
# 定义数据库路径
|
||||
DATA_DIR = os.path.join(os.getcwd(), "data")
|
||||
DB_DIR = os.path.join(DATA_DIR, "db")
|
||||
DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
|
||||
|
||||
# 创建异步日志队列和工作线程
|
||||
log_queue = queue.Queue()
|
||||
worker_thread = None
|
||||
db_lock = threading.Lock() # 数据库操作锁
|
||||
|
||||
# 初始化数据库
|
||||
def init_database():
|
||||
os.makedirs(DB_DIR, exist_ok=True)
|
||||
|
||||
with db_lock:
|
||||
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建查询类型表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS query_types (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
query_type TEXT NOT NULL,
|
||||
workflow_run_id TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建点踩原因表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS dislike_reasons (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dislike_reason TEXT NOT NULL,
|
||||
workflow_run_id TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
# 后台工作线程函数
|
||||
def log_worker():
|
||||
while True:
|
||||
try:
|
||||
# 从队列获取数据,设置超时以允许线程退出
|
||||
data = log_queue.get(timeout=1.0)
|
||||
if data is None: # 接收到退出信号
|
||||
# 处理剩余数据后再退出
|
||||
while not log_queue.empty():
|
||||
data = log_queue.get_nowait()
|
||||
if data is None: # 跳过额外的停止信号
|
||||
continue
|
||||
process_log_data(data)
|
||||
break
|
||||
|
||||
process_log_data(data)
|
||||
log_queue.task_done()
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 提取数据处理逻辑到单独函数
|
||||
def process_log_data(data):
|
||||
try:
|
||||
with db_lock:
|
||||
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if "dislike_reason" in data:
|
||||
# 保存点踩原因
|
||||
cursor.execute(
|
||||
"INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
|
||||
(data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
|
||||
)
|
||||
table_name = "dislike_reasons"
|
||||
else:
|
||||
# 保存查询类型
|
||||
cursor.execute(
|
||||
"INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
|
||||
(data["query_type"], data["workflow_run_id"], data["timestamp"])
|
||||
)
|
||||
table_name = "query_types"
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"成功保存数据到表: {table_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 创建数据目录
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
||||
# 配置日志 - 同时输出到控制台和文件
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 创建日志目录
|
||||
LOG_DIR = os.path.join(DATA_DIR, "logs")
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# 创建控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# 创建文件处理器
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join(LOG_DIR, "answer_type_service.log"),
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
|
||||
# 创建日志格式
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
# 添加处理器到日志器
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 设置其他库的日志级别
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
# 定义请求模型
|
||||
class AnswerTypeRequest(BaseModel):
|
||||
query: str
|
||||
query_type: str
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="提问数据类型",
|
||||
description="收集用户提问数据类型",
|
||||
version="1.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global worker_thread
|
||||
# 初始化数据库
|
||||
init_database()
|
||||
|
||||
# 启动后台工作线程
|
||||
worker_thread = threading.Thread(target=log_worker, daemon=True)
|
||||
worker_thread.start()
|
||||
logger.info("后台日志工作线程已启动")
|
||||
|
||||
# 应用关闭事件
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
global worker_thread
|
||||
if worker_thread:
|
||||
# 发送退出信号
|
||||
log_queue.put(None)
|
||||
# 等待工作线程处理剩余数据
|
||||
worker_thread.join(timeout=10.0)
|
||||
if worker_thread.is_alive():
|
||||
logger.warning("工作线程未在超时时间内退出")
|
||||
else:
|
||||
logger.info("后台日志工作线程已停止")
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/query_type", summary="异步检索API")
|
||||
async def query_type(query_type: str, workflow_run_id:str):
|
||||
try:
|
||||
# 记录请求
|
||||
logger.info(f"接收到请求: 类型: {query_type}, workflow_run_id: {workflow_run_id}")
|
||||
|
||||
# 准备数据
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
query_data = {
|
||||
"query_type": query_type,
|
||||
"timestamp": timestamp,
|
||||
"workflow_run_id": workflow_run_id
|
||||
}
|
||||
|
||||
# 将数据放入队列
|
||||
try:
|
||||
log_queue.put(query_data)
|
||||
success = True
|
||||
logger.info(f"查询数据已加入队列,当前队列大小: {log_queue.qsize()}")
|
||||
except Exception as e:
|
||||
success = False
|
||||
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 返回响应
|
||||
content = f"<strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
return HTMLResponse(content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
|
||||
|
||||
@app.get("/dislike_reason", summary="记录点踩原因")
|
||||
async def dislike_reason(reason: str, workflow_run_id: str):
|
||||
try:
|
||||
# 记录请求
|
||||
logger.info(f"接收到点踩原因: {reason}, workflow_run_id: {workflow_run_id}")
|
||||
|
||||
# 准备数据
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
dislike_data = {
|
||||
"dislike_reason": reason,
|
||||
"timestamp": timestamp,
|
||||
"workflow_run_id": workflow_run_id
|
||||
}
|
||||
|
||||
# 将数据放入队列
|
||||
try:
|
||||
log_queue.put(dislike_data)
|
||||
success = True
|
||||
logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
|
||||
except Exception as e:
|
||||
success = False
|
||||
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 返回响应
|
||||
content = f"<strong>点踩原因</strong>: {reason}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
return HTMLResponse(content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"处理点踩原因请求时出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"处理点踩原因请求时出错: {str(e)}")
|
||||
|
||||
# 添加数据查询API
|
||||
@app.get("/stats", summary="查询统计数据")
|
||||
async def get_stats():
|
||||
try:
|
||||
with db_lock:
|
||||
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询类型统计
|
||||
cursor.execute("SELECT COUNT(*) as count FROM query_types")
|
||||
query_count = cursor.fetchone()['count']
|
||||
|
||||
# 点踩原因统计
|
||||
cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons")
|
||||
dislike_count = cursor.fetchone()['count']
|
||||
|
||||
# 最近5条查询记录
|
||||
cursor.execute("""
|
||||
SELECT query_type, workflow_run_id, timestamp
|
||||
FROM query_types
|
||||
ORDER BY id DESC LIMIT 5
|
||||
""")
|
||||
recent_queries = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# 最近5条点踩记录
|
||||
cursor.execute("""
|
||||
SELECT dislike_reason, workflow_run_id, timestamp
|
||||
FROM dislike_reasons
|
||||
ORDER BY id DESC LIMIT 5
|
||||
""")
|
||||
recent_dislikes = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return {
|
||||
"statistics": {
|
||||
"total_queries": query_count,
|
||||
"total_dislikes": dislike_count
|
||||
},
|
||||
"recent_data": {
|
||||
"queries": recent_queries,
|
||||
"dislikes": recent_dislikes
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计数据时出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据时出错: {str(e)}")
|
||||
|
||||
@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
|
||||
async def get_query_by_workflow_id(workflow_run_id: str):
|
||||
try:
|
||||
with db_lock:
|
||||
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询指定工作流ID的查询类型数据
|
||||
cursor.execute("""
|
||||
SELECT id, query_type, workflow_run_id, timestamp
|
||||
FROM query_types
|
||||
WHERE workflow_run_id = ?
|
||||
ORDER BY id DESC
|
||||
""", (workflow_run_id,))
|
||||
|
||||
results = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
if not results:
|
||||
return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
|
||||
|
||||
return {"data": results}
|
||||
except Exception as e:
|
||||
logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}")
|
||||
|
||||
@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
|
||||
async def get_dislike_by_workflow_id(workflow_run_id: str):
|
||||
try:
|
||||
with db_lock:
|
||||
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询指定工作流ID的点踩原因数据
|
||||
cursor.execute("""
|
||||
SELECT id, dislike_reason, workflow_run_id, timestamp
|
||||
FROM dislike_reasons
|
||||
WHERE workflow_run_id = ?
|
||||
ORDER BY id DESC
|
||||
""", (workflow_run_id,))
|
||||
|
||||
results = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
if not results:
|
||||
return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
|
||||
|
||||
return {"data": results}
|
||||
except Exception as e:
|
||||
logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用Uvicorn运行FastAPI应用
|
||||
import uvicorn
|
||||
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
|
||||
@@ -23,12 +23,13 @@ class DifyQueryRetrieval:
|
||||
datasets_json = datasets.json()
|
||||
return {dataset["name"]:dataset for dataset in datasets_json["data"]}
|
||||
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
|
||||
try:
|
||||
dataset_id = self._datasets_list[dataset_name]["id"]
|
||||
retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
|
||||
knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id)
|
||||
|
||||
if len(metadata_filtering_conditions) !=0:
|
||||
retrieval_model["metadata_filtering_conditions"]=metadata_filtering_conditions
|
||||
documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300)
|
||||
retrieved_documents = documents.json().get("records", [])
|
||||
|
||||
@@ -51,7 +52,7 @@ class DifyQueryRetrieval:
|
||||
"documents": []
|
||||
}
|
||||
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_by_dataset方法
|
||||
|
||||
@@ -67,7 +68,8 @@ class DifyQueryRetrieval:
|
||||
return await asyncio.to_thread(
|
||||
self.retrieve_by_dataset,
|
||||
query,
|
||||
dataset_name
|
||||
dataset_name,
|
||||
metadata_filtering_conditions
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||
@@ -77,7 +79,13 @@ class DifyQueryRetrieval:
|
||||
"documents": []
|
||||
}
|
||||
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], query_expand_dict: dict, top_k: int = 5)->Dict[str, Any]:
|
||||
async def retrieve_api_async(self,
|
||||
original_query: str,
|
||||
query_list: List[str],
|
||||
data_set_list: List[str],
|
||||
query_expand_dict: dict,
|
||||
top_k: int = 5,
|
||||
metadata_filtering_conditions:dict = {})->Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||
|
||||
@@ -105,7 +113,7 @@ class DifyQueryRetrieval:
|
||||
continue
|
||||
|
||||
# 创建异步任务
|
||||
task = self.retrieve_by_dataset_async(query, dataset)
|
||||
task = self.retrieve_by_dataset_async(query, dataset, metadata_filtering_conditions)
|
||||
tasks.append(task)
|
||||
|
||||
# 并发执行所有异步任务
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
# from gevent import monkey
|
||||
# monkey.patch_all()
|
||||
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs('data/logs', exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义请求模型
|
||||
class RetrieveRequest(BaseModel):
|
||||
original_query: str
|
||||
query_list: str
|
||||
data_set_list: str
|
||||
query_expand_dict: dict | str = Field(default="{}")
|
||||
topk: int = Field(default=4)
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="Dify查询检索服务",
|
||||
description="基于Dify的异步查询检索服务",
|
||||
version="1.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局变量存储DifyQueryRetrieval实例
|
||||
dify_query_retrieval = None
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global dify_query_retrieval
|
||||
# 初始化DifyQueryRetrieval实例
|
||||
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
|
||||
logger.info("DifyQueryRetrieval初始化完成")
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/retrieve", summary="异步检索API")
|
||||
async def retrieve(request: RetrieveRequest):
|
||||
"""
|
||||
异步检索API
|
||||
|
||||
Args:
|
||||
request: 包含原始查询、查询列表和数据集列表的请求对象
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
try:
|
||||
# 解析查询列表和数据集列表
|
||||
query_list = request.query_list.split("<sub_query>")
|
||||
data_set_list = request.data_set_list.split("<dataset>")
|
||||
if isinstance(request.query_expand_dict, str):
|
||||
query_expand_dict = json.loads(request.query_expand_dict)
|
||||
else:
|
||||
query_expand_dict = request.query_expand_dict
|
||||
# 调用异步检索方法
|
||||
start_time = time.time()
|
||||
results = await dify_query_retrieval.retrieve_api_async(
|
||||
request.original_query,
|
||||
query_list,
|
||||
data_set_list,
|
||||
query_expand_dict=query_expand_dict,
|
||||
top_k=request.topk
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用Uvicorn运行FastAPI应用
|
||||
import uvicorn
|
||||
uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=1, log_level="info")
|
||||
# # 使用uvicorn启动服务
|
||||
# import uvicorn
|
||||
# uvicorn.run(
|
||||
# "rag2_0.dify.intent_recognition_api:app",
|
||||
# host="0.0.0.0",
|
||||
# port=8001,
|
||||
# reload=False, # 开发环境启用热重载
|
||||
# workers=1 # 生产环境可以增加worker数量
|
||||
# )
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10
|
||||
@@ -1,185 +0,0 @@
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs('data/logs', exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
enable_query_expansion: bool = True
|
||||
conversation_context: Dict = None
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
|
||||
# 定义槽位填充响应模型
|
||||
class SlotFillingResponse(BaseModel):
|
||||
is_complete: bool = False
|
||||
missing_slots: Dict[str, Any] = Field(default_factory=dict)
|
||||
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class QueryExpandResponse(BaseModel):
|
||||
# 必须包含的all字段
|
||||
all: List[str] = Field(default_factory=list)
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
# 定义响应模型
|
||||
class IntentRecognizeResponse(BaseModel):
|
||||
source_query: str
|
||||
source_query_keys: List[str]
|
||||
vertical_classification: str
|
||||
sub_classification: str
|
||||
rewrite_query: str
|
||||
keywords: List[Dict[str, str]] = Field(default_factory=list)
|
||||
has_slot_filling: bool = False
|
||||
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
|
||||
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
|
||||
dinge_qingdan_info: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="意图识别服务",
|
||||
description="基于LLM的意图识别和问题改写服务",
|
||||
version="2.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局变量存储AsyncIntentRecognizer实例
|
||||
_instance = None
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global _instance
|
||||
_instance = await AsyncIntentRecognizer.create()
|
||||
logger.info("AsyncIntentRecognizer初始化完成")
|
||||
|
||||
@app.post("/intent_recognize1")
|
||||
async def intent_recognize(request: Request):
|
||||
data = await request.json()
|
||||
print(data)
|
||||
return {"message": "success"}
|
||||
|
||||
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
||||
async def intent_recognize(request: IntentRecognizeRequest):
|
||||
try:
|
||||
if not request.query:
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
enable_query_expansion = request.enable_query_expansion
|
||||
start_time = time.time()
|
||||
current_softname = request.conversation_context.get("current_softname", "")
|
||||
result = await _instance.process_query_async(
|
||||
query=request.query,
|
||||
conversation_context=request.conversation_context,
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=True,
|
||||
enable_query_expansion=enable_query_expansion,
|
||||
cur_soft_name=current_softname
|
||||
)
|
||||
dinge_qingdan_info = result["dinge_qingdan_info"]
|
||||
|
||||
end_time = time.time()
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
||||
logger.info(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
|
||||
# 提取关键词信息
|
||||
keywords = result["keywords"]
|
||||
keywords_list = []
|
||||
if keywords and keywords.get("terms"):
|
||||
for term in keywords["terms"]:
|
||||
keywords_list.append({
|
||||
"名称": term["name"]
|
||||
})
|
||||
|
||||
# 提取槽位填充信息
|
||||
slot_filling = result.get("slot_filling", {})
|
||||
|
||||
# 构建响应
|
||||
response = IntentRecognizeResponse(
|
||||
source_query=request.query,
|
||||
source_query_keys=result["query_keys"],
|
||||
vertical_classification=classification["vertical_classification"],
|
||||
sub_classification=classification["sub_classification"],
|
||||
rewrite_query=result["rewrite"]["rewrite"],
|
||||
keywords=keywords_list,
|
||||
has_slot_filling=len(slot_filling) != 0,
|
||||
slot_filling=SlotFillingResponse(
|
||||
is_complete=slot_filling.get("is_complete", False),
|
||||
missing_slots=slot_filling.get("missing_slots", {}),
|
||||
filled_data=slot_filling.get("filled_data", {})
|
||||
),
|
||||
query_expand=QueryExpandResponse(**result["query_expand"]),
|
||||
dinge_qingdan_info=dinge_qingdan_info
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用uvicorn启动服务
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"rag2_0.dify.intent_recognition_api:app",
|
||||
host="0.0.0.0",
|
||||
port=9001,
|
||||
reload=True, # 开发环境启用热重载
|
||||
workers=1 # 生产环境可以增加worker数量
|
||||
)
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10
|
||||
@@ -1,567 +0,0 @@
|
||||
# 添加FastAPI相关导入
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uvicorn
|
||||
import sqlite3
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 导入ExcelToSQLiteProcessor类
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.demo.create_qingdan_dinge_database import ExcelToSQLiteProcessor
|
||||
# 导入向量检索相关类
|
||||
from rag2_0.tool.ModelTool import XinferenceEmbeddings
|
||||
from langchain_community.vectorstores import SQLiteVSS
|
||||
from rag2_0.tool.APIKeyManager import APIKeyManager
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(title="清单定额库查询API", description="提供清单和定额信息查询接口")
|
||||
TOP_K = 100
|
||||
|
||||
# 响应模型
|
||||
class QueryResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
# 批量查询请求模型
|
||||
class DingEInfoList(BaseModel):
|
||||
dinge_code_list: List[str] = []
|
||||
dinge_name_list: List[str] = []
|
||||
|
||||
class QingDanInfo(BaseModel):
|
||||
qingdan_code_list: List[str] = []
|
||||
qingdan_name_list: List[str] = []
|
||||
|
||||
class DingEQingDanInfo(BaseModel):
|
||||
dinge_info_list: DingEInfoList
|
||||
qingdan_info: QingDanInfo
|
||||
|
||||
class BatchQueryRequest(BaseModel):
|
||||
dinge_qingdan_info: DingEQingDanInfo
|
||||
scope: Optional[str] = Query(None, description="适用范围")
|
||||
|
||||
# 批量查询响应模型
|
||||
class BatchQueryResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
dinge_data: Optional[List[Dict[str, Any]]] = None
|
||||
qingdan_data: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
# 封装查询数据的相关代码
|
||||
class QingDanDingEQueryService:
|
||||
def __init__(self, db_path="/data/QueryRewrite/data/db/qingdan_ding_e_ku.db"):
|
||||
self.db_path = db_path
|
||||
self.top_k = TOP_K
|
||||
|
||||
# 初始化向量检索相关组件
|
||||
self.embedding_function = XinferenceEmbeddings(api_key="")
|
||||
|
||||
# 初始化向量数据库连接
|
||||
self.ding_e_vector_db = SQLiteVSS(
|
||||
table="embeding_ding_e_zimu_name",
|
||||
connection=None,
|
||||
embedding=self.embedding_function,
|
||||
db_file=self.db_path
|
||||
)
|
||||
|
||||
self.qing_dan_vector_db = SQLiteVSS(
|
||||
table="embeding_qd_zimu_name",
|
||||
connection=None,
|
||||
embedding=self.embedding_function,
|
||||
db_file=self.db_path
|
||||
)
|
||||
|
||||
def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None):
|
||||
"""使用向量检索获取相似名称"""
|
||||
try:
|
||||
# 使用向量数据库进行相似性搜索
|
||||
results = vector_db.similarity_search_with_score(query=query_text, k=30)
|
||||
|
||||
# 提取结果中的元数据
|
||||
similar_items = []
|
||||
for doc, score in results:
|
||||
if scope and scope not in doc.metadata[field_map["适用范围"]]:
|
||||
continue
|
||||
|
||||
metadata = doc.metadata
|
||||
# 添加相似度分数
|
||||
metadata['similarity_score'] = float(score)
|
||||
similar_items.append(metadata)
|
||||
|
||||
# 按相似度分数排序,分数高的排前面
|
||||
similar_items.sort(key=lambda x: x['similarity_score'])
|
||||
return similar_items[:top_k]
|
||||
except Exception as e:
|
||||
print(f"向量检索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_db_connection(self):
|
||||
"""获取数据库连接"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row # 设置行工厂,使结果可以通过列名访问
|
||||
return conn
|
||||
|
||||
def create_reverse_field_map(self):
|
||||
"""创建字段反向映射(数据库字段名 -> 中文字段名)"""
|
||||
# 定额库字段反向映射
|
||||
ding_e_reverse_map = {v: k for k, v in ExcelToSQLiteProcessor.ding_e_field_map.items()}
|
||||
# 清单库字段反向映射
|
||||
qing_dan_reverse_map = {v: k for k, v in ExcelToSQLiteProcessor.qing_dan_field_map.items()}
|
||||
return ding_e_reverse_map, qing_dan_reverse_map
|
||||
|
||||
def convert_field_names_to_chinese(self, data_list, reverse_map):
|
||||
"""转换字段名称为中文"""
|
||||
result = []
|
||||
for item in data_list:
|
||||
chinese_item = {}
|
||||
for field_name, value in item.items():
|
||||
# 如果字段名在反向映射中存在,则使用中文名称
|
||||
chinese_field_name = reverse_map.get(field_name, field_name)
|
||||
chinese_item[chinese_field_name] = value
|
||||
result.append(chinese_item)
|
||||
return result
|
||||
|
||||
def sort_results_by_exact_match(self, data_list, search_term, field_name):
|
||||
"""对查询结果进行排序,将完全匹配的结果排在前面"""
|
||||
exact_matches = []
|
||||
partial_matches = []
|
||||
|
||||
for item in data_list:
|
||||
# 检查是否为完全匹配
|
||||
if search_term.upper() == str(item[field_name]).upper():
|
||||
exact_matches.append(item)
|
||||
else:
|
||||
partial_matches.append(item)
|
||||
|
||||
# 合并结果,完全匹配的排在前面
|
||||
return exact_matches + partial_matches
|
||||
|
||||
def query_ding_e_by_name(self, name, scope=None):
|
||||
"""根据定额名称查询定额子目表中详情信息,使用向量检索扩大查询范围"""
|
||||
try:
|
||||
conn = self.get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表名和字段映射
|
||||
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
|
||||
mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"]
|
||||
attr_table = ExcelToSQLiteProcessor.ding_e_table_names["定额资源库属性"]
|
||||
field_map = ExcelToSQLiteProcessor.ding_e_field_map
|
||||
|
||||
# 1. 先使用向量检索获取相似名称
|
||||
similar_items = self.get_similar_names_by_vector(query_text=name,
|
||||
vector_db=self.ding_e_vector_db,
|
||||
field_map=field_map,
|
||||
scope=scope)
|
||||
similar_names = [item[field_map['名称']] for item in similar_items]
|
||||
|
||||
# 构建查询条件,始终包含原始名称的模糊匹配
|
||||
like_conditions = [f"zimu.{field_map['名称']} LIKE ?"]
|
||||
params = [f'%{name}%']
|
||||
|
||||
# 如果有向量检索结果,添加这些结果的模糊匹配条件
|
||||
for similar_name in similar_names:
|
||||
like_conditions.append(f"zimu.{field_map['名称']} LIKE ?")
|
||||
params.append(f'%{similar_name}%')
|
||||
|
||||
# 将所有条件用OR连接
|
||||
like_conditions_str = " OR ".join(like_conditions)
|
||||
like_conditions_str= f"({like_conditions_str})"
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
zimu.*,
|
||||
mulu.{field_map['名称']} as mulu_name,
|
||||
attr.{field_map['发布时间']} as attr_pub_time,
|
||||
attr.{field_map['适用范围']} as attr_scope
|
||||
FROM {zimu_table} zimu
|
||||
LEFT JOIN {mulu_table} mulu ON
|
||||
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
|
||||
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
|
||||
LEFT JOIN {attr_table} attr ON
|
||||
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
|
||||
WHERE {like_conditions_str}
|
||||
"""
|
||||
|
||||
# 如果提供了适用范围,添加过滤条件
|
||||
if scope:
|
||||
query += f" AND attr.{field_map['适用范围']} LIKE ?"
|
||||
params.append(f'%{scope}%')
|
||||
|
||||
cursor.execute(query, params)
|
||||
|
||||
# 获取结果
|
||||
results = cursor.fetchall()
|
||||
data = [dict(row) for row in results]
|
||||
|
||||
# 对结果进行排序,将全字匹配的排在前面
|
||||
data = self.sort_results_by_exact_match(data, name, field_map['名称'])
|
||||
data = data[:self.top_k]
|
||||
|
||||
conn.close()
|
||||
|
||||
if not data:
|
||||
return {"success": True, "message": "未找到匹配的定额信息", "data": []}
|
||||
|
||||
# 创建反向映射并转换字段名为中文
|
||||
ding_e_reverse_map, _ = self.create_reverse_field_map()
|
||||
|
||||
# 添加自定义字段映射
|
||||
ding_e_reverse_map['mulu_name'] = '目录名称'
|
||||
ding_e_reverse_map['attr_pub_time'] = '发布时间'
|
||||
ding_e_reverse_map['attr_scope'] = '适用范围'
|
||||
|
||||
chinese_data = self.convert_field_names_to_chinese(data, ding_e_reverse_map)
|
||||
|
||||
return {"success": True, "message": "查询成功", "data": chinese_data}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"查询出错: {str(e)}"}
|
||||
|
||||
def query_ding_e_by_code(self, code, scope=None):
|
||||
"""根据定额编码查询定额子目表中详情信息"""
|
||||
try:
|
||||
code = code.upper()
|
||||
conn = self.get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表名和字段映射
|
||||
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
|
||||
mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"]
|
||||
attr_table = ExcelToSQLiteProcessor.ding_e_table_names["定额资源库属性"]
|
||||
field_map = ExcelToSQLiteProcessor.ding_e_field_map
|
||||
|
||||
# 构建连表查询SQL
|
||||
query = f"""
|
||||
SELECT
|
||||
zimu.*,
|
||||
mulu.{field_map['名称']} as mulu_name,
|
||||
attr.{field_map['发布时间']} as attr_pub_time,
|
||||
attr.{field_map['适用范围']} as attr_scope
|
||||
FROM {zimu_table} zimu
|
||||
LEFT JOIN {mulu_table} mulu ON
|
||||
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
|
||||
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
|
||||
LEFT JOIN {attr_table} attr ON
|
||||
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
|
||||
WHERE zimu.{field_map['编码']} LIKE ?
|
||||
"""
|
||||
|
||||
params = [f'%{code}%']
|
||||
|
||||
# 如果提供了适用范围,添加过滤条件
|
||||
if scope:
|
||||
query += f" AND attr.{field_map['适用范围']} LIKE ?"
|
||||
params.append(f'%{scope}%')
|
||||
|
||||
cursor.execute(query, params)
|
||||
|
||||
# 获取结果
|
||||
results = cursor.fetchall()
|
||||
data = [dict(row) for row in results]
|
||||
|
||||
# 对结果进行排序,将全字匹配的排在前面
|
||||
data = self.sort_results_by_exact_match(data, code, field_map['编码'])
|
||||
data = data[:self.top_k]
|
||||
|
||||
conn.close()
|
||||
|
||||
if not data:
|
||||
return {"success": True, "message": "未找到匹配的定额信息", "data": []}
|
||||
|
||||
# 创建反向映射并转换字段名为中文
|
||||
ding_e_reverse_map, _ = self.create_reverse_field_map()
|
||||
|
||||
# 添加自定义字段映射
|
||||
ding_e_reverse_map['mulu_name'] = '目录名称'
|
||||
ding_e_reverse_map['attr_pub_time'] = '发布时间'
|
||||
ding_e_reverse_map['attr_scope'] = '适用范围'
|
||||
|
||||
chinese_data = self.convert_field_names_to_chinese(data, ding_e_reverse_map)
|
||||
|
||||
return {"success": True, "message": "查询成功", "data": chinese_data}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"查询出错: {str(e)}"}
|
||||
|
||||
def query_qing_dan_by_name(self, name, scope=None):
|
||||
"""根据清单名称查询清单子目表中详情信息,使用向量检索扩大查询范围"""
|
||||
try:
|
||||
conn = self.get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表名和字段映射
|
||||
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
|
||||
mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"]
|
||||
attr_table = ExcelToSQLiteProcessor.qing_dan_table_names["资源库属性"]
|
||||
field_map = ExcelToSQLiteProcessor.qing_dan_field_map
|
||||
|
||||
# 1. 先使用向量检索获取相似名称
|
||||
similar_items = self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope)
|
||||
similar_names = [item['mc'] for item in similar_items]
|
||||
|
||||
# 构建查询条件,始终包含原始名称的模糊匹配
|
||||
like_conditions = [f"zimu.{field_map['名称']} LIKE ?"]
|
||||
params = [f'%{name}%']
|
||||
|
||||
# 如果有向量检索结果,添加这些结果的模糊匹配条件
|
||||
for similar_name in similar_names:
|
||||
like_conditions.append(f"zimu.{field_map['名称']} LIKE ?")
|
||||
params.append(f'%{similar_name}%')
|
||||
|
||||
# 将所有条件用OR连接
|
||||
like_conditions_str = " OR ".join(like_conditions)
|
||||
like_conditions_str= f"({like_conditions_str})"
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
zimu.*,
|
||||
mulu.{field_map['名称']} as mulu_name,
|
||||
attr.{field_map['发布时间']} as attr_pub_time,
|
||||
attr.{field_map['适用范围']} as attr_scope
|
||||
FROM {zimu_table} zimu
|
||||
LEFT JOIN {mulu_table} mulu ON
|
||||
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
|
||||
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
|
||||
LEFT JOIN {attr_table} attr ON
|
||||
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
|
||||
WHERE {like_conditions_str}
|
||||
"""
|
||||
|
||||
# 如果提供了适用范围,添加过滤条件
|
||||
if scope:
|
||||
query += f" AND attr.{field_map['适用范围']} LIKE ?"
|
||||
params.append(f'%{scope}%')
|
||||
|
||||
cursor.execute(query, params)
|
||||
|
||||
# 获取结果
|
||||
results = cursor.fetchall()
|
||||
data = [dict(row) for row in results]
|
||||
|
||||
# 对结果进行排序,将全字匹配的排在前面
|
||||
data = self.sort_results_by_exact_match(data, name, field_map['名称'])
|
||||
data = data[:self.top_k]
|
||||
|
||||
conn.close()
|
||||
|
||||
if not data:
|
||||
return {"success": True, "message": "未找到匹配的清单信息", "data": []}
|
||||
|
||||
# 创建反向映射并转换字段名为中文
|
||||
_, qing_dan_reverse_map = self.create_reverse_field_map()
|
||||
|
||||
# 添加自定义字段映射
|
||||
qing_dan_reverse_map['mulu_name'] = '目录名称'
|
||||
qing_dan_reverse_map['attr_pub_time'] = '发布时间'
|
||||
qing_dan_reverse_map['attr_scope'] = '适用范围'
|
||||
|
||||
chinese_data = self.convert_field_names_to_chinese(data, qing_dan_reverse_map)
|
||||
|
||||
return {"success": True, "message": "查询成功", "data": chinese_data}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"查询出错: {str(e)}"}
|
||||
|
||||
def query_qing_dan_by_code(self, code, scope=None):
|
||||
"""根据清单编码查询清单子目表中详情信息"""
|
||||
try:
|
||||
code = code.upper()
|
||||
conn = self.get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表名和字段映射
|
||||
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
|
||||
mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"]
|
||||
attr_table = ExcelToSQLiteProcessor.qing_dan_table_names["资源库属性"]
|
||||
field_map = ExcelToSQLiteProcessor.qing_dan_field_map
|
||||
|
||||
# 构建连表查询SQL
|
||||
query = f"""
|
||||
SELECT
|
||||
zimu.*,
|
||||
mulu.{field_map['名称']} as mulu_name,
|
||||
attr.{field_map['发布时间']} as attr_pub_time,
|
||||
attr.{field_map['适用范围']} as attr_scope
|
||||
FROM {zimu_table} zimu
|
||||
LEFT JOIN {mulu_table} mulu ON
|
||||
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
|
||||
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
|
||||
LEFT JOIN {attr_table} attr ON
|
||||
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
|
||||
WHERE zimu.{field_map['编码']} LIKE ?
|
||||
"""
|
||||
|
||||
params = [f'%{code}%']
|
||||
|
||||
# 如果提供了适用范围,添加过滤条件
|
||||
if scope:
|
||||
query += f" AND attr.{field_map['适用范围']} LIKE ?"
|
||||
params.append(f'%{scope}%')
|
||||
|
||||
cursor.execute(query, params)
|
||||
|
||||
# 获取结果
|
||||
results = cursor.fetchall()
|
||||
data = [dict(row) for row in results]
|
||||
|
||||
# 对结果进行排序,将全字匹配的排在前面
|
||||
data = self.sort_results_by_exact_match(data, code, field_map['编码'])
|
||||
data = data[:self.top_k]
|
||||
|
||||
conn.close()
|
||||
|
||||
if not data:
|
||||
return {"success": True, "message": "未找到匹配的清单信息", "data": []}
|
||||
|
||||
# 创建反向映射并转换字段名为中文
|
||||
_, qing_dan_reverse_map = self.create_reverse_field_map()
|
||||
|
||||
# 添加自定义字段映射
|
||||
qing_dan_reverse_map['mulu_name'] = '目录名称'
|
||||
qing_dan_reverse_map['attr_pub_time'] = '发布时间'
|
||||
qing_dan_reverse_map['attr_scope'] = '适用范围'
|
||||
|
||||
chinese_data = self.convert_field_names_to_chinese(data, qing_dan_reverse_map)
|
||||
|
||||
return {"success": True, "message": "查询成功", "data": chinese_data}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"查询出错: {str(e)}"}
|
||||
|
||||
def batch_query(self, requests:BatchQueryRequest):
|
||||
"""批量查询接口,支持向量检索"""
|
||||
dinge_results = []
|
||||
qingdan_results = []
|
||||
tracking_dict = {} # 用于跟踪已查询过的项目,避免重复
|
||||
|
||||
try:
|
||||
# 获取查询信息
|
||||
dinge_info = requests.dinge_qingdan_info.dinge_info_list
|
||||
qingdan_info = requests.dinge_qingdan_info.qingdan_info
|
||||
scope = requests.scope
|
||||
|
||||
# 处理定额编码查询
|
||||
for code in dinge_info.dinge_code_list or []:
|
||||
key = f"dinge_code_{code}_{scope}"
|
||||
if key not in tracking_dict:
|
||||
result = self.query_ding_e_by_code(code, scope)
|
||||
if result["success"] and result["data"]:
|
||||
dinge_results.extend(result["data"])
|
||||
tracking_dict[key] = True
|
||||
|
||||
# 处理定额名称查询
|
||||
for name in dinge_info.dinge_name_list or []:
|
||||
key = f"dinge_name_{name}_{scope}"
|
||||
if key not in tracking_dict:
|
||||
result = self.query_ding_e_by_name(name, scope)
|
||||
if result["success"] and result["data"]:
|
||||
dinge_results.extend(result["data"])
|
||||
tracking_dict[key] = True
|
||||
|
||||
# 处理清单编码查询
|
||||
for code in qingdan_info.qingdan_code_list or []:
|
||||
key = f"qingdan_code_{code}_{scope}"
|
||||
if key not in tracking_dict:
|
||||
result = self.query_qing_dan_by_code(code, scope)
|
||||
if result["success"] and result["data"]:
|
||||
qingdan_results.extend(result["data"])
|
||||
tracking_dict[key] = True
|
||||
|
||||
# 处理清单名称查询
|
||||
for name in qingdan_info.qingdan_name_list or []:
|
||||
key = f"qingdan_name_{name}_{scope}"
|
||||
if key not in tracking_dict:
|
||||
result = self.query_qing_dan_by_name(name, scope)
|
||||
if result["success"] and result["data"]:
|
||||
qingdan_results.extend(result["data"])
|
||||
tracking_dict[key] = True
|
||||
|
||||
# 限制返回结果数量
|
||||
dinge_results = dinge_results[:self.top_k]
|
||||
qingdan_results = qingdan_results[:self.top_k]
|
||||
|
||||
if not dinge_results and not qingdan_results:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "未找到匹配信息",
|
||||
"dinge_data": [],
|
||||
"qingdan_data": []
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "查询成功",
|
||||
"dinge_data": dinge_results,
|
||||
"qingdan_data": qingdan_results
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"批量查询出错: {str(e)}", "dinge_data": [], "qingdan_data": []}
|
||||
|
||||
# 创建查询服务实例
|
||||
query_service = QingDanDingEQueryService()
|
||||
|
||||
# 1. 根据定额名称查询定额子目表中详情信息(包含资源库属性和目录信息)
|
||||
@app.get("/api/ding_e/by_name", response_model=QueryResponse)
|
||||
async def query_ding_e_by_name(
|
||||
name: str = Query(..., description="定额名称"),
|
||||
scope: Optional[str] = Query(None, description="适用范围")
|
||||
):
|
||||
result = query_service.query_ding_e_by_name(name, scope)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
return QueryResponse(**result)
|
||||
|
||||
# 2. 根据定额编码查询定额子目表中详情信息(包含资源库属性和目录信息)
|
||||
@app.get("/api/ding_e/by_code", response_model=QueryResponse)
|
||||
async def query_ding_e_by_code(
|
||||
code: str = Query(..., description="定额编码"),
|
||||
scope: Optional[str] = Query(None, description="适用范围")
|
||||
):
|
||||
result = query_service.query_ding_e_by_code(code, scope)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
return QueryResponse(**result)
|
||||
|
||||
# 3. 根据清单名称查询清单子目表中详情信息(包含资源库属性和目录信息)
|
||||
@app.get("/api/qing_dan/by_name", response_model=QueryResponse)
|
||||
async def query_qing_dan_by_name(
|
||||
name: str = Query(..., description="清单名称"),
|
||||
scope: Optional[str] = Query(None, description="适用范围")
|
||||
):
|
||||
result = query_service.query_qing_dan_by_name(name, scope)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
return QueryResponse(**result)
|
||||
|
||||
# 4. 根据清单编码查询清单子目表中详情信息(包含资源库属性和目录信息)
|
||||
@app.get("/api/qing_dan/by_code", response_model=QueryResponse)
|
||||
async def query_qing_dan_by_code(
|
||||
code: str = Query(..., description="清单编码"),
|
||||
scope: Optional[str] = Query(None, description="适用范围")
|
||||
):
|
||||
result = query_service.query_qing_dan_by_code(code, scope)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
return QueryResponse(**result)
|
||||
|
||||
# 5. 批量查询定额和清单信息
|
||||
@app.post("/api/batch_query", response_model=BatchQueryResponse)
|
||||
async def batch_query(request: BatchQueryRequest):
|
||||
result = query_service.batch_query(request)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
return BatchQueryResponse(**result)
|
||||
|
||||
# 启动服务器的函数
|
||||
def start_api_server():
|
||||
"""启动FastAPI服务器"""
|
||||
uvicorn.run(app, host="0.0.0.0", port=8005)
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("正在启动API服务器...")
|
||||
start_api_server()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# uvicorn rag2_0.dify.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10
|
||||
Reference in New Issue
Block a user