Compare commits
13 Commits
786c4d05f6
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| e634746a52 | |||
| d12800e14e | |||
| c1df0d1bba | |||
| 0664952ecd | |||
| 7023b54246 | |||
| aee6aa3c04 | |||
| 680e24c516 | |||
| 6663ee8976 | |||
| 0a5f335981 | |||
| 2901bd9eaf | |||
| 453b3ca55c | |||
| f0afd1a4bb | |||
| eb572eff27 |
@@ -1,3 +1,8 @@
|
|||||||
|
JIEBA_DATA=./nltk_data
|
||||||
|
NLTK_DATA=./nltk_data
|
||||||
|
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||||
|
DATA_SOURCE_CACHE=./restapi
|
||||||
|
|
||||||
# The Llama Cloud API key.
|
# The Llama Cloud API key.
|
||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
JIEBA_DATA=./nltk_data
|
||||||
|
NLTK_DATA=./nltk_data
|
||||||
|
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||||
|
DATA_SOURCE_CACHE=./restapi
|
||||||
|
|
||||||
# The Llama Cloud API key.
|
# The Llama Cloud API key.
|
||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||||
from bm25s.tokenization import *
|
from bm25s.tokenization import *
|
||||||
|
|
||||||
@@ -8,9 +9,12 @@ except ImportError:
|
|||||||
def tqdm(iterable, *args, **kwargs):
|
def tqdm(iterable, *args, **kwargs):
|
||||||
return iterable
|
return iterable
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
jiebapath = os.environ.get("JIEBA_DATA", "")
|
||||||
|
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
|
||||||
|
jieba.initialize() #初始化jeiba
|
||||||
|
|
||||||
def chinese_tokenizer(text: str) -> List[str]:
|
def chinese_tokenizer(text: str) -> List[str]:
|
||||||
import jieba
|
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
tokens = jieba.lcut(text)
|
tokens = jieba.lcut(text)
|
||||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||||
|
|||||||
@@ -3,11 +3,10 @@ from typing import Dict
|
|||||||
|
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.settings import Settings
|
||||||
|
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||||
from llama_index.llms.xinference import Xinference
|
from llama_index.llms.xinference import Xinference
|
||||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||||
|
|
||||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_postprocessors():
|
def get_node_postprocessors():
|
||||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
+349046
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,121 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from llama_index.core.evaluation import CorrectnessEvaluator
|
||||||
|
from app.engine import get_chat_engine
|
||||||
|
from app.engine.index import get_index
|
||||||
|
from app.observability import init_observability
|
||||||
|
from app.settings import init_settings
|
||||||
|
|
||||||
|
init_settings()
|
||||||
|
init_observability()
|
||||||
|
|
||||||
|
index = get_index()
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
from llama_index.core.prompts import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
MessageRole
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_TEMPLATE = """
|
||||||
|
您是一个问答聊天机器人的专业评估系统。
|
||||||
|
|
||||||
|
您将获得以下信息:
|
||||||
|
|
||||||
|
- 用户查询,
|
||||||
|
- 生成的回答,
|
||||||
|
|
||||||
|
也可能提供一个参考答案作为评估的依据。
|
||||||
|
|
||||||
|
您的任务是判断生成回答的相关性和正确性。
|
||||||
|
输出一个代表全面评估的单一分数。
|
||||||
|
您必须在一行中仅返回该分数。
|
||||||
|
不要以其他任何格式返回答案。
|
||||||
|
在单独的一行提供给定分数的理由。
|
||||||
|
|
||||||
|
请遵循以下评分指南:
|
||||||
|
|
||||||
|
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
|
||||||
|
-如果生成的回答与用户查询不相关,您应该给出1分。
|
||||||
|
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
|
||||||
|
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
|
||||||
|
示例响应:
|
||||||
|
4.0
|
||||||
|
生成的回答与参考答案的指标完全相同,但不够精炼。
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_USER_TEMPLATE = """
|
||||||
|
## User Query
|
||||||
|
{query}
|
||||||
|
|
||||||
|
## Reference Answer
|
||||||
|
{reference_answer}
|
||||||
|
|
||||||
|
## Generated Answer
|
||||||
|
{generated_answer}
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
||||||
|
message_templates=[
|
||||||
|
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
|
||||||
|
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化聊天引擎和评估器
|
||||||
|
chat_engine = get_chat_engine()
|
||||||
|
corr_evaluator_qwen = CorrectnessEvaluator()
|
||||||
|
|
||||||
|
# 加载本地问题回答文件
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(script_dir, 'questions_and_answers.json')
|
||||||
|
output_file_path = file_path.replace('.json', '_test.json')
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# 异步函数用于评估查询
|
||||||
|
async def evaluate_query(question, answer, index, output_file):
|
||||||
|
response = await chat_engine.astream_chat(question)
|
||||||
|
|
||||||
|
# 检查sources是否为空
|
||||||
|
if response.sources:
|
||||||
|
content_str = str(response.sources[0])
|
||||||
|
else:
|
||||||
|
content_str = "<无回答>"
|
||||||
|
|
||||||
|
result = corr_evaluator_qwen.evaluate(
|
||||||
|
query=question,
|
||||||
|
response=content_str,
|
||||||
|
reference=answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"编号": index,
|
||||||
|
"问题": question,
|
||||||
|
"答案": answer,
|
||||||
|
"回答": result.response,
|
||||||
|
"得分(1~5)": result.score,
|
||||||
|
"评价": result.feedback
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_file, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
||||||
|
f.write(',\n')
|
||||||
|
|
||||||
|
# 主异步函数
|
||||||
|
async def main():
|
||||||
|
for index, item in enumerate(data, start=1):
|
||||||
|
await evaluate_query(item['question'], item['answer'], index, output_file_path)
|
||||||
|
|
||||||
|
# 运行主协程
|
||||||
|
asyncio.run(main())
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
Attribute_Prompt = (
|
||||||
|
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||||
|
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||||
|
"现在需要你针对数据中属性一列进行提问和回答。"
|
||||||
|
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
|
||||||
|
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||||
|
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||||
|
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Amount_Prompt = (
|
||||||
|
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||||
|
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||||
|
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
|
||||||
|
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
|
||||||
|
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||||
|
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||||
|
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Units_Prompt = (
|
||||||
|
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||||
|
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||||
|
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
|
||||||
|
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
|
||||||
|
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||||
|
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||||
|
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
Name_Prompt = (
|
||||||
|
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||||
|
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||||
|
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
|
||||||
|
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
|
||||||
|
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||||
|
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||||
|
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
All_Amount_Prompt = (
|
||||||
|
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||||
|
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||||
|
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
|
||||||
|
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
|
||||||
|
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||||
|
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||||
|
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
from app.observability import init_observability
|
||||||
|
from app.settings import init_settings
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.evaluation import DatasetGenerator
|
||||||
|
|
||||||
|
import prompts
|
||||||
|
|
||||||
|
init_settings()
|
||||||
|
init_observability()
|
||||||
|
|
||||||
|
# 读取所有文档(即所有表格)
|
||||||
|
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||||
|
|
||||||
|
# 定义表格名称和索引的对应关系
|
||||||
|
table_names = {
|
||||||
|
"工程信息表": 0,
|
||||||
|
"其他费用表": 1,
|
||||||
|
"取费表": 2,
|
||||||
|
"项目划分表": 3,
|
||||||
|
"项目划分_费用预览表": 4,
|
||||||
|
"总算表": 5,
|
||||||
|
"工程量表": 6
|
||||||
|
}
|
||||||
|
|
||||||
|
# 定义中文提示词和Python代码中提示词名称的映射
|
||||||
|
prompt_mapping = {
|
||||||
|
"普通属性": "Attribute_Prompt",
|
||||||
|
"金额查询": "Amount_Prompt",
|
||||||
|
"单位换算": "Units_Prompt",
|
||||||
|
"重名项目划分": "Name_Prompt",
|
||||||
|
"总体金额查询": "All_Amount_Prompt"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 定义表格与其对应的查询类别
|
||||||
|
table_prompt_mapping = {
|
||||||
|
"工程信息表": ["普通属性", "单位换算"],
|
||||||
|
"其他费用表": ["金额查询", "单位换算"],
|
||||||
|
"取费表": ["金额查询"],
|
||||||
|
"总算表": ["金额查询", "总体金额查询"],
|
||||||
|
"工程量表": ["普通属性", "重名项目划分"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据表格名称选择特定的表格
|
||||||
|
def select_document(documents, table_name):
|
||||||
|
if table_name not in table_names:
|
||||||
|
raise ValueError(f"未找到名为 '{table_name}' 的表格")
|
||||||
|
index = table_names[table_name]
|
||||||
|
return [documents[index]] # 返回一个包含所选表格的列表
|
||||||
|
|
||||||
|
# 选择提示词
|
||||||
|
def select_prompt(prompt_category):
|
||||||
|
prompt_name = prompt_mapping.get(prompt_category)
|
||||||
|
if not prompt_name:
|
||||||
|
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
||||||
|
try:
|
||||||
|
return getattr(prompts, prompt_name)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
||||||
|
|
||||||
|
# 生成问题和答案
|
||||||
|
def generate_questions_from_document(document, quest_prompt, num_questions):
|
||||||
|
question_generator = DatasetGenerator.from_documents(
|
||||||
|
documents=document,
|
||||||
|
question_gen_query=quest_prompt,
|
||||||
|
num_questions_per_chunk=num_questions
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
|
||||||
|
print(eval_questions)
|
||||||
|
|
||||||
|
qa_pairs = []
|
||||||
|
for qa in eval_questions:
|
||||||
|
if ',' in qa:
|
||||||
|
question, answer = qa.split(",", 1)
|
||||||
|
qa_pairs.append({
|
||||||
|
"question": question.strip(),
|
||||||
|
"answer": answer.strip()
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"无法处理的问题和答案: {qa}")
|
||||||
|
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
||||||
|
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
|
||||||
|
if table_names_input == "all":
|
||||||
|
selected_tables = list(table_prompt_mapping.keys())
|
||||||
|
else:
|
||||||
|
selected_tables = table_names_input.strip('[]').split(',')
|
||||||
|
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
for table_name in selected_tables:
|
||||||
|
table_name = table_name.strip() # 去掉前后空格
|
||||||
|
document = select_document(documents, table_name)
|
||||||
|
|
||||||
|
if prompt_categories_input == "all":
|
||||||
|
selected_prompts = table_prompt_mapping[table_name]
|
||||||
|
else:
|
||||||
|
selected_prompts = prompt_categories_input.strip('[]').split(',')
|
||||||
|
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
|
||||||
|
|
||||||
|
for prompt_category in selected_prompts:
|
||||||
|
if prompt_category not in table_prompt_mapping[table_name]:
|
||||||
|
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
|
||||||
|
continue
|
||||||
|
|
||||||
|
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
||||||
|
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
|
||||||
|
|
||||||
|
label = f"test:{table_name}_{prompt_category}"
|
||||||
|
all_results[label] = qa_pairs
|
||||||
|
|
||||||
|
# 自动生成输出文件名
|
||||||
|
output_file = "combined_test.json"
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(f"All questions and answers have been saved to '{output_file}'")
|
||||||
|
|
||||||
|
# 获取命令行参数
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
||||||
|
else:
|
||||||
|
table_names_input = sys.argv[1]
|
||||||
|
prompt_categories_input = sys.argv[2]
|
||||||
|
num_questions_per_prompt = int(sys.argv[3])
|
||||||
|
|
||||||
|
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
import phoenix as px
|
import phoenix as px
|
||||||
|
|
||||||
|
|
||||||
os.environ['PHOENIX_HOST'] = "0.0.0.0"
|
|
||||||
|
|
||||||
session = px.launch_app(use_temp_dir=False)
|
session = px.launch_app(use_temp_dir=False)
|
||||||
|
|
||||||
import msvcrt
|
import msvcrt
|
||||||
|
|||||||
Reference in New Issue
Block a user