工程名称下拉项获取兼容.md文件,同时新增自定义答案合成类
This commit is contained in:
@@ -10,6 +10,8 @@ from sqlalchemy import create_engine
|
|||||||
from util.register import *
|
from util.register import *
|
||||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||||
|
from app.engine.response.treeSummResponse import CustomTreeResponse
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
|
||||||
ModelPlateCategory = '模型平台'
|
ModelPlateCategory = '模型平台'
|
||||||
|
|
||||||
@@ -65,6 +67,14 @@ def get_Retriever(index,**kwargs):
|
|||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def get_synthesizer():
|
||||||
|
return CustomTreeResponse(
|
||||||
|
llm=Settings.llm,
|
||||||
|
summary_template=summary_template,
|
||||||
|
use_async=True,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
|
||||||
sql_database = None
|
sql_database = None
|
||||||
sql_obj_index = None
|
sql_obj_index = None
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import yaml
|
|||||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||||
from app.engine.loaders.projectJson import getProjectName
|
from app.engine.loaders.file import getProjectName
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from llama_index.core.readers.base import BaseReader
|
|||||||
from llama_index.core.readers.json import JSONReader
|
from llama_index.core.readers.json import JSONReader
|
||||||
from llama_parse import LlamaParse
|
from llama_parse import LlamaParse
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||||
|
from app.engine.loaders.projectJson import ProjectJson
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -20,7 +23,6 @@ class FileLoaderConfig(BaseModel):
|
|||||||
raise ValueError(f"Directory '{v}' does not exist")
|
raise ValueError(f"Directory '{v}' does not exist")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def llama_parse_parser():
|
def llama_parse_parser():
|
||||||
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -35,7 +37,6 @@ def llama_parse_parser():
|
|||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||||
|
|
||||||
@@ -43,8 +44,11 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
|||||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||||
|
|
||||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
parser = {
|
||||||
|
".json" : JSONReader(clean_json=False,levels_back=0),
|
||||||
|
".md" : ChunkMarkdownReader(),
|
||||||
|
}
|
||||||
|
return parser
|
||||||
|
|
||||||
def get_file_documents(config: FileLoaderConfig,childPath: str):
|
def get_file_documents(config: FileLoaderConfig,childPath: str):
|
||||||
from llama_index.core.readers import SimpleDirectoryReader
|
from llama_index.core.readers import SimpleDirectoryReader
|
||||||
@@ -86,3 +90,32 @@ def get_file_documents(config: FileLoaderConfig,childPath: str):
|
|||||||
else:
|
else:
|
||||||
# Raise the error if it is not the case of empty data dir
|
# Raise the error if it is not the case of empty data dir
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def prjFileSuffix(dir:str):
|
||||||
|
entries = os.listdir(dir)
|
||||||
|
file_names = [entry for entry in entries if os.path.isfile(os.path.join(dir, entry))]
|
||||||
|
if len(file_names) > 0:
|
||||||
|
return os.path.splitext(file_names[0])[1]
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def getProjectName(dir:str):
|
||||||
|
suffix = prjFileSuffix(dir)
|
||||||
|
if suffix== '.json':
|
||||||
|
prjJson = ProjectJson(dir)
|
||||||
|
prjJson.parse()
|
||||||
|
tb = prjJson.table('工程属性')
|
||||||
|
records = tb.records()
|
||||||
|
for record in records:
|
||||||
|
name = record.value('名称')
|
||||||
|
if name == '工程名称':
|
||||||
|
return record.value('值')
|
||||||
|
elif suffix == '.md':
|
||||||
|
md_files = [f for f in os.listdir(dir) if f.endswith('.md')]
|
||||||
|
for md_file in md_files:
|
||||||
|
prjPath = os.path.join(dir, md_file)
|
||||||
|
basename = os.path.splitext(md_file)[0]
|
||||||
|
if basename =='工程属性':
|
||||||
|
rd = ChunkMarkdownReader()
|
||||||
|
rd.load_data(prjPath)
|
||||||
|
return rd.findValue("名称=='工程名称'",'值')
|
||||||
|
return ''
|
||||||
@@ -13,6 +13,8 @@ class ChunkMarkdownReader(MarkdownReader):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._chunkSize = chunkSize
|
self._chunkSize = chunkSize
|
||||||
self._tokenizer = get_tokenizer()
|
self._tokenizer = get_tokenizer()
|
||||||
|
self._colheader = ''
|
||||||
|
self._rows = []
|
||||||
super().__init__(*args,**kwargs)
|
super().__init__(*args,**kwargs)
|
||||||
|
|
||||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||||
@@ -34,6 +36,8 @@ class ChunkMarkdownReader(MarkdownReader):
|
|||||||
tokensNum = headerSize
|
tokensNum = headerSize
|
||||||
current_lines.clear()
|
current_lines.clear()
|
||||||
current_lines.append(line)
|
current_lines.append(line)
|
||||||
|
if strTitle!='' and strheader!='':
|
||||||
|
self._rows.append(line)
|
||||||
|
|
||||||
if line == '\n' or line == '\r':
|
if line == '\n' or line == '\r':
|
||||||
if tokensNum > self._chunkSize:
|
if tokensNum > self._chunkSize:
|
||||||
@@ -43,10 +47,12 @@ class ChunkMarkdownReader(MarkdownReader):
|
|||||||
current_lines.clear()
|
current_lines.clear()
|
||||||
|
|
||||||
if line.startswith("|---"):
|
if line.startswith("|---"):
|
||||||
|
self._colheader = current_lines[0]
|
||||||
strheader = "\n".join(current_lines)
|
strheader = "\n".join(current_lines)
|
||||||
headerSize= headerSize + self._token_size(strheader)
|
headerSize= headerSize + self._token_size(strheader)
|
||||||
current_lines.clear()
|
current_lines.clear()
|
||||||
|
|
||||||
|
|
||||||
if len(current_lines) > 0:
|
if len(current_lines) > 0:
|
||||||
if len(markdown_tups) == 0:
|
if len(markdown_tups) == 0:
|
||||||
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
|
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
|
||||||
@@ -62,4 +68,22 @@ class ChunkMarkdownReader(MarkdownReader):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _token_size(self, text: str) -> int:
|
def _token_size(self, text: str) -> int:
|
||||||
return len(self._tokenizer(text))
|
return len(self._tokenizer(text))
|
||||||
|
|
||||||
|
def findValue(self,expression:str,Field:str):
|
||||||
|
cols = self._colheader.split('|')
|
||||||
|
cols = [item for item in cols if item]
|
||||||
|
|
||||||
|
for row in self._rows:
|
||||||
|
rowtrs = row.split('|')
|
||||||
|
rowdatas = [item for item in rowtrs if item and (item!='\r' or item!='\n')]
|
||||||
|
if len(rowdatas) == 0:
|
||||||
|
continue
|
||||||
|
gData = {}
|
||||||
|
for cName,rValue in zip(cols,rowdatas):
|
||||||
|
gData[cName] = rValue
|
||||||
|
if eval(expression,gData):
|
||||||
|
return gData[Field]
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ class JsonTable:
|
|||||||
def comment(self):
|
def comment(self):
|
||||||
return self._comment
|
return self._comment
|
||||||
|
|
||||||
|
|
||||||
class ProjectJson:
|
class ProjectJson:
|
||||||
def __init__(self,dir:str) -> None:
|
def __init__(self,dir:str) -> None:
|
||||||
self._dir = dir
|
self._dir = dir
|
||||||
@@ -76,14 +75,5 @@ class ProjectJson:
|
|||||||
def tables(self):
|
def tables(self):
|
||||||
return self._tables
|
return self._tables
|
||||||
|
|
||||||
def getProjectName(dir:str):
|
|
||||||
prjJson = ProjectJson(dir)
|
|
||||||
prjJson.parse()
|
|
||||||
tb:JsonTable = prjJson.table('工程属性')
|
|
||||||
records = tb.records()
|
|
||||||
for record in records:
|
|
||||||
name = record.value('名称')
|
|
||||||
if name == '工程名称':
|
|
||||||
return record.value('值')
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,234 @@
|
|||||||
|
from llama_index.core.response_synthesizers.tree_summarize import TreeSummarize
|
||||||
|
from typing import Any, Optional, Sequence,List
|
||||||
|
import asyncio
|
||||||
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
|
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||||
|
from llama_index.core.prompts import BasePromptTemplate
|
||||||
|
from llama_index.core.service_context import ServiceContext
|
||||||
|
from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType
|
||||||
|
from llama_index.core.types import BaseModel,RESPONSE_TEXT_TYPE
|
||||||
|
from llama_index.core.async_utils import run_async_tasks
|
||||||
|
from llama_index.core.utils import get_tokenizer
|
||||||
|
from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt
|
||||||
|
|
||||||
|
class CustomTreeResponse(TreeSummarize):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm: Optional[LLMPredictorType] = None,
|
||||||
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
|
prompt_helper: Optional[PromptHelper] = None,
|
||||||
|
summary_template: Optional[BasePromptTemplate] = None,
|
||||||
|
output_cls: Optional[BaseModel] = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
service_context: Optional[ServiceContext] = None,
|
||||||
|
) -> None:
|
||||||
|
self._tokenizer = get_tokenizer()
|
||||||
|
super().__init__(llm,callback_manager,prompt_helper,summary_template,output_cls
|
||||||
|
,streaming,use_async,verbose,service_context)
|
||||||
|
|
||||||
|
async def aget_response(
|
||||||
|
self,
|
||||||
|
query_str: str,
|
||||||
|
text_chunks: Sequence[str],
|
||||||
|
**response_kwargs: Any,
|
||||||
|
) -> RESPONSE_TEXT_TYPE:
|
||||||
|
"""Get tree summarize response."""
|
||||||
|
summary_template = self._summary_template.partial_format(query_str=query_str)
|
||||||
|
|
||||||
|
text_chunks = self.repack(text_chunks=text_chunks)
|
||||||
|
|
||||||
|
if self._verbose:
|
||||||
|
print(f"{len(text_chunks)} text chunks after repacking")
|
||||||
|
|
||||||
|
|
||||||
|
# give final response if there is only one chunk
|
||||||
|
if len(text_chunks) == 1:
|
||||||
|
response: RESPONSE_TEXT_TYPE
|
||||||
|
if self._streaming:
|
||||||
|
response = await self._llm.astream(
|
||||||
|
summary_template, context_str=text_chunks[0], **response_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self._output_cls is None:
|
||||||
|
response = await self._llm.apredict(
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunks[0],
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self._llm.astructured_predict(
|
||||||
|
self._output_cls,
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunks[0],
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return pydantic object if output_cls is specified
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
# summarize each chunk
|
||||||
|
if self._output_cls is None:
|
||||||
|
tasks = [
|
||||||
|
self._llm.apredict(
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tasks = [
|
||||||
|
self._llm.astructured_predict(
|
||||||
|
self._output_cls,
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
summary_responses = await asyncio.gather(*tasks)
|
||||||
|
if self._output_cls is not None:
|
||||||
|
summaries = [summary.json() for summary in summary_responses]
|
||||||
|
else:
|
||||||
|
summaries = summary_responses
|
||||||
|
|
||||||
|
# recursively summarize the summaries
|
||||||
|
return await self.aget_response(
|
||||||
|
query_str=query_str,
|
||||||
|
text_chunks=summaries,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
self,
|
||||||
|
query_str: str,
|
||||||
|
text_chunks: Sequence[str],
|
||||||
|
**response_kwargs: Any,
|
||||||
|
) -> RESPONSE_TEXT_TYPE:
|
||||||
|
"""Get tree summarize response."""
|
||||||
|
summary_template = self._summary_template.partial_format(query_str=query_str)
|
||||||
|
text_chunks = self.repack(text_chunks=text_chunks)
|
||||||
|
|
||||||
|
if self._verbose:
|
||||||
|
print(f"{len(text_chunks)} text chunks after repacking")
|
||||||
|
|
||||||
|
# give final response if there is only one chunk
|
||||||
|
if len(text_chunks) == 1:
|
||||||
|
response: RESPONSE_TEXT_TYPE
|
||||||
|
if self._streaming:
|
||||||
|
response = self._llm.stream(
|
||||||
|
summary_template, context_str=text_chunks[0], **response_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self._output_cls is None:
|
||||||
|
response = self._llm.predict(
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunks[0],
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self._llm.structured_predict(
|
||||||
|
self._output_cls,
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunks[0],
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
# summarize each chunk
|
||||||
|
if self._use_async:
|
||||||
|
if self._output_cls is None:
|
||||||
|
tasks = [
|
||||||
|
self._llm.apredict(
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tasks = [
|
||||||
|
self._llm.astructured_predict(
|
||||||
|
self._output_cls,
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
summary_responses = run_async_tasks(tasks)
|
||||||
|
|
||||||
|
if self._output_cls is not None:
|
||||||
|
summaries = [summary.json() for summary in summary_responses]
|
||||||
|
else:
|
||||||
|
summaries = summary_responses
|
||||||
|
else:
|
||||||
|
if self._output_cls is None:
|
||||||
|
summaries = [
|
||||||
|
self._llm.predict(
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
summaries = [
|
||||||
|
self._llm.structured_predict(
|
||||||
|
self._output_cls,
|
||||||
|
summary_template,
|
||||||
|
context_str=text_chunk,
|
||||||
|
**response_kwargs,
|
||||||
|
)
|
||||||
|
for text_chunk in text_chunks
|
||||||
|
]
|
||||||
|
summaries = [summary.json() for summary in summaries]
|
||||||
|
|
||||||
|
# recursively summarize the summaries
|
||||||
|
return self.get_response(
|
||||||
|
query_str=query_str, text_chunks=summaries, **response_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def repack( self,text_chunks: Sequence[str],) ->List[str]:
|
||||||
|
prompt_str = get_empty_prompt_txt(self._summary_template)
|
||||||
|
num_prompt_tokens = self._token_size(prompt_str)
|
||||||
|
avaliableSize = self._get_available_context_size(num_prompt_tokens)
|
||||||
|
ava_chunks = []
|
||||||
|
sumSize = 0
|
||||||
|
results = []
|
||||||
|
for text_chunk in text_chunks:
|
||||||
|
one_chunk_size = self._token_size(text_chunk)
|
||||||
|
if one_chunk_size > avaliableSize:
|
||||||
|
raise ValueError("文本块大小大于可用上下文大小")
|
||||||
|
sumSize = sumSize + one_chunk_size
|
||||||
|
if sumSize > avaliableSize:
|
||||||
|
results.append(self._merge_chunks(ava_chunks))
|
||||||
|
ava_chunks.clear()
|
||||||
|
sumSize = 0
|
||||||
|
ava_chunks.append(text_chunk)
|
||||||
|
if len(ava_chunks) > 0:
|
||||||
|
results.append(self._merge_chunks(ava_chunks))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_available_context_size(self, num_prompt_tokens: int) -> int:
|
||||||
|
llm_metadata = self._llm.metadata
|
||||||
|
context_size_tokens = llm_metadata.context_window - num_prompt_tokens - llm_metadata.num_output
|
||||||
|
if context_size_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Calculated available context size {context_size_tokens} was"
|
||||||
|
" not non-negative."
|
||||||
|
)
|
||||||
|
return context_size_tokens
|
||||||
|
|
||||||
|
def _token_size(self, text: str) -> int:
|
||||||
|
return len(self._tokenizer(text))
|
||||||
|
|
||||||
|
def _merge_chunks(self,ava_chunks:list):
|
||||||
|
return "\n\n".join([c.strip() for c in ava_chunks if c.strip()])
|
||||||
Reference in New Issue
Block a user