1、修改api文件位置

2、意图识别继承langfuse
This commit is contained in:
2025-08-27 11:22:54 +08:00
parent 53ac47f4a5
commit c9c7f13060
12 changed files with 1385 additions and 1321 deletions
+379
View File
@@ -0,0 +1,379 @@
# from gevent import monkey
# monkey.patch_all()
import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional
import asyncio
import threading
import queue
import sqlite3
from contextlib import closing
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
def main(query: str) -> dict:
query = query.strip()
escaped_query = json.dumps(query, ensure_ascii=False)
return {
"format_query": escaped_query,
}
import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
# 定义数据库路径
DATA_DIR = os.path.join(os.getcwd(), "data")
DB_DIR = os.path.join(DATA_DIR, "db")
DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
# 创建异步日志队列和工作线程
log_queue = queue.Queue()
worker_thread = None
db_lock = threading.Lock() # 数据库操作锁
# 初始化数据库
def init_database():
os.makedirs(DB_DIR, exist_ok=True)
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
# 创建查询类型表
cursor.execute('''
CREATE TABLE IF NOT EXISTS query_types (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_type TEXT NOT NULL,
workflow_run_id TEXT NOT NULL,
timestamp TEXT NOT NULL
)
''')
# 创建点踩原因表
cursor.execute('''
CREATE TABLE IF NOT EXISTS dislike_reasons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dislike_reason TEXT NOT NULL,
workflow_run_id TEXT NOT NULL,
timestamp TEXT NOT NULL
)
''')
conn.commit()
logger.info("数据库初始化完成")
# 后台工作线程函数
def log_worker():
while True:
try:
# 从队列获取数据,设置超时以允许线程退出
data = log_queue.get(timeout=1.0)
if data is None: # 接收到退出信号
# 处理剩余数据后再退出
while not log_queue.empty():
data = log_queue.get_nowait()
if data is None: # 跳过额外的停止信号
continue
process_log_data(data)
break
process_log_data(data)
log_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True)
# 提取数据处理逻辑到单独函数
def process_log_data(data):
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
if "dislike_reason" in data:
# 保存点踩原因
cursor.execute(
"INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
(data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
)
table_name = "dislike_reasons"
else:
# 保存查询类型
cursor.execute(
"INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
(data["query_type"], data["workflow_run_id"], data["timestamp"])
)
table_name = "query_types"
conn.commit()
logger.info(f"成功保存数据到表: {table_name}")
except Exception as e:
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
# 创建数据目录
os.makedirs(DATA_DIR, exist_ok=True)
# 配置日志 - 同时输出到控制台和文件
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 创建日志目录
LOG_DIR = os.path.join(DATA_DIR, "logs")
os.makedirs(LOG_DIR, exist_ok=True)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler(
os.path.join(LOG_DIR, "answer_type_service.log"),
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
# 创建日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
# 添加处理器到日志器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
# 设置其他库的日志级别
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
# 定义请求模型
class AnswerTypeRequest(BaseModel):
query: str
query_type: str
# 创建FastAPI应用
app = FastAPI(
title="提问数据类型",
description="收集用户提问数据类型",
version="1.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global worker_thread
# 初始化数据库
init_database()
# 启动后台工作线程
worker_thread = threading.Thread(target=log_worker, daemon=True)
worker_thread.start()
logger.info("后台日志工作线程已启动")
# 应用关闭事件
@app.on_event("shutdown")
def shutdown_event():
global worker_thread
if worker_thread:
# 发送退出信号
log_queue.put(None)
# 等待工作线程处理剩余数据
worker_thread.join(timeout=10.0)
if worker_thread.is_alive():
logger.warning("工作线程未在超时时间内退出")
else:
logger.info("后台日志工作线程已停止")
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
@app.get("/query_type", summary="异步检索API")
async def query_type(query_type: str, workflow_run_id:str):
try:
# 记录请求
logger.info(f"接收到请求: 类型: {query_type}, workflow_run_id: {workflow_run_id}")
# 准备数据
timestamp = datetime.datetime.now().isoformat()
query_data = {
"query_type": query_type,
"timestamp": timestamp,
"workflow_run_id": workflow_run_id
}
# 将数据放入队列
try:
log_queue.put(query_data)
success = True
logger.info(f"查询数据已加入队列,当前队列大小: {log_queue.qsize()}")
except Exception as e:
success = False
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
# 返回响应
content = f"<strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
return HTMLResponse(content=content)
except Exception as e:
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
@app.get("/dislike_reason", summary="记录点踩原因")
async def dislike_reason(reason: str, workflow_run_id: str):
try:
# 记录请求
logger.info(f"接收到点踩原因: {reason}, workflow_run_id: {workflow_run_id}")
# 准备数据
timestamp = datetime.datetime.now().isoformat()
dislike_data = {
"dislike_reason": reason,
"timestamp": timestamp,
"workflow_run_id": workflow_run_id
}
# 将数据放入队列
try:
log_queue.put(dislike_data)
success = True
logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
except Exception as e:
success = False
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
# 返回响应
content = f"<strong>点踩原因</strong>: {reason}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
return HTMLResponse(content=content)
except Exception as e:
logger.error(f"处理点踩原因请求时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"处理点踩原因请求时出错: {str(e)}")
# 添加数据查询API
@app.get("/stats", summary="查询统计数据")
async def get_stats():
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询类型统计
cursor.execute("SELECT COUNT(*) as count FROM query_types")
query_count = cursor.fetchone()['count']
# 点踩原因统计
cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons")
dislike_count = cursor.fetchone()['count']
# 最近5条查询记录
cursor.execute("""
SELECT query_type, workflow_run_id, timestamp
FROM query_types
ORDER BY id DESC LIMIT 5
""")
recent_queries = [dict(row) for row in cursor.fetchall()]
# 最近5条点踩记录
cursor.execute("""
SELECT dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons
ORDER BY id DESC LIMIT 5
""")
recent_dislikes = [dict(row) for row in cursor.fetchall()]
return {
"statistics": {
"total_queries": query_count,
"total_dislikes": dislike_count
},
"recent_data": {
"queries": recent_queries,
"dislikes": recent_dislikes
}
}
except Exception as e:
logger.error(f"获取统计数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取统计数据时出错: {str(e)}")
@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
async def get_query_by_workflow_id(workflow_run_id: str):
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询指定工作流ID的查询类型数据
cursor.execute("""
SELECT id, query_type, workflow_run_id, timestamp
FROM query_types
WHERE workflow_run_id = ?
ORDER BY id DESC
""", (workflow_run_id,))
results = [dict(row) for row in cursor.fetchall()]
if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
return {"data": results}
except Exception as e:
logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}")
@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
async def get_dislike_by_workflow_id(workflow_run_id: str):
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询指定工作流ID的点踩原因数据
cursor.execute("""
SELECT id, dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons
WHERE workflow_run_id = ?
ORDER BY id DESC
""", (workflow_run_id,))
results = [dict(row) for row in cursor.fetchall()]
if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
return {"data": results}
except Exception as e:
logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}")
if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用
import uvicorn
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
+132
View File
@@ -0,0 +1,132 @@
# from gevent import monkey
# monkey.patch_all()
import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
# 确保日志目录存在
os.makedirs('data/logs', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# 定义请求模型
class RetrieveRequest(BaseModel):
original_query: str
query_list: str
data_set_list: str
query_expand_dict: dict | str = Field(default="{}")
topk: int = Field(default=4)
metadata_filtering_conditions : dict = Field(default={})
# 创建FastAPI应用
app = FastAPI(
title="Dify查询检索服务",
description="基于Dify的异步查询检索服务",
version="1.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储DifyQueryRetrieval实例
dify_query_retrieval = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global dify_query_retrieval
# 初始化DifyQueryRetrieval实例
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
logger.info("DifyQueryRetrieval初始化完成")
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
@app.post("/retrieve", summary="异步检索API")
async def retrieve(request: RetrieveRequest):
"""
异步检索API
Args:
request: 包含原始查询、查询列表和数据集列表的请求对象
Returns:
检索结果
"""
try:
# 解析查询列表和数据集列表
query_list = request.query_list.split("<sub_query>")
data_set_list = request.data_set_list.split("<dataset>")
if isinstance(request.query_expand_dict, str):
query_expand_dict = json.loads(request.query_expand_dict)
else:
query_expand_dict = request.query_expand_dict
# 调用异步检索方法
start_time = time.time()
results = await dify_query_retrieval.retrieve_api_async(
request.original_query,
query_list,
data_set_list,
query_expand_dict=query_expand_dict,
top_k=request.topk,
metadata_filtering_conditions=request.metadata_filtering_conditions
)
end_time = time.time()
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}")
return results
except Exception as e:
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用
import uvicorn
uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=9002, reload=False, workers=1, log_level="info")
# # 使用uvicorn启动服务
# import uvicorn
# uvicorn.run(
# "rag2_0.api.DifyQueryRetrieval_api:app",
# host="0.0.0.0",
# port=8001,
# reload=False, # 开发环境启用热重载
# workers=1 # 生产环境可以增加worker数量
# )
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10
+185
View File
@@ -0,0 +1,185 @@
import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
import sys
sys.path.append(os.getcwd())
from rag2_0.intent_recognition import AsyncIntentRecognizer
# 确保日志目录存在
os.makedirs('data/logs', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# 定义请求模型
class IntentRecognizeRequest(BaseModel):
query: str
enable_query_expansion: bool = True
conversation_context: Dict = None
chat_history: Optional[List] = None
previous_slots: str | Dict = None
# 定义槽位填充响应模型
class SlotFillingResponse(BaseModel):
is_complete: bool = False
missing_slots: Dict[str, Any] = Field(default_factory=dict)
filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel):
# 必须包含的all字段
all: List[str] = Field(default_factory=list)
model_config = ConfigDict(extra='allow')
# 定义响应模型
class IntentRecognizeResponse(BaseModel):
source_query: str
source_query_keys: List[str]
vertical_classification: str
sub_classification: str
rewrite_query: str
keywords: List[Dict[str, str]] = Field(default_factory=list)
has_slot_filling: bool = False
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
dinge_qingdan_info: Dict[str, Any] = Field(default_factory=dict)
# 创建FastAPI应用
app = FastAPI(
title="意图识别服务",
description="基于LLM的意图识别和问题改写服务",
version="2.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储AsyncIntentRecognizer实例
_instance = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global _instance
_instance = await AsyncIntentRecognizer.create()
logger.info("AsyncIntentRecognizer初始化完成")
@app.post("/intent_recognize1")
async def intent_recognize(request: Request):
data = await request.json()
print(data)
return {"message": "success"}
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
async def intent_recognize(request: IntentRecognizeRequest):
try:
if not request.query:
raise HTTPException(status_code=400, detail="缺少query参数")
enable_query_expansion = request.enable_query_expansion
start_time = time.time()
current_softname = request.conversation_context.get("current_softname", "")
result = await _instance.process_query_async(
query=request.query,
conversation_context=request.conversation_context,
chat_history=request.chat_history,
previous_slots=request.previous_slots,
use_jieba=True,
enable_query_expansion=enable_query_expansion,
cur_soft_name=current_softname
)
dinge_qingdan_info = result["dinge_qingdan_info"]
end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
logger.info(f"意图识别耗时: {end_time - start_time:.2f}")
# 提取分类信息
classification = result["classification"]
# 提取关键词信息
keywords = result["keywords"]
keywords_list = []
if keywords and keywords.get("terms"):
for term in keywords["terms"]:
keywords_list.append({
"名称": term["name"]
})
# 提取槽位填充信息
slot_filling = result.get("slot_filling", {})
# 构建响应
response = IntentRecognizeResponse(
source_query=request.query,
source_query_keys=result["query_keys"],
vertical_classification=classification["vertical_classification"],
sub_classification=classification["sub_classification"],
rewrite_query=result["rewrite"]["rewrite"],
keywords=keywords_list,
has_slot_filling=len(slot_filling) != 0,
slot_filling=SlotFillingResponse(
is_complete=slot_filling.get("is_complete", False),
missing_slots=slot_filling.get("missing_slots", {}),
filled_data=slot_filling.get("filled_data", {})
),
query_expand=QueryExpandResponse(**result["query_expand"]),
dinge_qingdan_info=dinge_qingdan_info
)
return response
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__":
# 使用uvicorn启动服务
import uvicorn
uvicorn.run(
"rag2_0.dify.intent_recognition_api:app",
host="0.0.0.0",
port=9001,
reload=True, # 开发环境启用热重载
workers=1 # 生产环境可以增加worker数量
)
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10
+567
View File
@@ -0,0 +1,567 @@
# 添加FastAPI相关导入
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uvicorn
import sqlite3
import sys
import os
# 导入ExcelToSQLiteProcessor类
sys.path.append(os.getcwd())
from rag2_0.demo.create_qingdan_dinge_database import ExcelToSQLiteProcessor
# 导入向量检索相关类
from rag2_0.tool.ModelTool import XinferenceEmbeddings
from langchain_community.vectorstores import SQLiteVSS
from rag2_0.tool.APIKeyManager import APIKeyManager
# 创建FastAPI应用
app = FastAPI(title="清单定额库查询API", description="提供清单和定额信息查询接口")
TOP_K = 100
# 响应模型
class QueryResponse(BaseModel):
success: bool
message: str
data: Optional[List[Dict[str, Any]]] = None
# 批量查询请求模型
class DingEInfoList(BaseModel):
dinge_code_list: List[str] = []
dinge_name_list: List[str] = []
class QingDanInfo(BaseModel):
qingdan_code_list: List[str] = []
qingdan_name_list: List[str] = []
class DingEQingDanInfo(BaseModel):
dinge_info_list: DingEInfoList
qingdan_info: QingDanInfo
class BatchQueryRequest(BaseModel):
dinge_qingdan_info: DingEQingDanInfo
scope: Optional[str] = Query(None, description="适用范围")
# 批量查询响应模型
class BatchQueryResponse(BaseModel):
success: bool
message: str
dinge_data: Optional[List[Dict[str, Any]]] = None
qingdan_data: Optional[List[Dict[str, Any]]] = None
# 封装查询数据的相关代码
class QingDanDingEQueryService:
def __init__(self, db_path="/data/QueryRewrite/data/db/qingdan_ding_e_ku.db"):
self.db_path = db_path
self.top_k = TOP_K
# 初始化向量检索相关组件
self.embedding_function = XinferenceEmbeddings(api_key="")
# 初始化向量数据库连接
self.ding_e_vector_db = SQLiteVSS(
table="embeding_ding_e_zimu_name",
connection=None,
embedding=self.embedding_function,
db_file=self.db_path
)
self.qing_dan_vector_db = SQLiteVSS(
table="embeding_qd_zimu_name",
connection=None,
embedding=self.embedding_function,
db_file=self.db_path
)
def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None):
"""使用向量检索获取相似名称"""
try:
# 使用向量数据库进行相似性搜索
results = vector_db.similarity_search_with_score(query=query_text, k=30)
# 提取结果中的元数据
similar_items = []
for doc, score in results:
if scope and scope not in doc.metadata[field_map["适用范围"]]:
continue
metadata = doc.metadata
# 添加相似度分数
metadata['similarity_score'] = float(score)
similar_items.append(metadata)
# 按相似度分数排序,分数高的排前面
similar_items.sort(key=lambda x: x['similarity_score'])
return similar_items[:top_k]
except Exception as e:
print(f"向量检索出错: {str(e)}")
return []
def get_db_connection(self):
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row # 设置行工厂,使结果可以通过列名访问
return conn
def create_reverse_field_map(self):
"""创建字段反向映射(数据库字段名 -> 中文字段名)"""
# 定额库字段反向映射
ding_e_reverse_map = {v: k for k, v in ExcelToSQLiteProcessor.ding_e_field_map.items()}
# 清单库字段反向映射
qing_dan_reverse_map = {v: k for k, v in ExcelToSQLiteProcessor.qing_dan_field_map.items()}
return ding_e_reverse_map, qing_dan_reverse_map
def convert_field_names_to_chinese(self, data_list, reverse_map):
"""转换字段名称为中文"""
result = []
for item in data_list:
chinese_item = {}
for field_name, value in item.items():
# 如果字段名在反向映射中存在,则使用中文名称
chinese_field_name = reverse_map.get(field_name, field_name)
chinese_item[chinese_field_name] = value
result.append(chinese_item)
return result
def sort_results_by_exact_match(self, data_list, search_term, field_name):
"""对查询结果进行排序,将完全匹配的结果排在前面"""
exact_matches = []
partial_matches = []
for item in data_list:
# 检查是否为完全匹配
if search_term.upper() == str(item[field_name]).upper():
exact_matches.append(item)
else:
partial_matches.append(item)
# 合并结果,完全匹配的排在前面
return exact_matches + partial_matches
def query_ding_e_by_name(self, name, scope=None):
"""根据定额名称查询定额子目表中详情信息,使用向量检索扩大查询范围"""
try:
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"]
attr_table = ExcelToSQLiteProcessor.ding_e_table_names["定额资源库属性"]
field_map = ExcelToSQLiteProcessor.ding_e_field_map
# 1. 先使用向量检索获取相似名称
similar_items = self.get_similar_names_by_vector(query_text=name,
vector_db=self.ding_e_vector_db,
field_map=field_map,
scope=scope)
similar_names = [item[field_map['名称']] for item in similar_items]
# 构建查询条件,始终包含原始名称的模糊匹配
like_conditions = [f"zimu.{field_map['名称']} LIKE ?"]
params = [f'%{name}%']
# 如果有向量检索结果,添加这些结果的模糊匹配条件
for similar_name in similar_names:
like_conditions.append(f"zimu.{field_map['名称']} LIKE ?")
params.append(f'%{similar_name}%')
# 将所有条件用OR连接
like_conditions_str = " OR ".join(like_conditions)
like_conditions_str= f"({like_conditions_str})"
query = f"""
SELECT
zimu.*,
mulu.{field_map['名称']} as mulu_name,
attr.{field_map['发布时间']} as attr_pub_time,
attr.{field_map['适用范围']} as attr_scope
FROM {zimu_table} zimu
LEFT JOIN {mulu_table} mulu ON
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
LEFT JOIN {attr_table} attr ON
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
WHERE {like_conditions_str}
"""
# 如果提供了适用范围,添加过滤条件
if scope:
query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%')
cursor.execute(query, params)
# 获取结果
results = cursor.fetchall()
data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, name, field_map['名称'])
data = data[:self.top_k]
conn.close()
if not data:
return {"success": True, "message": "未找到匹配的定额信息", "data": []}
# 创建反向映射并转换字段名为中文
ding_e_reverse_map, _ = self.create_reverse_field_map()
# 添加自定义字段映射
ding_e_reverse_map['mulu_name'] = '目录名称'
ding_e_reverse_map['attr_pub_time'] = '发布时间'
ding_e_reverse_map['attr_scope'] = '适用范围'
chinese_data = self.convert_field_names_to_chinese(data, ding_e_reverse_map)
return {"success": True, "message": "查询成功", "data": chinese_data}
except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"}
def query_ding_e_by_code(self, code, scope=None):
"""根据定额编码查询定额子目表中详情信息"""
try:
code = code.upper()
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"]
attr_table = ExcelToSQLiteProcessor.ding_e_table_names["定额资源库属性"]
field_map = ExcelToSQLiteProcessor.ding_e_field_map
# 构建连表查询SQL
query = f"""
SELECT
zimu.*,
mulu.{field_map['名称']} as mulu_name,
attr.{field_map['发布时间']} as attr_pub_time,
attr.{field_map['适用范围']} as attr_scope
FROM {zimu_table} zimu
LEFT JOIN {mulu_table} mulu ON
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
LEFT JOIN {attr_table} attr ON
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
WHERE zimu.{field_map['编码']} LIKE ?
"""
params = [f'%{code}%']
# 如果提供了适用范围,添加过滤条件
if scope:
query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%')
cursor.execute(query, params)
# 获取结果
results = cursor.fetchall()
data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, code, field_map['编码'])
data = data[:self.top_k]
conn.close()
if not data:
return {"success": True, "message": "未找到匹配的定额信息", "data": []}
# 创建反向映射并转换字段名为中文
ding_e_reverse_map, _ = self.create_reverse_field_map()
# 添加自定义字段映射
ding_e_reverse_map['mulu_name'] = '目录名称'
ding_e_reverse_map['attr_pub_time'] = '发布时间'
ding_e_reverse_map['attr_scope'] = '适用范围'
chinese_data = self.convert_field_names_to_chinese(data, ding_e_reverse_map)
return {"success": True, "message": "查询成功", "data": chinese_data}
except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"}
def query_qing_dan_by_name(self, name, scope=None):
"""根据清单名称查询清单子目表中详情信息,使用向量检索扩大查询范围"""
try:
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"]
attr_table = ExcelToSQLiteProcessor.qing_dan_table_names["资源库属性"]
field_map = ExcelToSQLiteProcessor.qing_dan_field_map
# 1. 先使用向量检索获取相似名称
similar_items = self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope)
similar_names = [item['mc'] for item in similar_items]
# 构建查询条件,始终包含原始名称的模糊匹配
like_conditions = [f"zimu.{field_map['名称']} LIKE ?"]
params = [f'%{name}%']
# 如果有向量检索结果,添加这些结果的模糊匹配条件
for similar_name in similar_names:
like_conditions.append(f"zimu.{field_map['名称']} LIKE ?")
params.append(f'%{similar_name}%')
# 将所有条件用OR连接
like_conditions_str = " OR ".join(like_conditions)
like_conditions_str= f"({like_conditions_str})"
query = f"""
SELECT
zimu.*,
mulu.{field_map['名称']} as mulu_name,
attr.{field_map['发布时间']} as attr_pub_time,
attr.{field_map['适用范围']} as attr_scope
FROM {zimu_table} zimu
LEFT JOIN {mulu_table} mulu ON
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
LEFT JOIN {attr_table} attr ON
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
WHERE {like_conditions_str}
"""
# 如果提供了适用范围,添加过滤条件
if scope:
query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%')
cursor.execute(query, params)
# 获取结果
results = cursor.fetchall()
data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, name, field_map['名称'])
data = data[:self.top_k]
conn.close()
if not data:
return {"success": True, "message": "未找到匹配的清单信息", "data": []}
# 创建反向映射并转换字段名为中文
_, qing_dan_reverse_map = self.create_reverse_field_map()
# 添加自定义字段映射
qing_dan_reverse_map['mulu_name'] = '目录名称'
qing_dan_reverse_map['attr_pub_time'] = '发布时间'
qing_dan_reverse_map['attr_scope'] = '适用范围'
chinese_data = self.convert_field_names_to_chinese(data, qing_dan_reverse_map)
return {"success": True, "message": "查询成功", "data": chinese_data}
except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"}
def query_qing_dan_by_code(self, code, scope=None):
"""根据清单编码查询清单子目表中详情信息"""
try:
code = code.upper()
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"]
attr_table = ExcelToSQLiteProcessor.qing_dan_table_names["资源库属性"]
field_map = ExcelToSQLiteProcessor.qing_dan_field_map
# 构建连表查询SQL
query = f"""
SELECT
zimu.*,
mulu.{field_map['名称']} as mulu_name,
attr.{field_map['发布时间']} as attr_pub_time,
attr.{field_map['适用范围']} as attr_scope
FROM {zimu_table} zimu
LEFT JOIN {mulu_table} mulu ON
zimu.{field_map['章节码']} = mulu.{field_map['章节码']} AND
zimu.{field_map['资源库名称']} = mulu.{field_map['资源库名称']}
LEFT JOIN {attr_table} attr ON
zimu.{field_map['资源库名称']} = attr.{field_map['资源库名称']}
WHERE zimu.{field_map['编码']} LIKE ?
"""
params = [f'%{code}%']
# 如果提供了适用范围,添加过滤条件
if scope:
query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%')
cursor.execute(query, params)
# 获取结果
results = cursor.fetchall()
data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, code, field_map['编码'])
data = data[:self.top_k]
conn.close()
if not data:
return {"success": True, "message": "未找到匹配的清单信息", "data": []}
# 创建反向映射并转换字段名为中文
_, qing_dan_reverse_map = self.create_reverse_field_map()
# 添加自定义字段映射
qing_dan_reverse_map['mulu_name'] = '目录名称'
qing_dan_reverse_map['attr_pub_time'] = '发布时间'
qing_dan_reverse_map['attr_scope'] = '适用范围'
chinese_data = self.convert_field_names_to_chinese(data, qing_dan_reverse_map)
return {"success": True, "message": "查询成功", "data": chinese_data}
except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"}
def batch_query(self, requests:BatchQueryRequest):
"""批量查询接口,支持向量检索"""
dinge_results = []
qingdan_results = []
tracking_dict = {} # 用于跟踪已查询过的项目,避免重复
try:
# 获取查询信息
dinge_info = requests.dinge_qingdan_info.dinge_info_list
qingdan_info = requests.dinge_qingdan_info.qingdan_info
scope = requests.scope
# 处理定额编码查询
for code in dinge_info.dinge_code_list or []:
key = f"dinge_code_{code}_{scope}"
if key not in tracking_dict:
result = self.query_ding_e_by_code(code, scope)
if result["success"] and result["data"]:
dinge_results.extend(result["data"])
tracking_dict[key] = True
# 处理定额名称查询
for name in dinge_info.dinge_name_list or []:
key = f"dinge_name_{name}_{scope}"
if key not in tracking_dict:
result = self.query_ding_e_by_name(name, scope)
if result["success"] and result["data"]:
dinge_results.extend(result["data"])
tracking_dict[key] = True
# 处理清单编码查询
for code in qingdan_info.qingdan_code_list or []:
key = f"qingdan_code_{code}_{scope}"
if key not in tracking_dict:
result = self.query_qing_dan_by_code(code, scope)
if result["success"] and result["data"]:
qingdan_results.extend(result["data"])
tracking_dict[key] = True
# 处理清单名称查询
for name in qingdan_info.qingdan_name_list or []:
key = f"qingdan_name_{name}_{scope}"
if key not in tracking_dict:
result = self.query_qing_dan_by_name(name, scope)
if result["success"] and result["data"]:
qingdan_results.extend(result["data"])
tracking_dict[key] = True
# 限制返回结果数量
dinge_results = dinge_results[:self.top_k]
qingdan_results = qingdan_results[:self.top_k]
if not dinge_results and not qingdan_results:
return {
"success": True,
"message": "未找到匹配信息",
"dinge_data": [],
"qingdan_data": []
}
return {
"success": True,
"message": "查询成功",
"dinge_data": dinge_results,
"qingdan_data": qingdan_results
}
except Exception as e:
return {"success": False, "message": f"批量查询出错: {str(e)}", "dinge_data": [], "qingdan_data": []}
# 创建查询服务实例
query_service = QingDanDingEQueryService()
# 1. 根据定额名称查询定额子目表中详情信息(包含资源库属性和目录信息)
@app.get("/api/ding_e/by_name", response_model=QueryResponse)
async def query_ding_e_by_name(
name: str = Query(..., description="定额名称"),
scope: Optional[str] = Query(None, description="适用范围")
):
result = query_service.query_ding_e_by_name(name, scope)
if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result)
# 2. 根据定额编码查询定额子目表中详情信息(包含资源库属性和目录信息)
@app.get("/api/ding_e/by_code", response_model=QueryResponse)
async def query_ding_e_by_code(
code: str = Query(..., description="定额编码"),
scope: Optional[str] = Query(None, description="适用范围")
):
result = query_service.query_ding_e_by_code(code, scope)
if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result)
# 3. 根据清单名称查询清单子目表中详情信息(包含资源库属性和目录信息)
@app.get("/api/qing_dan/by_name", response_model=QueryResponse)
async def query_qing_dan_by_name(
name: str = Query(..., description="清单名称"),
scope: Optional[str] = Query(None, description="适用范围")
):
result = query_service.query_qing_dan_by_name(name, scope)
if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result)
# 4. 根据清单编码查询清单子目表中详情信息(包含资源库属性和目录信息)
@app.get("/api/qing_dan/by_code", response_model=QueryResponse)
async def query_qing_dan_by_code(
code: str = Query(..., description="清单编码"),
scope: Optional[str] = Query(None, description="适用范围")
):
result = query_service.query_qing_dan_by_code(code, scope)
if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result)
# 5. 批量查询定额和清单信息
@app.post("/api/batch_query", response_model=BatchQueryResponse)
async def batch_query(request: BatchQueryRequest):
result = query_service.batch_query(request)
if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"])
return BatchQueryResponse(**result)
# 启动服务器的函数
def start_api_server():
"""启动FastAPI服务器"""
uvicorn.run(app, host="0.0.0.0", port=8005)
# 主函数
def main():
"""主函数"""
print("正在启动API服务器...")
start_api_server()
if __name__ == "__main__":
main()
# uvicorn rag2_0.dify.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10