新增工程信息、检索的知识片段节点回传、下一轮建议问题列表

This commit is contained in:
wanyaokun
2024-08-30 16:42:36 +08:00
parent 73565b26e4
commit e7628809ad
11 changed files with 411 additions and 134 deletions
+127 -50
View File
@@ -3,7 +3,6 @@ import json
import logging import logging
import time import time
from typing import Dict, List, Any, Optional, AsyncGenerator from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream from aiostream import stream
from fastapi import APIRouter, Request,HTTPException from fastapi import APIRouter, Request,HTTPException
@@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput from llama_index.core.tools import ToolOutput
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
from app.api.routers.request.baseConfig import * from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
from app.api.routers.services.fileServices import FileLoadService from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
from app.api.routers.services.suggestion import NextQuestionSuggestion
import time
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel): class ChatCallbackEvent(BaseModel):
@@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel):
def get_common_param(self)-> dict: def get_common_param(self)-> dict:
return { return {
'event': self.event_type.name, 'event': self.event_type.value,
'conversation_id':self.payload.get("conversation_id"), 'conversation_id':self.payload.get("conversation_id"),
'message_id': self.payload.get("message_id"), 'message_id': self.payload.get("message_id"),
'created_at': int(time.time()), 'created_at': int(time.time()),
@@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel):
"workflow_id": self.payload.get('workflow_id'), "workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709, "sequence_number": 1709,
"inputs": { "inputs": {
"sys.query": self.payload.get('query'), "sys.query": f"开始查询 {self.payload.get('query')}",
"sys.files": [], "sys.files": [],
"sys.conversation_id": self.payload.get('conversation_id'), "sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id') "sys.user_id": self.payload.get('use_id')
@@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'), "id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'), "node_id": self.payload.get('nodeid'),
"node_type": "http-request", "node_type": "http-request",
"title": self.payload.get('title'), "title": f"正在执行事件:{self.payload.get('title')}",
"index": self.payload.get('index'), "index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'), "predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '', "inputs": '',
@@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'), "id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'), "node_id": self.payload.get('nodeid'),
"node_type": "http-request", "node_type": "http-request",
"title": self.payload.get('title'), "title": f"事件执行结束:{self.payload.get('title')}",
"index": self.payload.get('index'), "index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'), "predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '', "inputs": '',
@@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel):
def get_MessageEnd_param(self) -> dict: def get_MessageEnd_param(self) -> dict:
params = self.get_common_param() params = self.get_common_param()
nodeInfos = []
source_nodes = self.payload.get('source_node')
if source_nodes is not None:
for i in range(len(source_nodes)):
source_node:NodeWithScore = source_nodes[i]
metadata:dict = source_node.node.metadata
nodeInfo = {
"position": i,
"dataset_id": metadata.get("pipeline_id"),
"dataset_name": metadata.get("file_name"),
"document_id": source_node.node_id,
"document_name": metadata.get("file_name"),
"data_source_type": "upload_file",
"segment_id": source_node.node_id,
"retriever_from": "workflow",
"score": source_node.score,
"hit_count": 1,
"word_count": 632,
"segment_position": i,
"index_node_hash": "",
"content": source_node.text
}
nodeInfos.append(nodeInfo)
params.update({ params.update({
'id':self.payload.get('message_id'), 'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata') 'metadata':{
"retriever_resources":nodeInfos,
"usage":{
"prompt_tokens": 4972,
"prompt_unit_price": "0.0",
"prompt_price_unit": "0.0",
"prompt_price": "0.0",
"completion_tokens": 332,
"completion_unit_price": "0.0",
"completion_price_unit": "0.0",
"completion_price": "0.0",
"total_tokens": 5304,
"total_price": "0.0",
"currency": "USD",
"latency": 4.897703120019287
}
}
}) })
return params return params
def to_response(self)-> dict|None: def to_response(self)-> dict|None:
try: try:
match self.event_type: match self.event_type.value:
case "workflow_started": case "workflow_started":
return self.get_WorkflowStart_param() return self.get_WorkflowStart_param()
case "workflow_finished": case "workflow_finished":
@@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
] ]
super().__init__(ignored_events, ignored_events) super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue() self._aqueue = asyncio.Queue()
self._response:str = '' self._response: StreamingAgentChatResponse = None
self._params:Dict[str,Any] = params self._params:Dict[str,Any] = params
self._nodeStack:deque = deque() self._nodeStack:List[str] = []
#添加工作流开始事件 #添加工作流开始事件
data:ChatRequestData = self._params['data'] wf_event = self.makeWorkflow_startEvent()
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
def setResponse(self,response: StreamingAgentChatResponse):
self._response = response
def on_event_start( def on_event_start(
self, self,
event_type: CBEventType, event_type: CBEventType,
@@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id) self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1 nindex = len(self._nodeStack) - 1
args:Dict[str,Any] = self._params['ids'] args:Dict[str,Any] = self._params['ids']
args.update( args.update(
{ {
@@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
if nd_event.to_response() is not None: if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
def on_event_end( def on_event_end(
self, self,
event_type: CBEventType, event_type: CBEventType,
@@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
args:Dict[str,Any] = self._params['ids'] args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1] nodeID = self._nodeStack[-1]
if nodeID == event_id: if nodeID == event_id:
nindex = self._nodeStack.count() - 1 nindex = len(self._nodeStack) - 1
args.update( args.update(
{ {
'nodeid':event_id, 'nodeid':event_id,
@@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
self._nodeStack.pop() self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None: def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op.""" """No-op."""
logger.info("trace_start:{}\n".format(trace_id)) logger.info("trace_start:{}\n".format(trace_id))
@@ -262,21 +294,12 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""No-op.""" """No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids'] wf_event = self.makeWorkflow_finishedEvent()
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
msgEnt_event = self.makeMessage_EndEvent()
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None: if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event) self._aqueue.put_nowait(msgEnt_event)
@@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response': '',
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
def makeMessage_EndEvent(self)->ChatCallbackEvent:
args:Dict[str,Any] = self._params['ids']
if self._response is not None:
args.update({
'source_node': self._response.source_nodes
})
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
return msgEnt_event
class IDManager: class IDManager:
def createID(self): def createID(self):
return { return {
@@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
params = data.inputs or {} params = data.inputs or {}
# 获取聊天引擎对象 # 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag) chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听 # 启动聊天事件监听
ids = IDManager().createID() ids = IDManager().createID()
@@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
# 执行异步聊天 # 执行异步聊天
response = await chat_engine.astream_chat(data.query) response = await chat_engine.astream_chat(data.query)
event_handler.setResponse(response)
# 返回异步消息回应 # 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data,ids) return ChatStreamResponse(request, event_handler, response, data,ids)
@@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
@v.get("/parameters") @v.get("/parameters")
async def query_parameters(user:str): async def query_parameters(user:str):
params = parameter().get(user) prjObj = ProjectInfo()
if len(params) == 0: return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
params = BaseConfig().ParamterCfg()
return params
@v.post("/messages/{message_id}/feedbacks") @v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]): async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null': if params['rating'] is None:
feedback().delete(message_id) feedback().delete(message_id)
else: else:
condition = {'id':message_id} results = message().query(message_id)
results = message().query(**condition)
if len(results) > 0: if len(results) > 0:
result = results[0] result = results[0]
feedback().add(message_id=message_id,query=result['query'], feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating']) answer=result['answer'],rating=params['rating'])
@v.post("") @v.post("/files/upload")
def upload_file(request: ChatFileUploadRequest) -> List[str]: def upload_file(request: ChatFileUploadRequest):
try: try:
logger.info("Processing file") logger.info("Processing file")
return FileLoadService.process_file(request.base64) resluts = ChatFileService.process_file(request.base64)
return {
'id':resluts.get('id'),
'name': resluts.get('name'),
'size': resluts.get('size'),
'extension':resluts.get('extension'),
'mime_type':resluts.get('mime_type'),
'created_by':str(uuid.uuid4()),
'created_at':int(time.time())
}
except Exception as e: except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True) logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file") raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/applications")
def upload_file(request: ChatFileUploadRequest):
try:
logger.info("Processing file")
return PrjFileLoadService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/messages/{message_id}/suggested")
async def post_suggested(request: Request,message_id:str,user:str):
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
return {
"result": "success",
"data":questions
}
+33 -2
View File
@@ -2,7 +2,7 @@ from datetime import datetime
import uuid import uuid
from app.api.routers.request.baseConfig import BaseConfig from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager from app.api.routers.request.dbOrm import DBManager
from typing import List
dbManage = DBManager() dbManage = DBManager()
class conversations: class conversations:
@@ -122,8 +122,9 @@ class message:
def delete(self,user_id:str): def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id) dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition): def query(self,id:str):
results = [] results = []
condition = {'id':id}
records = dbManage.query(self._tableName,**condition) records = dbManage.query(self._tableName,**condition)
for record in records: for record in records:
results.append(record.dict()) results.append(record.dict())
@@ -153,3 +154,33 @@ class feedback:
if len(records) > 0: if len(records) > 0:
return records[0].dict() return records[0].dict()
return None return None
class ProjectInfo:
def __init__(self) -> None:
self._tableName = 'projectInfos'
dbManage.createTable(self._tableName)
def add(self,name:str,flag:str):
record = {
'prjectName': name,
'prjFlag': flag
}
dbManage.addRecord(self._tableName,record)
def projectNames(self)->List[str]:
records = dbManage.query(self._tableName)
names = []
for record in records:
data:dict = record.dict()
name = data.get('prjectName')
if name !='':
names.append(name)
return names
def prjFalg(self,name:str):
records = dbManage.query(self._tableName)
for record in records:
data:dict = record.dict()
if data.get('prjectName') == name:
return data['prjFlag']
return ''
+14 -2
View File
@@ -5,7 +5,8 @@ from enum import Enum
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!") projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
def ParamterCfg(self): def ParamterCfg(self,**args):
projectInfo = args.get('projectInfo')
questions = os.getenv("CONVERSATION_STARTERS", "dev") questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{ return{
"opening_statement": self.projectInfo, "opening_statement": self.projectInfo,
@@ -30,7 +31,18 @@ class BaseConfig(BaseModel):
"more_like_this": { "more_like_this": {
"enabled": False "enabled": False
}, },
"user_input_form": [], "user_input_form": [
{
"select": {
"variable": "projectname",
"label": "\u5de5\u7a0b\u540d\u79f0",
"type": "select",
"max_length": 48,
"required": True,
"options": [projectInfo]
}
}
],
"sensitive_word_avoidance": { "sensitive_word_avoidance": {
"enabled": False "enabled": False
}, },
+18
View File
@@ -55,6 +55,13 @@ class FeedBackOrm(Base):
answer = Column(String) answer = Column(String)
rating = Column(String) rating = Column(String)
class ProjectInfoOrm(Base):
__tablename__ = "projectInfos"
prjFlag = Column(String,primary_key=True)
prjectName = Column(String)
#数据结构 #数据结构
class ConversationModel(BaseModel): class ConversationModel(BaseModel):
id: str id: str
@@ -121,6 +128,17 @@ class FeedBackModel(BaseModel):
def orm(cls): def orm(cls):
return FeedBackOrm return FeedBackOrm
class ProjectInfoModel(BaseModel):
prjectName:str
prjFlag:str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ProjectInfoOrm
class DBManager: class DBManager:
def __init__(self) -> None: def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
@@ -10,7 +10,6 @@ class ChatRequestData(BaseModel):
response_mode: str response_mode: str
files: Any files: Any
conversation_id: str = None conversation_id: str = None
prjFlag:Optional[str] = ''
class ChatFileUploadRequest(BaseModel): class ChatFileUploadRequest(BaseModel):
base64: str base64: str
@@ -1,16 +1,23 @@
import base64,os import base64,os,mimetypes,requests,tempfile
from typing import List from typing import List,Dict,Any
from uuid import uuid4 from uuid import uuid4
import requests
from app.settings import init_settings from app.settings import init_settings
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.engine.generate import get_doc_store,run_pipeline,persist_storage from app.engine.generate import get_doc_store,run_pipeline,persist_storage
import tempfile from llama_index.core.schema import Document
from pathlib import Path
from llama_index.core.readers.file.base import (
_try_loading_included_file_formats as get_file_loaders_map,
)
from llama_index.readers.file import FlatReader
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core import VectorStoreIndex
from app.engine.index import get_index
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
class FileLoadService: class PrjFileLoadService:
@staticmethod @staticmethod
def store_and_parse_file(file_data): def store_and_parse_file(file_data):
prjtoJson_url = os.getenv('PRJTOJSON_URL') prjtoJson_url = os.getenv('PRJTOJSON_URL')
@@ -20,12 +27,19 @@ class FileLoadService:
url = convert_url, url = convert_url,
files=files files=files
) )
if response1.text is None or response1.text=='':
return None
load_url = prjtoJson_url +'/file_download' load_url = prjtoJson_url +'/file_download'
response2 = requests.post( response2 = requests.post(
url = load_url, url = load_url,
data=response1.text data=response1.text
) )
tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip" if response2.text is None or response2.content=='':
return None
try:
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
with open(tempFilePath,'wb') as file: with open(tempFilePath,'wb') as file:
file.write(response2.content) file.write(response2.content)
@@ -39,10 +53,14 @@ class FileLoadService:
zip_File.extract(zip_info,filePath) zip_File.extract(zip_info,filePath)
os.remove(tempFilePath) os.remove(tempFilePath)
return f'Projects_{prjID}' return f'Projects_{prjID}'
except Exception as e:
return None
@staticmethod @staticmethod
def process_file(base64_content: str) -> str: def process_file(base64_content: str) -> str:
prjFlag = FileLoadService.store_and_parse_file(base64_content) prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
if prjFlag is None:
return None
#生成向量并持久化至本地 #生成向量并持久化至本地
documents = get_documents(prjFlag) documents = get_documents(prjFlag)
for doc in documents: for doc in documents:
@@ -53,3 +71,64 @@ class FileLoadService:
persist_storage(docstore, vector_store) persist_storage(docstore, vector_store)
return prjFlag return prjFlag
class ChatFileService:
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
resluts:Dict[str,Any] = {}
@staticmethod
def process_file(base64_content: str) -> dict:
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
documents = ChatFileService.store_and_parse_file(file_data, extension)
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
current_index = get_index()
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
if current_index is None:
current_index = VectorStoreIndex(nodes=nodes)
else:
current_index.insert_nodes(nodes=nodes)
current_index.storage_context.persist(
persist_dir=os.environ.get("STORAGE_DIR", "storage")
)
return ChatFileService.resluts
@staticmethod
def preprocess_base64_file(base64_content: str) -> tuple:
header, data = base64_content.split(",", 1)
mime_type = header.split(";")[0].split(":", 1)[1]
extension = mimetypes.guess_extension(mime_type)
ChatFileService.resluts['mime_type'] = mime_type
ChatFileService.resluts['extension'] = extension
return base64.b64decode(data), extension
@staticmethod
def store_and_parse_file(file_data, extension) -> List[Document]:
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
fileID = uuid4().hex
file_name = f"{fileID}{extension}"
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
ChatFileService.resluts['id'] = fileID
ChatFileService.resluts['file_name'] = file_name
with open(file_path, "wb") as f:
f.write(file_data)
ChatFileService.resluts['size'] = os.path.getsize(file_path)
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
if reader_cls is None:
raise ValueError(f"File extension {extension} is not supported")
reader = reader_cls()
documents = reader.load_data(file_path)
for doc in documents:
doc.metadata["file_name"] = file_name
doc.metadata["private"] = "true"
return documents
@staticmethod
def default_file_loaders_map():
default_loaders = get_file_loaders_map()
default_loaders[".txt"] = FlatReader
return default_loaders
@@ -0,0 +1,43 @@
from typing import List
from app.api.routers.request.base import message
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
from pydantic import BaseModel
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
"\n这是对话历史记录"
"\n---------------------\n{conversation}\n---------------------"
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
)
N_QUESTION_TO_GENERATE = 3
class NextQuestions(BaseModel):
"""A list of questions that user might ask next"""
questions: List[str]
class NextQuestionSuggestion:
@staticmethod
async def suggest_next_questions(
message_id: str,
number_of_questions: int = N_QUESTION_TO_GENERATE,
) -> List[str]:
last_user_message = None
last_assistant_message = None
results = message().query(message_id)
if len(results) > 0:
last_user_message = results[0]['query']
last_assistant_message = results[0]['answer']
conversation: str = f"{last_user_message}\n{last_assistant_message}"
output: NextQuestions = await Settings.llm.astructured_predict(
NextQuestions,
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
conversation=conversation,
nun_questions=number_of_questions,
)
return output.questions
return []
+18 -2
View File
@@ -8,9 +8,18 @@ from app.engine.engine import create_query_engine, create_summary_query_engine
from app.engine.index import get_index from app.engine.index import get_index
#from app.engine.loaders.db import makeDescriptionByEngine #from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory from app.engine.tools import ToolFactory
from app.api.routers.request.base import ProjectInfo
def getPrjFalg(params:dict=None)->str:
prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = ProjectInfo.prjFalg(inputs.get('projectname'))
return prjFlag
def get_chat_engine(filters=None, params=None,**args): def get_chat_engine(filters=None, params:dict=None):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", "3")) top_k = int(os.getenv("TOP_K", "3"))
use_reranker = os.getenv("RERANK_ENABLED") use_reranker = os.getenv("RERANK_ENABLED")
@@ -24,7 +33,13 @@ def get_chat_engine(filters=None, params=None,**args):
#tools.append(sql_query_tool) #tools.append(sql_query_tool)
# Add query tool if index exists # Add query tool if index exists
index = get_index(**args) prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = inputs.get('projectname')
index = get_index(prjFlag = getPrjFalg(params))
if index is not None: if index is not None:
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
@@ -57,6 +72,7 @@ def get_chat_engine(filters=None, params=None,**args):
verbose=True, verbose=True,
) )
return agentrunner return agentrunner
# create the function calling worker for reasoning # create the function calling worker for reasoning
# worker = FunctionCallingAgentWorker.from_tools( # worker = FunctionCallingAgentWorker.from_tools(
# tools, verbose=True # tools, verbose=True
+3 -3
View File
@@ -7,15 +7,15 @@ logger = logging.getLogger("uvicorn")
def get_index(**args): def get_index(**args):
logger.info("Connecting vector store...") logger.info("Connecting vector store...")
if 'prjFlag' in args:
prjFlags = get_document_Types() prjFlags = get_document_Types()
if len(prjFlags)<=0: if len(prjFlags)<=0:
return None return None
prjFlag = args.get('prjFlag','') prjFlag = args.get('prjFlag','')
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
else:
flag = ''
store = get_vector_store(flag) store = get_vector_store(flag)
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store) index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.") logger.info("Finished load index from vector store.")
return index return index
+2
View File
@@ -58,10 +58,12 @@ def get_document_Types():
def get_documents(docType:str): def get_documents(docType:str):
documents = [] documents = []
config = load_configs() config = load_configs()
if config is None or len(config.items()) == 0: if config is None or len(config.items()) == 0:
return documents return documents
for loader_type, loader_config in config.items(): for loader_type, loader_config in config.items():
if loader_config.get('enable', True): # 检查 enable 字段
logger.info( logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}" f"Loading documents from loader: {loader_type}, config: {loader_config}"
) )
+32 -32
View File
@@ -3,46 +3,46 @@ file:
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
use_llama_parse: false use_llama_parse: false
db: #db:
# The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
# query: The query to fetch data from the database. E.g.: SELECT * FROM table # query: The query to fetch data from the database. E.g.: SELECT * FROM table
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 #- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
enable: true # 添加 enable 字段 #enable: false # 添加 enable 字段
queries: #queries:
- sql: select * from ProjectProperties; #- sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" #explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; #- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" #explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; #- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" #explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web: #web:
# driver_arguments: # driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode