125 lines
4.8 KiB
Python
Executable File
125 lines
4.8 KiB
Python
Executable File
from gevent import monkey
|
|
monkey.patch_all()
|
|
|
|
from flask import Flask, request, Response
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import json
|
|
import time
|
|
from gevent.lock import RLock
|
|
|
|
|
|
import datetime
|
|
import logging
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
from rag2_0.intent_recognition import IntentRecognizer
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = Flask(__name__)
|
|
|
|
# 创建线程锁,用于保护共享资源
|
|
recognizer_lock = RLock()
|
|
|
|
# 使用单例模式创建意图识别器
|
|
class RecognizerSingleton:
|
|
_instance = None
|
|
_lock = RLock()
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
|
cls._instance = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
|
return cls._instance
|
|
|
|
@app.route('/intent_recognize', methods=['POST'])
|
|
def intent_recognize():
|
|
try:
|
|
data = request.get_json(force=True)
|
|
query = data.get('query')
|
|
conversation_context = data.get('conversation_context', "")
|
|
chat_history = data.get('chat_history', None)
|
|
previous_slots = data.get('previous_slots', None)
|
|
if not query:
|
|
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
|
|
|
start_time = time.time()
|
|
|
|
# 获取单例实例并使用线程锁保护关键操作
|
|
recognizer = RecognizerSingleton.get_instance()
|
|
result = recognizer.process_query(query=query,
|
|
conversation_context=conversation_context,
|
|
chat_history=chat_history,
|
|
previous_slots=previous_slots,
|
|
use_jieba=False,
|
|
enable_query_expansion=True)
|
|
|
|
end_time = time.time()
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
|
logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒")
|
|
|
|
# 提取分类信息
|
|
classification = result["classification"]
|
|
|
|
# 提取关键词信息
|
|
keywords = result["keywords"]
|
|
keywords_str = ""
|
|
if keywords and keywords.get("terms"):
|
|
term_details = []
|
|
for term in keywords["terms"]:
|
|
term_info = {
|
|
"名称": term["name"],
|
|
# "同义词": ";".join(term["synonymous"]) if term["synonymous"] else [],
|
|
# "描述": term["description"]
|
|
}
|
|
term_details.append(term_info)
|
|
keywords_str = term_details
|
|
|
|
# 提取槽位填充信息
|
|
slot_filling = result.get("slot_filling", {})
|
|
|
|
response_result = {
|
|
"source_query": query,
|
|
"source_query_keys": result["query_keys"],
|
|
"vertical_classification": classification["vertical_classification"],
|
|
"sub_classification": classification["sub_classification"],
|
|
"rewrite_query": result["rewrite"]["rewrite"],
|
|
"keywords": keywords_str,
|
|
"has_slot_filling": len(slot_filling)!=0,
|
|
"slot_filling": {
|
|
"is_complete": slot_filling.get("is_complete", False),
|
|
"missing_slots": slot_filling.get("missing_slots", {}),
|
|
"filled_data": slot_filling.get("filled_data", {})
|
|
},
|
|
"query_expand": result["query_expand"]
|
|
}
|
|
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
|
except Exception as e:
|
|
logger.error(f"意图识别出错: {str(e)}",exc_info=True)
|
|
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
|
|
|
if __name__ == "__main__":
|
|
# 开发环境使用Flask内置服务器
|
|
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
|
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
|
app.run(host="0.0.0.0", port=8001, threaded=True) |