from llama_index.core.response_synthesizers.tree_summarize import TreeSummarize from typing import Any, Optional, Sequence,List import asyncio from llama_index.core.callbacks.base import CallbackManager from llama_index.core.indices.prompt_helper import PromptHelper from llama_index.core.prompts import BasePromptTemplate from llama_index.core.service_context import ServiceContext from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType from llama_index.core.types import BaseModel,RESPONSE_TEXT_TYPE from llama_index.core.async_utils import run_async_tasks from llama_index.core.utils import get_tokenizer from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt class CustomTreeResponse(TreeSummarize): def __init__( self, llm: Optional[LLMPredictorType] = None, callback_manager: Optional[CallbackManager] = None, prompt_helper: Optional[PromptHelper] = None, summary_template: Optional[BasePromptTemplate] = None, output_cls: Optional[BaseModel] = None, streaming: bool = False, use_async: bool = False, verbose: bool = False, service_context: Optional[ServiceContext] = None, ) -> None: self._tokenizer = get_tokenizer() super().__init__(llm,callback_manager,prompt_helper,summary_template,output_cls ,streaming,use_async,verbose,service_context) async def aget_response( self, query_str: str, text_chunks: Sequence[str], **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: """Get tree summarize response.""" summary_template = self._summary_template.partial_format(query_str=query_str) text_chunks = self.repack(text_chunks=text_chunks) if self._verbose: print(f"{len(text_chunks)} text chunks after repacking") # give final response if there is only one chunk if len(text_chunks) == 1: response: RESPONSE_TEXT_TYPE if self._streaming: response = await self._llm.astream( summary_template, context_str=text_chunks[0], **response_kwargs ) else: if self._output_cls is None: response = await self._llm.apredict( summary_template, context_str=text_chunks[0], **response_kwargs, ) else: response = await self._llm.astructured_predict( self._output_cls, summary_template, context_str=text_chunks[0], **response_kwargs, ) # return pydantic object if output_cls is specified return response else: # summarize each chunk if self._output_cls is None: tasks = [ self._llm.apredict( summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] else: tasks = [ self._llm.astructured_predict( self._output_cls, summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] summary_responses = await asyncio.gather(*tasks) if self._output_cls is not None: summaries = [summary.json() for summary in summary_responses] else: summaries = summary_responses # recursively summarize the summaries return await self.aget_response( query_str=query_str, text_chunks=summaries, **response_kwargs, ) def get_response( self, query_str: str, text_chunks: Sequence[str], **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: """Get tree summarize response.""" summary_template = self._summary_template.partial_format(query_str=query_str) text_chunks = self.repack(text_chunks=text_chunks) if self._verbose: print(f"{len(text_chunks)} text chunks after repacking") # give final response if there is only one chunk if len(text_chunks) == 1: response: RESPONSE_TEXT_TYPE if self._streaming: response = self._llm.stream( summary_template, context_str=text_chunks[0], **response_kwargs ) else: if self._output_cls is None: response = self._llm.predict( summary_template, context_str=text_chunks[0], **response_kwargs, ) else: response = self._llm.structured_predict( self._output_cls, summary_template, context_str=text_chunks[0], **response_kwargs, ) return response else: # summarize each chunk if self._use_async: if self._output_cls is None: tasks = [ self._llm.apredict( summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] else: tasks = [ self._llm.astructured_predict( self._output_cls, summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] summary_responses = run_async_tasks(tasks) if self._output_cls is not None: summaries = [summary.json() for summary in summary_responses] else: summaries = summary_responses else: if self._output_cls is None: summaries = [ self._llm.predict( summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] else: summaries = [ self._llm.structured_predict( self._output_cls, summary_template, context_str=text_chunk, **response_kwargs, ) for text_chunk in text_chunks ] summaries = [summary.json() for summary in summaries] # recursively summarize the summaries return self.get_response( query_str=query_str, text_chunks=summaries, **response_kwargs ) def repack( self,text_chunks: Sequence[str],) ->List[str]: prompt_str = get_empty_prompt_txt(self._summary_template) num_prompt_tokens = self._token_size(prompt_str) avaliableSize = self._get_available_context_size(num_prompt_tokens) ava_chunks = [] sumSize = 0 results = [] for text_chunk in text_chunks: one_chunk_size = self._token_size(text_chunk) if one_chunk_size > avaliableSize: raise ValueError("文本块大小大于可用上下文大小") sumSize = sumSize + one_chunk_size if sumSize > avaliableSize: results.append(self._merge_chunks(ava_chunks)) ava_chunks.clear() sumSize = 0 ava_chunks.append(text_chunk) if len(ava_chunks) > 0: results.append(self._merge_chunks(ava_chunks)) return results def _get_available_context_size(self, num_prompt_tokens: int) -> int: llm_metadata = self._llm.metadata context_size_tokens = llm_metadata.context_window - num_prompt_tokens - llm_metadata.num_output if context_size_tokens < 0: raise ValueError( f"Calculated available context size {context_size_tokens} was" " not non-negative." ) return context_size_tokens def _token_size(self, text: str) -> int: return len(self._tokenizer(text)) def _merge_chunks(self,ava_chunks:list): return "\n\n".join([c.strip() for c in ava_chunks if c.strip()])