1、修改api文件位置
2、意图识别继承langfuse
This commit is contained in:
@@ -23,12 +23,13 @@ class DifyQueryRetrieval:
|
||||
datasets_json = datasets.json()
|
||||
return {dataset["name"]:dataset for dataset in datasets_json["data"]}
|
||||
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
|
||||
try:
|
||||
dataset_id = self._datasets_list[dataset_name]["id"]
|
||||
retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
|
||||
knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id)
|
||||
|
||||
if len(metadata_filtering_conditions) !=0:
|
||||
retrieval_model["metadata_filtering_conditions"]=metadata_filtering_conditions
|
||||
documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300)
|
||||
retrieved_documents = documents.json().get("records", [])
|
||||
|
||||
@@ -51,7 +52,7 @@ class DifyQueryRetrieval:
|
||||
"documents": []
|
||||
}
|
||||
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_by_dataset方法
|
||||
|
||||
@@ -67,7 +68,8 @@ class DifyQueryRetrieval:
|
||||
return await asyncio.to_thread(
|
||||
self.retrieve_by_dataset,
|
||||
query,
|
||||
dataset_name
|
||||
dataset_name,
|
||||
metadata_filtering_conditions
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||
@@ -77,7 +79,13 @@ class DifyQueryRetrieval:
|
||||
"documents": []
|
||||
}
|
||||
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], query_expand_dict: dict, top_k: int = 5)->Dict[str, Any]:
|
||||
async def retrieve_api_async(self,
|
||||
original_query: str,
|
||||
query_list: List[str],
|
||||
data_set_list: List[str],
|
||||
query_expand_dict: dict,
|
||||
top_k: int = 5,
|
||||
metadata_filtering_conditions:dict = {})->Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||
|
||||
@@ -105,7 +113,7 @@ class DifyQueryRetrieval:
|
||||
continue
|
||||
|
||||
# 创建异步任务
|
||||
task = self.retrieve_by_dataset_async(query, dataset)
|
||||
task = self.retrieve_by_dataset_async(query, dataset, metadata_filtering_conditions)
|
||||
tasks.append(task)
|
||||
|
||||
# 并发执行所有异步任务
|
||||
|
||||
Reference in New Issue
Block a user