Compare commits
2 Commits
a934f2c398
...
bd0f86ff61
| Author | SHA1 | Date | |
|---|---|---|---|
| bd0f86ff61 | |||
| 8a58fef1a7 |
@@ -16416,13 +16416,6 @@
|
||||
"synonymous": [],
|
||||
"description": "中标清单中的材料或机械消耗量"
|
||||
},
|
||||
{
|
||||
"name": "清单项",
|
||||
"synonymous": [
|
||||
"清单项目"
|
||||
],
|
||||
"description": "工程量清单中的具体项目"
|
||||
},
|
||||
{
|
||||
"name": "结算工程解锁",
|
||||
"synonymous": [],
|
||||
@@ -17476,13 +17469,6 @@
|
||||
"synonymous": [],
|
||||
"description": "软件中用于输出各类成果文件或展示工程计价结果的界面"
|
||||
},
|
||||
{
|
||||
"name": "南方电网接口格式",
|
||||
"synonymous": [
|
||||
"南网规约接口"
|
||||
],
|
||||
"description": "符合南方电网数据交换标准的接口规范,用于导出符合南方电网规范的接口数据,以便上传至基建一体化信息系统。"
|
||||
},
|
||||
{
|
||||
"name": "投标限价数据",
|
||||
"synonymous": [],
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -62,8 +62,8 @@ class QueryRewriteProcessor:
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
model_name: str = None,
|
||||
dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY",
|
||||
dify_base_url: str = "http://172.20.0.145/v1"):
|
||||
dify_dataset_key: str = None,
|
||||
dify_base_url: str = None):
|
||||
"""
|
||||
初始化查询改写处理器
|
||||
|
||||
@@ -71,13 +71,17 @@ class QueryRewriteProcessor:
|
||||
api_key: API密钥,默认使用环境变量
|
||||
base_url: API基础URL,默认使用环境变量
|
||||
model_name: 模型名称,默认使用环境变量或默认模型
|
||||
dify_api_key: Dify API密钥
|
||||
dify_dataset_key: Dify API密钥
|
||||
dify_base_url: Dify API基础URL
|
||||
"""
|
||||
# 初始化意图识别器
|
||||
# 使用asyncio.run()运行异步create方法
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
||||
if not dify_dataset_key:
|
||||
dify_dataset_key = os.getenv("DIFY_DATASET_KEY")
|
||||
if not dify_base_url:
|
||||
dify_base_url = os.getenv("DIFY_BSAE_URL")
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=dify_dataset_key, dify_base_url=dify_base_url)
|
||||
|
||||
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -205,14 +209,13 @@ class QueryRewriteProcessor:
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
query_list = result["query_expand"]["all"]
|
||||
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
||||
# 将字典转换为Classification对象
|
||||
classification_obj = Classification(**classification)
|
||||
|
||||
# 根据enable_retrieval参数决定是否进行文档检索
|
||||
retrieved_doc = None
|
||||
if enable_retrieval:
|
||||
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
||||
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, current_softname)
|
||||
|
||||
# 判断检索文档是否相关
|
||||
relevance_result = {}
|
||||
@@ -439,9 +442,9 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
query="怎么把一个批次拆分成多个批次工程"
|
||||
query="怎么调整报表顺序"
|
||||
conversation_context={
|
||||
"current_softname": "配网计价通D3软件"
|
||||
"current_softname": "储能计价通C1软件"
|
||||
}
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
|
||||
@@ -27,26 +27,29 @@ class DifyQueryRetrieval:
|
||||
"西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)",
|
||||
"配网造价软件知识(new)"]}
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
def __init__(self, dify_dataset_key: str, dify_base_url: str):
|
||||
self._dify_dataset_key = dify_dataset_key
|
||||
self._dify_base_url = dify_base_url
|
||||
self._datasets_list = self.get_datasets_list()
|
||||
|
||||
def get_datasets_list(self) -> Dict[str, str]:
|
||||
client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url)
|
||||
client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url)
|
||||
datasets = client.list_datasets(page_size=50)
|
||||
datasets_json = datasets.json()
|
||||
return {dataset["name"]:dataset["id"] for dataset in datasets_json["data"]}
|
||||
return {dataset["name"]:dataset for dataset in datasets_json["data"]}
|
||||
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
knowledge_base_client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url, dataset_id=self._datasets_list[dataset_name])
|
||||
documents = knowledge_base_client.retrieve(query)
|
||||
dataset_id = self._datasets_list[dataset_name]["id"]
|
||||
retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
|
||||
knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id)
|
||||
|
||||
documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300)
|
||||
retrieved_documents = documents.json().get("records", [])
|
||||
|
||||
# 添加数据集信息
|
||||
for retrieved_document in retrieved_documents:
|
||||
retrieved_document["dataset_id"] = self._datasets_list[dataset_name]
|
||||
retrieved_document["dataset_id"] = dataset_id
|
||||
retrieved_document["dataset_name"] = dataset_name
|
||||
|
||||
return retrieved_documents
|
||||
@@ -78,6 +81,7 @@ class DifyQueryRetrieval:
|
||||
|
||||
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
||||
datasets = self.get_datasets_by_classification(classification, software_name)
|
||||
datasets=["电力建设计价通(2018)软件知识(new)", "主网造价知识(new)", "下载安装注册(new)"]
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
|
||||
@@ -103,6 +107,7 @@ class DifyQueryRetrieval:
|
||||
return await self.retrieve_api_async(original_query, query_list, datasets)
|
||||
|
||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||
ssss = self.retrieve_by_dataset("怎么调整报表顺序", "电力建设计价通(2018)软件知识(new)")
|
||||
all_documents=[]
|
||||
# 使用线程池替代无限制创建线程
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||
@@ -112,7 +117,7 @@ class DifyQueryRetrieval:
|
||||
futures = {}
|
||||
for query in query_list:
|
||||
for dataset in data_set_list:
|
||||
if dataset not in self._datasets_list:
|
||||
if dataset not in list(self._datasets_list.keys()):
|
||||
raise ValueError(f"dataset {dataset} not in datasets_list")
|
||||
|
||||
futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query
|
||||
@@ -140,7 +145,7 @@ class DifyQueryRetrieval:
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k)
|
||||
processed_documents = self.data_post_processor("怎么调整报表顺序", deduplicated_documents, top_k)
|
||||
time_end = time.time()
|
||||
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
@@ -166,7 +171,7 @@ class DifyQueryRetrieval:
|
||||
tasks = []
|
||||
for query in query_list:
|
||||
for dataset in data_set_list:
|
||||
if dataset not in self._datasets_list:
|
||||
if dataset not in list(self._datasets_list.keys()):
|
||||
logging.error(f"dataset {dataset} not in datasets_list")
|
||||
continue
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ dify_query_retrieval = None
|
||||
async def startup_event():
|
||||
global dify_query_retrieval
|
||||
# 初始化DifyQueryRetrieval实例
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
|
||||
logger.info("DifyQueryRetrieval初始化完成")
|
||||
|
||||
# 添加健康检查端点
|
||||
|
||||
@@ -19,7 +19,7 @@ for index, row in pd_data.iterrows():
|
||||
if "存在抱怨" in answer:
|
||||
answer = answer.split("存在抱怨")[0]
|
||||
|
||||
content = f"问题:{query}\n解决方案:{answer}"
|
||||
content = f"问题:{query}\n回答:{answer}"
|
||||
segments_list.append({
|
||||
"content": str(content),
|
||||
"answer": "",
|
||||
|
||||
@@ -8,7 +8,7 @@ class DifyClient:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def _send_request(self, method, endpoint, json=None, params=None, stream=False):
|
||||
def _send_request(self, method, endpoint, json=None, params=None, stream=False, timeout=300):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
@@ -16,7 +16,7 @@ class DifyClient:
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method, url, json=json, params=params, headers=headers, stream=stream, verify=False
|
||||
method, url, json=json, params=params, headers=headers, stream=stream, verify=False, timeout=timeout
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -117,9 +117,14 @@ class DifyExporter:
|
||||
intent_result = json.loads(intent_node_execution_info[0]["outputs"])
|
||||
vertical_classification = intent_result.get("vertical_classification", "")
|
||||
sub_classification = intent_result.get("sub_classification", "")
|
||||
if vertical_classification == "固定话术类":
|
||||
if sub_classification == "固定话术类":
|
||||
return "使用固定话术"
|
||||
|
||||
worker_node_execution_info = [node_execution_info for node_execution_info in msg_debug_info['workflow_node_executions_info']
|
||||
if node_execution_info["title"] == "检索工单数据"]
|
||||
if len(worker_node_execution_info) != 0:
|
||||
return "检索工单"
|
||||
|
||||
return ""
|
||||
|
||||
def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict:
|
||||
@@ -198,13 +203,12 @@ class DifyExporter:
|
||||
return None
|
||||
|
||||
wiki_list = self.get_wiki_list(msg_debug_info)
|
||||
# 获取备注
|
||||
remark = self.get_remark(msg_debug_info)
|
||||
|
||||
if len(wiki_list) ==0:
|
||||
wiki_list_str = self.get_remark(msg_debug_info)
|
||||
else:
|
||||
wiki_list = list(set(wiki_list))
|
||||
wiki_list_str = "\n".join(wiki_list)
|
||||
if wiki_list_str == "":
|
||||
wiki_list_str = "无"
|
||||
rating = self.dify_pgsql.get_message_rating(msg_id)
|
||||
# 直接通过字典键获取query_type
|
||||
workflow_run_id = message['workflow_run_id']
|
||||
@@ -220,7 +224,6 @@ class DifyExporter:
|
||||
"评价": rating,
|
||||
"问题分类": query_type,
|
||||
"检索到的词条": wiki_list_str,
|
||||
"备注": remark
|
||||
}
|
||||
|
||||
def process_conversations(self):
|
||||
@@ -274,7 +277,7 @@ class DifyExporter:
|
||||
# 设置列的顺序
|
||||
columns_order = [
|
||||
"msg_id","当前软件", "提问", "回答", "提问人", "提问时间",
|
||||
"评价", "问题分类", "检索到的词条", "备注"
|
||||
"评价", "问题分类", "检索到的词条"
|
||||
]
|
||||
|
||||
# 确保所有列都存在,如果不存在则添加空列
|
||||
|
||||
@@ -309,7 +309,7 @@ class IntentAndSlotResult(BaseModel):
|
||||
class StepBackPrompt(BaseModel):
|
||||
"""后退提示数据模型"""
|
||||
original_query: str = Field(description="原始查询")
|
||||
can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(True/False),如果原始查询没有限定词或其他限定词语,则不能进行后退提示")
|
||||
can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(true/false),如果原始查询没有限定词或其他限定词语,则不能进行后退提示")
|
||||
step_back_query: List[str] = Field(description="后退提示生成的抽象查询(多个)")
|
||||
|
||||
class FollowUpQuestions(BaseModel):
|
||||
|
||||
@@ -38,7 +38,7 @@ class SiliconFlowEmbeddings(Embeddings):
|
||||
"input": input,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
response = requests.post(self.url, json=payload, headers=self.headers)
|
||||
response = requests.post(self.url, json=payload, headers=self.headers, timeout=300)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
@@ -50,7 +50,7 @@ class SiliconFlowEmbeddings(Embeddings):
|
||||
"input": input,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(self.url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -101,7 +101,7 @@ class SiliconFlowReRankerModel:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=300)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||
@@ -138,7 +138,7 @@ class SiliconFlowReRankerModel:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
@@ -173,7 +173,7 @@ class XinferenceReRankerModel:
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=params, headers=headers)
|
||||
response = requests.post(url, json=params, headers=headers, timeout=300)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
results = response.json()
|
||||
|
||||
@@ -206,7 +206,7 @@ class XinferenceReRankerModel:
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
results = response.json()
|
||||
|
||||
Reference in New Issue
Block a user