Merge branch 'dev-web' of https://git.97id.com/ly/zjdataai-app into dev-web

This commit is contained in:
2024-09-10 08:43:38 +08:00
38 changed files with 257 additions and 49 deletions
+18 -2
View File
@@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from util.register import *
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.retriever.HybridRetriever import HybridRetriever
from app.settings import get_node_postprocessors
ModelPlateCategory = '模型平台'
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False':
return []
Rerank_provider = os.getenv("RERANK_PROVIDER")
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
postprocess = None
if modelPaltCls is not None:
modelPalt = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
+1 -1
View File
@@ -39,7 +39,7 @@ def run_pipeline(docstore, vector_store, documents):
#chunk_size=Settings.chunk_size,
#chunk_overlap=Settings.chunk_overlap,
#),
MarkdownNodeParser(),
#MarkdownNodeParser(),
Settings.embed_model,
],
docstore=docstore,
+64
View File
@@ -0,0 +1,64 @@
from app.engine.loaders.projectJson import *
class MarkDown:
def __init__(self,table:JsonTable,path:str) -> None:
self._table = table
self._path = path
def build(self):
flds:Dict[str,Field] = self._table.fields()
records:List[Record] = self._table.records()
columns:list = []
colComments:list = []
ignores:List[str] = []
for name,fld in flds.items():
if name =='_id' or name =='nodeType' or name =='relTbId':
ignores.append(name)
continue
columns.append(fld.value('name'))
colComments.append(fld.value('alias'))
rowdatas = []
for record in records:
datas = []
for col in columns:
if col in ignores:
continue
txt:str = record.value(col)
datas.append(txt.replace('\n'," "))
rowdatas.append(datas)
content = self.convert(self._table.name(),self._table.comment(),columns,colComments,rowdatas)
with open(self._path, 'w',encoding='utf-8') as file:
file.write(content)
def convert(self,tableName:str,tableComment:str,columns:list,colComments:list,rowdatas:list):
strTitle = "# " + tableName + '\n'
if tableName!='':
strTitle+= f"备注:{tableComment}" + '\n'
for i in range(len(columns)):
strTitle+= f"- 字段名称:{columns[i]}" + '\n'
comment = colComments[i]
if comment!='':
strTitle+= f" - 备注:{comment}" + '\n'
markdown_table = "|"
# 添加列标题
markdown_table += "|".join(columns) + "|\n"
# 添加分隔行
markdown_table += "|" + "|".join(['---' for _ in columns]) + "|\n"
# 遍历每个数据行
for row in rowdatas:
# 添加数据行
markdown_table += "|" + "|".join(row) + "|\n"
return strTitle + "\n" + markdown_table
prjSon = ProjectJson('')
prjSon.parse()
tables = prjSon.tables()
for name,table in tables.items():
mdObj = MarkDown(table,f'')
mdObj.build()
@@ -0,0 +1,65 @@
from llama_index.readers.file.markdown import MarkdownReader
from typing import Any, Dict, List, Optional, Tuple
import re
from llama_index.core.utils import get_tokenizer
class ChunkMarkdownReader(MarkdownReader):
def __init__(
self,
*args: Any,
chunkSize:int = 2048,
**kwargs: Any,
) -> None:
self._chunkSize = chunkSize
self._tokenizer = get_tokenizer()
super().__init__(*args,**kwargs)
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
markdown_tups: List[Tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
strTitle = ''
tokensNum:int = 0
current_lines = []
strheader:str = ''
headerSize:int = 0
for line in lines:
tokensNum += self._token_size(line)
if tokensNum > self._chunkSize and len(current_lines) > 0:
if len(markdown_tups) == 0:
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
else:
markdown_tups.append((strheader , "\n".join(current_lines)))
tokensNum = headerSize
current_lines.clear()
current_lines.append(line)
if line == '\n' or line == '\r':
if tokensNum > self._chunkSize:
raise ValueError('标题Token数大于chunkSize大小')
strTitle = "\n".join(current_lines)
#headerSize = headerSize + self._token_size(strTitle)
current_lines.clear()
if line.startswith("|---"):
strheader = "\n".join(current_lines)
headerSize= headerSize + self._token_size(strheader)
current_lines.clear()
if len(current_lines) > 0:
if len(markdown_tups) == 0:
markdown_tups.append((strTitle + strheader , "\n".join(current_lines)))
else:
markdown_tups.append((strheader , "\n".join(current_lines)))
return [
(
key if key is None else re.sub(r"#", "", key).strip(),
re.sub(r"<.*?>", "", value),
)
for key, value in markdown_tups
]
def _token_size(self, text: str) -> int:
return len(self._tokenizer(text))
+16
View File
@@ -24,13 +24,16 @@ class JsonTable:
self._filePth = filePth
self._fields:Dict[str,Field] = {}
self._records:List[Record] = []
self._fileName = os.path.splitext(os.path.basename(filePth))[0]
self._name = ''
self._comment = ''
def parse(self):
with open(self._filePth, 'r',encoding='utf-8') as file:
jsObj = json.load(file)
data:dict = jsObj.get('table')
self._name = data.get('name')
self._comment = data.get('comment')
Jsfields = data.get('fields')
for jsfiled in Jsfields:
field = Field(jsfiled)
@@ -42,6 +45,16 @@ class JsonTable:
def records(self):
return self._records
def fields(self):
return self._fields
def name(self):
return self._fileName
def comment(self):
return self._comment
class ProjectJson:
def __init__(self,dir:str) -> None:
@@ -59,6 +72,9 @@ class ProjectJson:
def table(self,tableName:str):
return self._tables[tableName]
def tables(self):
return self._tables
def getProjectName(dir:str):
prjJson = ProjectJson(dir)
+70
View File
@@ -0,0 +1,70 @@
from typing import Any, List, Optional
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.bridge.pydantic import PrivateAttr
class OllamaRerank(SentenceTransformerRerank):
_score_threshold: float = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "cross-encoder/stsb-distilroberta-base",
device: Optional[str] = None,
keep_retrieval_score: Optional[bool] = False,
score_threshold:float = 0.3
):
self._score_threshold = score_threshold
super().__init__(top_n,model,device,keep_retrieval_score)
@classmethod
def class_name(cls) -> str:
return "OllamaRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
query_and_nodes = [
(
query_bundle.query_str,
node.node.get_content(metadata_mode=MetadataMode.EMBED),
)
for node in nodes
]
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
scores = self._model.predict(query_and_nodes)
assert len(scores) == len(nodes)
for node, score in zip(nodes, scores):
if self.keep_retrieval_score:
node.node.metadata["retrieval_score"] = node.score
node.score = score
for i in range(len(nodes)-1,-1,-1):
node = nodes[i]
if node.score < self._score_threshold:
nodes.remove(node)
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
+12 -16
View File
@@ -18,21 +18,6 @@ from llama_index.core.callbacks import CallbackManager
ModelPlateCategory = '模型平台'
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False':
return []
Rerank_provider = os.getenv("RERANK_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
postprocess = None
if modelPaltCls is not None:
modelPalt:ModelPlatform = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
@@ -91,7 +76,18 @@ class OllamaPlatform(ModelPlatform):
pass
def rerank(self):
pass
from app.engine.rerank.ollamRerank import OllamaRerank
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
top_n = os.getenv('RERANK_TOP_N',5)
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
rerank = OllamaRerank(
model=modelpath,
top_n=top_n,
device="cpu",
score_threshold= threshold
)
return [rerank]
@register(ModelPlateCategory,'xinference')
class XinferencePlatform(ModelPlatform):