更新环境变量配置,调整模型名称获取方式,新增Dify API相关配置,删除无用的脚本文件,优化意图识别逻辑,添加LLM提取词条逻辑
This commit is contained in:
@@ -75,15 +75,8 @@ class QueryRewriteProcessor:
|
||||
dify_base_url: Dify API基础URL
|
||||
"""
|
||||
# 初始化意图识别器
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
# 使用asyncio.run()运行异步create方法
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
model_name=self.model_name
|
||||
))
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
||||
|
||||
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -174,7 +167,7 @@ class QueryRewriteProcessor:
|
||||
return []
|
||||
|
||||
def process_query(self, query: str,
|
||||
conversation_context: str = "",
|
||||
conversation_context: Dict = None,
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
previous_slots: Dict[str, str] = None,
|
||||
enable_retrieval: bool = False):
|
||||
@@ -196,12 +189,17 @@ class QueryRewriteProcessor:
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
if conversation_context is None:
|
||||
conversation_context = {}
|
||||
|
||||
current_softname = conversation_context.get("current_softname", "")
|
||||
result = asyncio.run(self.recognizer_async.process_query_async(query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_query_expansion=True,
|
||||
use_jieba=True))
|
||||
use_jieba=True,
|
||||
cur_soft_name=current_softname))
|
||||
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
@@ -414,7 +412,7 @@ def main():
|
||||
# 从环境变量中获取配置,命令行参数优先
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
enable_retrieval = args.enable_retrieval
|
||||
|
||||
# 初始化查询改写处理器
|
||||
@@ -441,8 +439,10 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
query="811619150828能看一下这个锁是16的马"
|
||||
conversation_context="当前使用软件:配网计价通D3软件"
|
||||
query="怎么把一个批次拆分成多个批次工程"
|
||||
conversation_context={
|
||||
"current_softname": "配网计价通D3软件"
|
||||
}
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
query,
|
||||
|
||||
Reference in New Issue
Block a user