上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
from flask import Flask, request, Response
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import json
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 初始化意图识别器
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
@app.route('/intent_recognize', methods=['POST'])
|
||||
def intent_recognize():
|
||||
try:
|
||||
data = request.get_json(force=True)
|
||||
query = data.get('query')
|
||||
if not query:
|
||||
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
||||
start_time = time.time()
|
||||
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||
end_time = time.time()
|
||||
print(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||
# keywords对象转为字符串
|
||||
keywords_str = ""
|
||||
if keywords and keywords.terms:
|
||||
term_details = []
|
||||
for term in keywords.terms:
|
||||
term_info = {
|
||||
"名称": term.name,
|
||||
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||
"描述": term.description
|
||||
}
|
||||
term_details.append(term_info)
|
||||
keywords_str = term_details
|
||||
result = {
|
||||
"source_query": query,
|
||||
"source_query_keys": query_keys,
|
||||
"vertical_classification": classification.vertical_classification,
|
||||
"sub_classification": classification.sub_classification,
|
||||
"rewrite_query": rewrite.rewrite,
|
||||
"keywords": keywords_str
|
||||
}
|
||||
return Response(json.dumps(result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||
except Exception as e:
|
||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8001)
|
||||
Reference in New Issue
Block a user