diff --git a/rag2_0/demo/intent_recognition_example.py b/rag2_0/demo/intent_recognition_example.py index b136309..6d0b3d6 100755 --- a/rag2_0/demo/intent_recognition_example.py +++ b/rag2_0/demo/intent_recognition_example.py @@ -62,8 +62,8 @@ class QueryRewriteProcessor: api_key: str = None, base_url: str = None, model_name: str = None, - dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY", - dify_base_url: str = "http://172.20.0.145/v1"): + dify_dataset_key: str = None, + dify_base_url: str = None): """ 初始化查询改写处理器 @@ -71,13 +71,17 @@ class QueryRewriteProcessor: api_key: API密钥,默认使用环境变量 base_url: API基础URL,默认使用环境变量 model_name: 模型名称,默认使用环境变量或默认模型 - dify_api_key: Dify API密钥 + dify_dataset_key: Dify API密钥 dify_base_url: Dify API基础URL """ # 初始化意图识别器 # 使用asyncio.run()运行异步create方法 self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create()) - self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url) + if not dify_dataset_key: + dify_dataset_key = os.getenv("DIFY_DATASET_KEY") + if not dify_base_url: + dify_base_url = os.getenv("DIFY_BSAE_URL") + self.dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=dify_dataset_key, dify_base_url=dify_base_url) def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]: """ @@ -205,14 +209,13 @@ class QueryRewriteProcessor: classification = result["classification"] original_query = result["rewrite"]["rewrite"] query_list = result["query_expand"]["all"] - soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","") # 将字典转换为Classification对象 classification_obj = Classification(**classification) # 根据enable_retrieval参数决定是否进行文档检索 retrieved_doc = None if enable_retrieval: - retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name) + retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, current_softname) # 判断检索文档是否相关 relevance_result = {} @@ -439,9 +442,9 @@ def main(): for idx, query in enumerate(examples): if query.strip() == "": continue - query="怎么把一个批次拆分成多个批次工程" + query="怎么调整报表顺序" conversation_context={ - "current_softname": "配网计价通D3软件" + "current_softname": "储能计价通C1软件" } # 在调试模式下使用完整的参数 print(json.dumps(processor.process_query( diff --git a/rag2_0/dify/DifyQueryRetrieval.py b/rag2_0/dify/DifyQueryRetrieval.py index 38b8568..07dadec 100644 --- a/rag2_0/dify/DifyQueryRetrieval.py +++ b/rag2_0/dify/DifyQueryRetrieval.py @@ -27,26 +27,29 @@ class DifyQueryRetrieval: "西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)", "配网造价软件知识(new)"]} - def __init__(self, api_key: str, base_url: str): - self._api_key = api_key - self._base_url = base_url + def __init__(self, dify_dataset_key: str, dify_base_url: str): + self._dify_dataset_key = dify_dataset_key + self._dify_base_url = dify_base_url self._datasets_list = self.get_datasets_list() def get_datasets_list(self) -> Dict[str, str]: - client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url) + client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url) datasets = client.list_datasets(page_size=50) datasets_json = datasets.json() - return {dataset["name"]:dataset["id"] for dataset in datasets_json["data"]} + return {dataset["name"]:dataset for dataset in datasets_json["data"]} def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]: try: - knowledge_base_client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url, dataset_id=self._datasets_list[dataset_name]) - documents = knowledge_base_client.retrieve(query, timeout=300) + dataset_id = self._datasets_list[dataset_name]["id"] + retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"] + knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id) + + documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300) retrieved_documents = documents.json().get("records", []) # 添加数据集信息 for retrieved_document in retrieved_documents: - retrieved_document["dataset_id"] = self._datasets_list[dataset_name] + retrieved_document["dataset_id"] = dataset_id retrieved_document["dataset_name"] = dataset_name return retrieved_documents @@ -78,6 +81,7 @@ class DifyQueryRetrieval: def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: datasets = self.get_datasets_by_classification(classification, software_name) + datasets=["电力建设计价通(2018)软件知识(new)", "主网造价知识(new)", "下载安装注册(new)"] if len(datasets) == 0: return None @@ -103,6 +107,7 @@ class DifyQueryRetrieval: return await self.retrieve_api_async(original_query, query_list, datasets) def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: + ssss = self.retrieve_by_dataset("怎么调整报表顺序", "电力建设计价通(2018)软件知识(new)") all_documents=[] # 使用线程池替代无限制创建线程 # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 @@ -112,7 +117,7 @@ class DifyQueryRetrieval: futures = {} for query in query_list: for dataset in data_set_list: - if dataset not in self._datasets_list: + if dataset not in list(self._datasets_list.keys()): raise ValueError(f"dataset {dataset} not in datasets_list") futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query @@ -140,10 +145,10 @@ class DifyQueryRetrieval: # 对所有检索出来的文档进行重排序 time_start = time.time() - processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k) + processed_documents = self.data_post_processor("怎么调整报表顺序", deduplicated_documents, top_k) time_end = time.time() logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒") - + return processed_documents async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: @@ -166,7 +171,7 @@ class DifyQueryRetrieval: tasks = [] for query in query_list: for dataset in data_set_list: - if dataset not in self._datasets_list: + if dataset not in list(self._datasets_list.keys()): logging.error(f"dataset {dataset} not in datasets_list") continue diff --git a/rag2_0/dify/DifyQueryRetrieval_api.py b/rag2_0/dify/DifyQueryRetrieval_api.py index 4e6c508..ca80adc 100644 --- a/rag2_0/dify/DifyQueryRetrieval_api.py +++ b/rag2_0/dify/DifyQueryRetrieval_api.py @@ -64,7 +64,7 @@ dify_query_retrieval = None async def startup_event(): global dify_query_retrieval # 初始化DifyQueryRetrieval实例 - dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") + dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL")) logger.info("DifyQueryRetrieval初始化完成") # 添加健康检查端点 diff --git a/rag2_0/dify/WorkorderToDify.py b/rag2_0/dify/WorkorderToDify.py index d74e01a..dd994ce 100644 --- a/rag2_0/dify/WorkorderToDify.py +++ b/rag2_0/dify/WorkorderToDify.py @@ -19,7 +19,7 @@ for index, row in pd_data.iterrows(): if "存在抱怨" in answer: answer = answer.split("存在抱怨")[0] - content = f"问题:{query}\n解决方案:{answer}" + content = f"问题:{query}\n回答:{answer}" segments_list.append({ "content": str(content), "answer": "", diff --git a/rag2_0/dify/dify_client/client.py b/rag2_0/dify/dify_client/client.py index 972c97b..1646810 100755 --- a/rag2_0/dify/dify_client/client.py +++ b/rag2_0/dify/dify_client/client.py @@ -8,7 +8,7 @@ class DifyClient: self.api_key = api_key self.base_url = base_url - def _send_request(self, method, endpoint, json=None, params=None, stream=False): + def _send_request(self, method, endpoint, json=None, params=None, stream=False, timeout=300): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", @@ -16,7 +16,7 @@ class DifyClient: url = f"{self.base_url}{endpoint}" response = requests.request( - method, url, json=json, params=params, headers=headers, stream=stream, verify=False + method, url, json=json, params=params, headers=headers, stream=stream, verify=False, timeout=timeout ) return response diff --git a/rag2_0/intent_recognition/DataModels.py b/rag2_0/intent_recognition/DataModels.py index 8158418..546f8bf 100755 --- a/rag2_0/intent_recognition/DataModels.py +++ b/rag2_0/intent_recognition/DataModels.py @@ -309,7 +309,7 @@ class IntentAndSlotResult(BaseModel): class StepBackPrompt(BaseModel): """后退提示数据模型""" original_query: str = Field(description="原始查询") - can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(True/False),如果原始查询没有限定词或其他限定词语,则不能进行后退提示") + can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(true/false),如果原始查询没有限定词或其他限定词语,则不能进行后退提示") step_back_query: List[str] = Field(description="后退提示生成的抽象查询(多个)") class FollowUpQuestions(BaseModel):