更新DifyQueryRetrieval类的初始化参数,改为使用环境变量获取API密钥和基础URL;优化意图识别示例中的参数传递;调整问题和回答的格式描述;增加请求超时设置。
This commit is contained in:
@@ -62,8 +62,8 @@ class QueryRewriteProcessor:
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
model_name: str = None,
|
||||
dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY",
|
||||
dify_base_url: str = "http://172.20.0.145/v1"):
|
||||
dify_dataset_key: str = None,
|
||||
dify_base_url: str = None):
|
||||
"""
|
||||
初始化查询改写处理器
|
||||
|
||||
@@ -71,13 +71,17 @@ class QueryRewriteProcessor:
|
||||
api_key: API密钥,默认使用环境变量
|
||||
base_url: API基础URL,默认使用环境变量
|
||||
model_name: 模型名称,默认使用环境变量或默认模型
|
||||
dify_api_key: Dify API密钥
|
||||
dify_dataset_key: Dify API密钥
|
||||
dify_base_url: Dify API基础URL
|
||||
"""
|
||||
# 初始化意图识别器
|
||||
# 使用asyncio.run()运行异步create方法
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
||||
if not dify_dataset_key:
|
||||
dify_dataset_key = os.getenv("DIFY_DATASET_KEY")
|
||||
if not dify_base_url:
|
||||
dify_base_url = os.getenv("DIFY_BSAE_URL")
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=dify_dataset_key, dify_base_url=dify_base_url)
|
||||
|
||||
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -205,14 +209,13 @@ class QueryRewriteProcessor:
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
query_list = result["query_expand"]["all"]
|
||||
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
||||
# 将字典转换为Classification对象
|
||||
classification_obj = Classification(**classification)
|
||||
|
||||
# 根据enable_retrieval参数决定是否进行文档检索
|
||||
retrieved_doc = None
|
||||
if enable_retrieval:
|
||||
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
||||
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, current_softname)
|
||||
|
||||
# 判断检索文档是否相关
|
||||
relevance_result = {}
|
||||
@@ -439,9 +442,9 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
query="怎么把一个批次拆分成多个批次工程"
|
||||
query="怎么调整报表顺序"
|
||||
conversation_context={
|
||||
"current_softname": "配网计价通D3软件"
|
||||
"current_softname": "储能计价通C1软件"
|
||||
}
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
|
||||
Reference in New Issue
Block a user