234 lines
9.0 KiB
Python
234 lines
9.0 KiB
Python
from llama_index.core.response_synthesizers.tree_summarize import TreeSummarize
|
|
from typing import Any, Optional, Sequence,List
|
|
import asyncio
|
|
from llama_index.core.callbacks.base import CallbackManager
|
|
from llama_index.core.indices.prompt_helper import PromptHelper
|
|
from llama_index.core.prompts import BasePromptTemplate
|
|
from llama_index.core.service_context import ServiceContext
|
|
from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType
|
|
from llama_index.core.types import BaseModel,RESPONSE_TEXT_TYPE
|
|
from llama_index.core.async_utils import run_async_tasks
|
|
from llama_index.core.utils import get_tokenizer
|
|
from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt
|
|
|
|
class CustomTreeResponse(TreeSummarize):
|
|
def __init__(
|
|
self,
|
|
llm: Optional[LLMPredictorType] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
prompt_helper: Optional[PromptHelper] = None,
|
|
summary_template: Optional[BasePromptTemplate] = None,
|
|
output_cls: Optional[BaseModel] = None,
|
|
streaming: bool = False,
|
|
use_async: bool = False,
|
|
verbose: bool = False,
|
|
service_context: Optional[ServiceContext] = None,
|
|
) -> None:
|
|
self._tokenizer = get_tokenizer()
|
|
super().__init__(llm,callback_manager,prompt_helper,summary_template,output_cls
|
|
,streaming,use_async,verbose,service_context)
|
|
|
|
async def aget_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Get tree summarize response."""
|
|
summary_template = self._summary_template.partial_format(query_str=query_str)
|
|
|
|
text_chunks = self.repack(text_chunks=text_chunks)
|
|
|
|
if self._verbose:
|
|
print(f"{len(text_chunks)} text chunks after repacking")
|
|
|
|
|
|
# give final response if there is only one chunk
|
|
if len(text_chunks) == 1:
|
|
response: RESPONSE_TEXT_TYPE
|
|
if self._streaming:
|
|
response = await self._llm.astream(
|
|
summary_template, context_str=text_chunks[0], **response_kwargs
|
|
)
|
|
else:
|
|
if self._output_cls is None:
|
|
response = await self._llm.apredict(
|
|
summary_template,
|
|
context_str=text_chunks[0],
|
|
**response_kwargs,
|
|
)
|
|
else:
|
|
response = await self._llm.astructured_predict(
|
|
self._output_cls,
|
|
summary_template,
|
|
context_str=text_chunks[0],
|
|
**response_kwargs,
|
|
)
|
|
|
|
# return pydantic object if output_cls is specified
|
|
return response
|
|
|
|
else:
|
|
# summarize each chunk
|
|
if self._output_cls is None:
|
|
tasks = [
|
|
self._llm.apredict(
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
else:
|
|
tasks = [
|
|
self._llm.astructured_predict(
|
|
self._output_cls,
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
|
|
summary_responses = await asyncio.gather(*tasks)
|
|
if self._output_cls is not None:
|
|
summaries = [summary.json() for summary in summary_responses]
|
|
else:
|
|
summaries = summary_responses
|
|
|
|
# recursively summarize the summaries
|
|
return await self.aget_response(
|
|
query_str=query_str,
|
|
text_chunks=summaries,
|
|
**response_kwargs,
|
|
)
|
|
|
|
def get_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Get tree summarize response."""
|
|
summary_template = self._summary_template.partial_format(query_str=query_str)
|
|
text_chunks = self.repack(text_chunks=text_chunks)
|
|
|
|
if self._verbose:
|
|
print(f"{len(text_chunks)} text chunks after repacking")
|
|
|
|
# give final response if there is only one chunk
|
|
if len(text_chunks) == 1:
|
|
response: RESPONSE_TEXT_TYPE
|
|
if self._streaming:
|
|
response = self._llm.stream(
|
|
summary_template, context_str=text_chunks[0], **response_kwargs
|
|
)
|
|
else:
|
|
if self._output_cls is None:
|
|
response = self._llm.predict(
|
|
summary_template,
|
|
context_str=text_chunks[0],
|
|
**response_kwargs,
|
|
)
|
|
else:
|
|
response = self._llm.structured_predict(
|
|
self._output_cls,
|
|
summary_template,
|
|
context_str=text_chunks[0],
|
|
**response_kwargs,
|
|
)
|
|
|
|
return response
|
|
|
|
else:
|
|
# summarize each chunk
|
|
if self._use_async:
|
|
if self._output_cls is None:
|
|
tasks = [
|
|
self._llm.apredict(
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
else:
|
|
tasks = [
|
|
self._llm.astructured_predict(
|
|
self._output_cls,
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
|
|
summary_responses = run_async_tasks(tasks)
|
|
|
|
if self._output_cls is not None:
|
|
summaries = [summary.json() for summary in summary_responses]
|
|
else:
|
|
summaries = summary_responses
|
|
else:
|
|
if self._output_cls is None:
|
|
summaries = [
|
|
self._llm.predict(
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
else:
|
|
summaries = [
|
|
self._llm.structured_predict(
|
|
self._output_cls,
|
|
summary_template,
|
|
context_str=text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
for text_chunk in text_chunks
|
|
]
|
|
summaries = [summary.json() for summary in summaries]
|
|
|
|
# recursively summarize the summaries
|
|
return self.get_response(
|
|
query_str=query_str, text_chunks=summaries, **response_kwargs
|
|
)
|
|
|
|
def repack( self,text_chunks: Sequence[str],) ->List[str]:
|
|
prompt_str = get_empty_prompt_txt(self._summary_template)
|
|
num_prompt_tokens = self._token_size(prompt_str)
|
|
avaliableSize = self._get_available_context_size(num_prompt_tokens)
|
|
ava_chunks = []
|
|
sumSize = 0
|
|
results = []
|
|
for text_chunk in text_chunks:
|
|
one_chunk_size = self._token_size(text_chunk)
|
|
if one_chunk_size > avaliableSize:
|
|
raise ValueError("文本块大小大于可用上下文大小")
|
|
sumSize = sumSize + one_chunk_size
|
|
if sumSize > avaliableSize:
|
|
results.append(self._merge_chunks(ava_chunks))
|
|
ava_chunks.clear()
|
|
sumSize = 0
|
|
ava_chunks.append(text_chunk)
|
|
if len(ava_chunks) > 0:
|
|
results.append(self._merge_chunks(ava_chunks))
|
|
return results
|
|
|
|
def _get_available_context_size(self, num_prompt_tokens: int) -> int:
|
|
llm_metadata = self._llm.metadata
|
|
context_size_tokens = llm_metadata.context_window - num_prompt_tokens - llm_metadata.num_output
|
|
if context_size_tokens < 0:
|
|
raise ValueError(
|
|
f"Calculated available context size {context_size_tokens} was"
|
|
" not non-negative."
|
|
)
|
|
return context_size_tokens
|
|
|
|
def _token_size(self, text: str) -> int:
|
|
return len(self._tokenizer(text))
|
|
|
|
def _merge_chunks(self,ava_chunks:list):
|
|
return "\n\n".join([c.strip() for c in ava_chunks if c.strip()]) |