Merge branch 'dev-web' of https://git.97id.com/ly/zjdataai-app into dev-web
This commit is contained in:
@@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from util.register import *
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if modelPaltCls is not None:
|
||||
modelPalt = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
@@ -39,7 +39,7 @@ def run_pipeline(docstore, vector_store, documents):
|
||||
#chunk_size=Settings.chunk_size,
|
||||
#chunk_overlap=Settings.chunk_overlap,
|
||||
#),
|
||||
MarkdownNodeParser(),
|
||||
#MarkdownNodeParser(),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
from app.engine.loaders.projectJson import *
|
||||
|
||||
class MarkDown:
|
||||
def __init__(self,table:JsonTable,path:str) -> None:
|
||||
self._table = table
|
||||
self._path = path
|
||||
|
||||
def build(self):
|
||||
flds:Dict[str,Field] = self._table.fields()
|
||||
records:List[Record] = self._table.records()
|
||||
columns:list = []
|
||||
colComments:list = []
|
||||
ignores:List[str] = []
|
||||
for name,fld in flds.items():
|
||||
if name =='_id' or name =='nodeType' or name =='relTbId':
|
||||
ignores.append(name)
|
||||
continue
|
||||
|
||||
columns.append(fld.value('name'))
|
||||
colComments.append(fld.value('alias'))
|
||||
|
||||
rowdatas = []
|
||||
for record in records:
|
||||
datas = []
|
||||
for col in columns:
|
||||
if col in ignores:
|
||||
continue
|
||||
txt:str = record.value(col)
|
||||
datas.append(txt.replace('\n'," "))
|
||||
rowdatas.append(datas)
|
||||
|
||||
content = self.convert(self._table.name(),self._table.comment(),columns,colComments,rowdatas)
|
||||
with open(self._path, 'w',encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
|
||||
def convert(self,tableName:str,tableComment:str,columns:list,colComments:list,rowdatas:list):
|
||||
strTitle = "# " + tableName + '\n'
|
||||
if tableName!='':
|
||||
strTitle+= f"备注:{tableComment}" + '\n'
|
||||
|
||||
for i in range(len(columns)):
|
||||
strTitle+= f"- 字段名称:{columns[i]}" + '\n'
|
||||
comment = colComments[i]
|
||||
if comment!='':
|
||||
strTitle+= f" - 备注:{comment}" + '\n'
|
||||
|
||||
markdown_table = "|"
|
||||
# 添加列标题
|
||||
markdown_table += "|".join(columns) + "|\n"
|
||||
# 添加分隔行
|
||||
markdown_table += "|" + "|".join(['---' for _ in columns]) + "|\n"
|
||||
# 遍历每个数据行
|
||||
for row in rowdatas:
|
||||
# 添加数据行
|
||||
markdown_table += "|" + "|".join(row) + "|\n"
|
||||
return strTitle + "\n" + markdown_table
|
||||
|
||||
|
||||
prjSon = ProjectJson('')
|
||||
prjSon.parse()
|
||||
tables = prjSon.tables()
|
||||
for name,table in tables.items():
|
||||
mdObj = MarkDown(table,f'')
|
||||
mdObj.build()
|
||||
@@ -0,0 +1,65 @@
|
||||
from llama_index.readers.file.markdown import MarkdownReader
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import re
|
||||
from llama_index.core.utils import get_tokenizer
|
||||
|
||||
|
||||
class ChunkMarkdownReader(MarkdownReader):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
chunkSize:int = 2048,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._chunkSize = chunkSize
|
||||
self._tokenizer = get_tokenizer()
|
||||
super().__init__(*args,**kwargs)
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
strTitle = ''
|
||||
tokensNum:int = 0
|
||||
current_lines = []
|
||||
strheader:str = ''
|
||||
headerSize:int = 0
|
||||
for line in lines:
|
||||
tokensNum += self._token_size(line)
|
||||
if tokensNum > self._chunkSize and len(current_lines) > 0:
|
||||
if len(markdown_tups) == 0:
|
||||
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
|
||||
else:
|
||||
markdown_tups.append((strheader , "\n".join(current_lines)))
|
||||
tokensNum = headerSize
|
||||
current_lines.clear()
|
||||
current_lines.append(line)
|
||||
|
||||
if line == '\n' or line == '\r':
|
||||
if tokensNum > self._chunkSize:
|
||||
raise ValueError('标题Token数大于chunkSize大小')
|
||||
strTitle = "\n".join(current_lines)
|
||||
#headerSize = headerSize + self._token_size(strTitle)
|
||||
current_lines.clear()
|
||||
|
||||
if line.startswith("|---"):
|
||||
strheader = "\n".join(current_lines)
|
||||
headerSize= headerSize + self._token_size(strheader)
|
||||
current_lines.clear()
|
||||
|
||||
if len(current_lines) > 0:
|
||||
if len(markdown_tups) == 0:
|
||||
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
|
||||
else:
|
||||
markdown_tups.append((strheader , "\n".join(current_lines)))
|
||||
|
||||
return [
|
||||
(
|
||||
key if key is None else re.sub(r"#", "", key).strip(),
|
||||
re.sub(r"<.*?>", "", value),
|
||||
)
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
|
||||
def _token_size(self, text: str) -> int:
|
||||
return len(self._tokenizer(text))
|
||||
@@ -24,13 +24,16 @@ class JsonTable:
|
||||
self._filePth = filePth
|
||||
self._fields:Dict[str,Field] = {}
|
||||
self._records:List[Record] = []
|
||||
self._fileName = os.path.splitext(os.path.basename(filePth))[0]
|
||||
self._name = ''
|
||||
self._comment = ''
|
||||
|
||||
def parse(self):
|
||||
with open(self._filePth, 'r',encoding='utf-8') as file:
|
||||
jsObj = json.load(file)
|
||||
data:dict = jsObj.get('table')
|
||||
self._name = data.get('name')
|
||||
self._comment = data.get('comment')
|
||||
Jsfields = data.get('fields')
|
||||
for jsfiled in Jsfields:
|
||||
field = Field(jsfiled)
|
||||
@@ -42,6 +45,16 @@ class JsonTable:
|
||||
|
||||
def records(self):
|
||||
return self._records
|
||||
|
||||
def fields(self):
|
||||
return self._fields
|
||||
|
||||
def name(self):
|
||||
return self._fileName
|
||||
|
||||
def comment(self):
|
||||
return self._comment
|
||||
|
||||
|
||||
class ProjectJson:
|
||||
def __init__(self,dir:str) -> None:
|
||||
@@ -59,6 +72,9 @@ class ProjectJson:
|
||||
|
||||
def table(self,tableName:str):
|
||||
return self._tables[tableName]
|
||||
|
||||
def tables(self):
|
||||
return self._tables
|
||||
|
||||
def getProjectName(dir:str):
|
||||
prjJson = ProjectJson(dir)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from typing import Any, List, Optional
|
||||
from llama_index.core.postprocessor import SentenceTransformerRerank
|
||||
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
|
||||
class OllamaRerank(SentenceTransformerRerank):
|
||||
_score_threshold: float = PrivateAttr()
|
||||
def __init__(
|
||||
self,
|
||||
top_n: int = 2,
|
||||
model: str = "cross-encoder/stsb-distilroberta-base",
|
||||
device: Optional[str] = None,
|
||||
keep_retrieval_score: Optional[bool] = False,
|
||||
score_threshold:float = 0.3
|
||||
):
|
||||
self._score_threshold = score_threshold
|
||||
super().__init__(top_n,model,device,keep_retrieval_score)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "OllamaRerank"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
query_and_nodes = [
|
||||
(
|
||||
query_bundle.query_str,
|
||||
node.node.get_content(metadata_mode=MetadataMode.EMBED),
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
EventPayload.NODES: nodes,
|
||||
EventPayload.MODEL_NAME: self.model,
|
||||
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||
EventPayload.TOP_K: self.top_n,
|
||||
},
|
||||
) as event:
|
||||
scores = self._model.predict(query_and_nodes)
|
||||
|
||||
assert len(scores) == len(nodes)
|
||||
|
||||
for node, score in zip(nodes, scores):
|
||||
if self.keep_retrieval_score:
|
||||
node.node.metadata["retrieval_score"] = node.score
|
||||
node.score = score
|
||||
|
||||
for i in range(len(nodes)-1,-1,-1):
|
||||
node = nodes[i]
|
||||
if node.score < self._score_threshold:
|
||||
nodes.remove(node)
|
||||
|
||||
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
|
||||
: self.top_n
|
||||
]
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
return new_nodes
|
||||
+12
-16
@@ -18,21 +18,6 @@ from llama_index.core.callbacks import CallbackManager
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if modelPaltCls is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||
@@ -91,7 +76,18 @@ class OllamaPlatform(ModelPlatform):
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
from app.engine.rerank.ollamRerank import OllamaRerank
|
||||
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
|
||||
top_n = os.getenv('RERANK_TOP_N',5)
|
||||
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
|
||||
rerank = OllamaRerank(
|
||||
model=modelpath,
|
||||
top_n=top_n,
|
||||
device="cpu",
|
||||
score_threshold= threshold
|
||||
)
|
||||
return [rerank]
|
||||
|
||||
|
||||
@register(ModelPlateCategory,'xinference')
|
||||
class XinferencePlatform(ModelPlatform):
|
||||
|
||||
Reference in New Issue
Block a user