diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index 138ad53..1111d81 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -10,6 +10,8 @@ from sqlalchemy import create_engine from util.register import * from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.retriever.HybridRetriever import HybridRetriever +from app.engine.response.treeSummResponse import CustomTreeResponse +from llama_index.core.settings import Settings ModelPlateCategory = '模型平台' @@ -65,6 +67,14 @@ def get_Retriever(index,**kwargs): return retriever +def get_synthesizer(): + return CustomTreeResponse( + llm=Settings.llm, + summary_template=summary_template, + use_async=True, + streaming=False, + ) + sql_database = None sql_obj_index = None diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index f47fe94..b227cf4 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -3,7 +3,7 @@ import yaml from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents -from app.engine.loaders.projectJson import getProjectName +from app.engine.loaders.file import getProjectName import os diff --git a/backend/app/engine/loaders/file.py b/backend/app/engine/loaders/file.py index 5a0e648..a75d30e 100644 --- a/backend/app/engine/loaders/file.py +++ b/backend/app/engine/loaders/file.py @@ -6,6 +6,9 @@ from llama_index.core.readers.base import BaseReader from llama_index.core.readers.json import JSONReader from llama_parse import LlamaParse from pydantic import BaseModel, validator +from app.engine.loaders.markdownReader import ChunkMarkdownReader +from app.engine.loaders.projectJson import ProjectJson + logger = logging.getLogger(__name__) @@ -20,7 +23,6 @@ class FileLoaderConfig(BaseModel): raise ValueError(f"Directory '{v}' does not exist") return v - def llama_parse_parser(): if os.getenv("LLAMA_CLOUD_API_KEY") is None: raise ValueError( @@ -35,7 +37,6 @@ def llama_parse_parser(): ) return parser - def llama_parse_extractor() -> Dict[str, LlamaParse]: from llama_parse.utils import SUPPORTED_FILE_TYPES @@ -43,8 +44,11 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]: return {file_type: parser for file_type in SUPPORTED_FILE_TYPES} def llama_local_extractor() -> Dict[str, BaseReader]: - return {".json" : JSONReader(clean_json=False,levels_back=0)} - + parser = { + ".json" : JSONReader(clean_json=False,levels_back=0), + ".md" : ChunkMarkdownReader(), + } + return parser def get_file_documents(config: FileLoaderConfig,childPath: str): from llama_index.core.readers import SimpleDirectoryReader @@ -86,3 +90,32 @@ def get_file_documents(config: FileLoaderConfig,childPath: str): else: # Raise the error if it is not the case of empty data dir raise e + +def prjFileSuffix(dir:str): + entries = os.listdir(dir) + file_names = [entry for entry in entries if os.path.isfile(os.path.join(dir, entry))] + if len(file_names) > 0: + return os.path.splitext(file_names[0])[1] + return '' + +def getProjectName(dir:str): + suffix = prjFileSuffix(dir) + if suffix== '.json': + prjJson = ProjectJson(dir) + prjJson.parse() + tb = prjJson.table('工程属性') + records = tb.records() + for record in records: + name = record.value('名称') + if name == '工程名称': + return record.value('值') + elif suffix == '.md': + md_files = [f for f in os.listdir(dir) if f.endswith('.md')] + for md_file in md_files: + prjPath = os.path.join(dir, md_file) + basename = os.path.splitext(md_file)[0] + if basename =='工程属性': + rd = ChunkMarkdownReader() + rd.load_data(prjPath) + return rd.findValue("名称=='工程名称'",'值') + return '' \ No newline at end of file diff --git a/backend/app/engine/loaders/markdownReader.py b/backend/app/engine/loaders/markdownReader.py index bf30a16..d0688b2 100644 --- a/backend/app/engine/loaders/markdownReader.py +++ b/backend/app/engine/loaders/markdownReader.py @@ -13,6 +13,8 @@ class ChunkMarkdownReader(MarkdownReader): ) -> None: self._chunkSize = chunkSize self._tokenizer = get_tokenizer() + self._colheader = '' + self._rows = [] super().__init__(*args,**kwargs) def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: @@ -34,6 +36,8 @@ class ChunkMarkdownReader(MarkdownReader): tokensNum = headerSize current_lines.clear() current_lines.append(line) + if strTitle!='' and strheader!='': + self._rows.append(line) if line == '\n' or line == '\r': if tokensNum > self._chunkSize: @@ -43,10 +47,12 @@ class ChunkMarkdownReader(MarkdownReader): current_lines.clear() if line.startswith("|---"): + self._colheader = current_lines[0] strheader = "\n".join(current_lines) headerSize= headerSize + self._token_size(strheader) current_lines.clear() + if len(current_lines) > 0: if len(markdown_tups) == 0: markdown_tups.append((strTitle + strheader , "\n".join(current_lines))) @@ -62,4 +68,22 @@ class ChunkMarkdownReader(MarkdownReader): ] def _token_size(self, text: str) -> int: - return len(self._tokenizer(text)) \ No newline at end of file + return len(self._tokenizer(text)) + + def findValue(self,expression:str,Field:str): + cols = self._colheader.split('|') + cols = [item for item in cols if item] + + for row in self._rows: + rowtrs = row.split('|') + rowdatas = [item for item in rowtrs if item and (item!='\r' or item!='\n')] + if len(rowdatas) == 0: + continue + gData = {} + for cName,rValue in zip(cols,rowdatas): + gData[cName] = rValue + if eval(expression,gData): + return gData[Field] + return '' + + diff --git a/backend/app/engine/loaders/projectJson.py b/backend/app/engine/loaders/projectJson.py index c1b0d0a..3badacf 100644 --- a/backend/app/engine/loaders/projectJson.py +++ b/backend/app/engine/loaders/projectJson.py @@ -55,7 +55,6 @@ class JsonTable: def comment(self): return self._comment - class ProjectJson: def __init__(self,dir:str) -> None: self._dir = dir @@ -76,14 +75,5 @@ class ProjectJson: def tables(self): return self._tables -def getProjectName(dir:str): - prjJson = ProjectJson(dir) - prjJson.parse() - tb:JsonTable = prjJson.table('工程属性') - records = tb.records() - for record in records: - name = record.value('名称') - if name == '工程名称': - return record.value('值') - return '' + diff --git a/backend/app/engine/response/treeSummResponse.py b/backend/app/engine/response/treeSummResponse.py new file mode 100644 index 0000000..7cf1868 --- /dev/null +++ b/backend/app/engine/response/treeSummResponse.py @@ -0,0 +1,234 @@ +from llama_index.core.response_synthesizers.tree_summarize import TreeSummarize +from typing import Any, Optional, Sequence,List +import asyncio +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.indices.prompt_helper import PromptHelper +from llama_index.core.prompts import BasePromptTemplate +from llama_index.core.service_context import ServiceContext +from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType +from llama_index.core.types import BaseModel,RESPONSE_TEXT_TYPE +from llama_index.core.async_utils import run_async_tasks +from llama_index.core.utils import get_tokenizer +from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt + +class CustomTreeResponse(TreeSummarize): + def __init__( + self, + llm: Optional[LLMPredictorType] = None, + callback_manager: Optional[CallbackManager] = None, + prompt_helper: Optional[PromptHelper] = None, + summary_template: Optional[BasePromptTemplate] = None, + output_cls: Optional[BaseModel] = None, + streaming: bool = False, + use_async: bool = False, + verbose: bool = False, + service_context: Optional[ServiceContext] = None, + ) -> None: + self._tokenizer = get_tokenizer() + super().__init__(llm,callback_manager,prompt_helper,summary_template,output_cls + ,streaming,use_async,verbose,service_context) + + async def aget_response( + self, + query_str: str, + text_chunks: Sequence[str], + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + """Get tree summarize response.""" + summary_template = self._summary_template.partial_format(query_str=query_str) + + text_chunks = self.repack(text_chunks=text_chunks) + + if self._verbose: + print(f"{len(text_chunks)} text chunks after repacking") + + + # give final response if there is only one chunk + if len(text_chunks) == 1: + response: RESPONSE_TEXT_TYPE + if self._streaming: + response = await self._llm.astream( + summary_template, context_str=text_chunks[0], **response_kwargs + ) + else: + if self._output_cls is None: + response = await self._llm.apredict( + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) + else: + response = await self._llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) + + # return pydantic object if output_cls is specified + return response + + else: + # summarize each chunk + if self._output_cls is None: + tasks = [ + self._llm.apredict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + tasks = [ + self._llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + + summary_responses = await asyncio.gather(*tasks) + if self._output_cls is not None: + summaries = [summary.json() for summary in summary_responses] + else: + summaries = summary_responses + + # recursively summarize the summaries + return await self.aget_response( + query_str=query_str, + text_chunks=summaries, + **response_kwargs, + ) + + def get_response( + self, + query_str: str, + text_chunks: Sequence[str], + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + """Get tree summarize response.""" + summary_template = self._summary_template.partial_format(query_str=query_str) + text_chunks = self.repack(text_chunks=text_chunks) + + if self._verbose: + print(f"{len(text_chunks)} text chunks after repacking") + + # give final response if there is only one chunk + if len(text_chunks) == 1: + response: RESPONSE_TEXT_TYPE + if self._streaming: + response = self._llm.stream( + summary_template, context_str=text_chunks[0], **response_kwargs + ) + else: + if self._output_cls is None: + response = self._llm.predict( + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) + else: + response = self._llm.structured_predict( + self._output_cls, + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) + + return response + + else: + # summarize each chunk + if self._use_async: + if self._output_cls is None: + tasks = [ + self._llm.apredict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + tasks = [ + self._llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + + summary_responses = run_async_tasks(tasks) + + if self._output_cls is not None: + summaries = [summary.json() for summary in summary_responses] + else: + summaries = summary_responses + else: + if self._output_cls is None: + summaries = [ + self._llm.predict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + summaries = [ + self._llm.structured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + summaries = [summary.json() for summary in summaries] + + # recursively summarize the summaries + return self.get_response( + query_str=query_str, text_chunks=summaries, **response_kwargs + ) + + def repack( self,text_chunks: Sequence[str],) ->List[str]: + prompt_str = get_empty_prompt_txt(self._summary_template) + num_prompt_tokens = self._token_size(prompt_str) + avaliableSize = self._get_available_context_size(num_prompt_tokens) + ava_chunks = [] + sumSize = 0 + results = [] + for text_chunk in text_chunks: + one_chunk_size = self._token_size(text_chunk) + if one_chunk_size > avaliableSize: + raise ValueError("文本块大小大于可用上下文大小") + sumSize = sumSize + one_chunk_size + if sumSize > avaliableSize: + results.append(self._merge_chunks(ava_chunks)) + ava_chunks.clear() + sumSize = 0 + ava_chunks.append(text_chunk) + if len(ava_chunks) > 0: + results.append(self._merge_chunks(ava_chunks)) + return results + + def _get_available_context_size(self, num_prompt_tokens: int) -> int: + llm_metadata = self._llm.metadata + context_size_tokens = llm_metadata.context_window - num_prompt_tokens - llm_metadata.num_output + if context_size_tokens < 0: + raise ValueError( + f"Calculated available context size {context_size_tokens} was" + " not non-negative." + ) + return context_size_tokens + + def _token_size(self, text: str) -> int: + return len(self._tokenizer(text)) + + def _merge_chunks(self,ava_chunks:list): + return "\n\n".join([c.strip() for c in ava_chunks if c.strip()]) \ No newline at end of file