951 lines
35 KiB
Python
951 lines
35 KiB
Python
import streamlit as st
|
||
import json
|
||
import os
|
||
import uuid
|
||
import time
|
||
from typing import Dict, List, Optional, Any
|
||
import logging
|
||
import sys
|
||
from datetime import datetime
|
||
from textcomplete import textcomplete, StrategyProps
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# 确保必要的目录存在
|
||
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")
|
||
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
|
||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||
os.makedirs(DATA_DIR, exist_ok=True)
|
||
|
||
# 配置日志
|
||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
log_filename = f"{current_file}_{now_str}.log"
|
||
|
||
# 确保日志目录存在并设置在logs目录下
|
||
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
logging.basicConfig(
|
||
level=logging.DEBUG,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
handlers=[
|
||
logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"),
|
||
logging.StreamHandler()
|
||
],
|
||
)
|
||
|
||
logger = logging.getLogger(current_file)
|
||
|
||
def setup_logger(logger_name):
|
||
"""
|
||
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
||
:param logger_name: logger的名称
|
||
"""
|
||
logger = logging.getLogger(logger_name)
|
||
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
|
||
return logger
|
||
|
||
|
||
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
|
||
for name in logger_names:
|
||
setup_logger(name)
|
||
|
||
# 延迟导入依赖,只有在需要时才加载
|
||
@st.cache_resource
|
||
def load_dependencies():
|
||
"""
|
||
延迟加载依赖项,并缓存资源以提高性能
|
||
"""
|
||
from src.config import Config
|
||
from src.document_loader import load_file
|
||
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
||
from src.user_interaction import UserInteraction
|
||
from src.dialog_manager import DialogManager
|
||
from src.code_executor import CodeExecutor
|
||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||
from src.prompt_manager import PromptManager
|
||
from src.embedding_client import EmbeddingClient
|
||
from src.project import ProjectBuilder, ProjectToolkit
|
||
from src.project_implementation import ProjectToolkitNeo4j
|
||
|
||
return {
|
||
"Config": Config,
|
||
"load_file": load_file,
|
||
"MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI,
|
||
"UserInteraction": UserInteraction,
|
||
"DialogManager": DialogManager,
|
||
"CodeExecutor": CodeExecutor,
|
||
"Neo4jRawRetriever": Neo4jRawRetriever,
|
||
"PromptManager": PromptManager,
|
||
"EmbeddingClient": EmbeddingClient,
|
||
"ProjectBuilder": ProjectBuilder,
|
||
"ProjectToolkitNeo4j": ProjectToolkitNeo4j
|
||
}
|
||
|
||
# 后台数据结构
|
||
class Indicator:
|
||
def __init__(self, name: str, query: str = "", code: str = "", id: Optional[str] = None):
|
||
self.id = id if id else str(uuid.uuid4())
|
||
self.name = name
|
||
self.query = query
|
||
self.code = code
|
||
self.created_at = time.time()
|
||
self.updated_at = time.time()
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"id": self.id,
|
||
"name": self.name,
|
||
"query": self.query,
|
||
"code": self.code,
|
||
"created_at": self.created_at,
|
||
"updated_at": self.updated_at
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'Indicator':
|
||
indicator = cls(
|
||
name=data["name"],
|
||
query=data["query"],
|
||
code=data.get("code", ""), # 使用 get 方法处理可能不存在的 code 字段
|
||
id=data["id"]
|
||
)
|
||
indicator.created_at = data.get("created_at", time.time())
|
||
indicator.updated_at = data.get("updated_at", time.time())
|
||
return indicator
|
||
|
||
|
||
class IndicatorManager:
|
||
def __init__(self, save_path: Optional[str] = None):
|
||
if save_path is None:
|
||
save_path = os.path.join(CACHE_DIR, "indicator_library.json")
|
||
self.indicators: Dict[str, Indicator] = {}
|
||
self.save_path = save_path
|
||
self.load_indicators()
|
||
|
||
def load_indicators(self) -> None:
|
||
"""从文件加载指标库"""
|
||
try:
|
||
if os.path.exists(self.save_path):
|
||
with open(self.save_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
for indicator_data in data:
|
||
indicator = Indicator.from_dict(indicator_data)
|
||
self.indicators[indicator.id] = indicator
|
||
except Exception as e:
|
||
st.error(f"加载指标库失败: {str(e)}")
|
||
|
||
def save_indicators(self) -> None:
|
||
"""保存指标库到文件"""
|
||
try:
|
||
with open(self.save_path, "w", encoding="utf-8") as f:
|
||
json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
st.error(f"保存指标库失败: {str(e)}")
|
||
|
||
def add_indicator(self, name: str, query: str = "") -> Indicator:
|
||
"""添加新指标"""
|
||
if not name.strip():
|
||
raise ValueError("指标名称不能为空")
|
||
|
||
indicator = Indicator(name=name, query=query)
|
||
self.indicators[indicator.id] = indicator
|
||
self.save_indicators()
|
||
return indicator
|
||
|
||
def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None, code: Optional[str] = None) -> Optional[Indicator]:
|
||
"""更新指标"""
|
||
if id not in self.indicators:
|
||
return None
|
||
|
||
if name is not None and not name.strip():
|
||
raise ValueError("指标名称不能为空")
|
||
|
||
indicator = self.indicators[id]
|
||
if name is not None:
|
||
indicator.name = name
|
||
if query is not None:
|
||
indicator.query = query
|
||
if code is not None:
|
||
indicator.code = code
|
||
|
||
indicator.updated_at = time.time()
|
||
self.save_indicators()
|
||
return indicator
|
||
|
||
def delete_indicator(self, id: str) -> bool:
|
||
"""删除指标"""
|
||
if id in self.indicators:
|
||
del self.indicators[id]
|
||
self.save_indicators()
|
||
return True
|
||
return False
|
||
|
||
def clear_all_indicators(self) -> None:
|
||
"""清除所有指标"""
|
||
self.indicators.clear()
|
||
self.save_indicators()
|
||
|
||
def get_indicator(self, id: str) -> Optional[Indicator]:
|
||
"""获取指标"""
|
||
return self.indicators.get(id)
|
||
|
||
def get_all_indicators(self) -> List[Indicator]:
|
||
"""获取所有指标"""
|
||
return list(self.indicators.values())
|
||
|
||
|
||
# 保存会话状态
|
||
@st.cache_data(ttl=3600)
|
||
def save_session_state():
|
||
"""保存会话状态到文件"""
|
||
try:
|
||
state = {
|
||
"current_indicator_id": st.session_state.get("current_indicator_id", None)
|
||
}
|
||
state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json")
|
||
with open(state_file, "w", encoding="utf-8") as f:
|
||
json.dump(state, f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
st.error(f"保存会话状态失败: {str(e)}")
|
||
|
||
|
||
# 加载会话状态
|
||
@st.cache_data(ttl=3600)
|
||
def load_session_state():
|
||
"""从文件加载会话状态"""
|
||
try:
|
||
state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json")
|
||
if os.path.exists(state_file):
|
||
with open(state_file, "r", encoding="utf-8") as f:
|
||
state = json.load(f)
|
||
for key, value in state.items():
|
||
if key not in st.session_state:
|
||
st.session_state[key] = value
|
||
except Exception as e:
|
||
st.error(f"加载会话状态失败: {str(e)}")
|
||
|
||
|
||
# 测试指标查询
|
||
@st.cache_data(ttl=60) # 缓存结果1分钟
|
||
def generate_indicator_code(query: str) -> str:
|
||
"""生成指标代码"""
|
||
if not query.strip():
|
||
return "查询语句为空,无法生成代码"
|
||
|
||
result = ""
|
||
|
||
try:
|
||
# 步骤1:通过调用user_interaction.understand(query)实现
|
||
result = "正在理解查询语句..."
|
||
if "user_interaction" not in st.session_state:
|
||
st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction
|
||
|
||
understanding_result = st.session_state.user_interaction.understand(query)
|
||
if not understanding_result:
|
||
return "❌ 无法理解查询语句"
|
||
|
||
# 步骤2:验证步骤一返回值
|
||
result = "正在验证理解结果..."
|
||
selected_knowledge = []
|
||
entity_info = []
|
||
for item in understanding_result:
|
||
entity_info.append(f"识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}")
|
||
selected_knowledge.append(item)
|
||
|
||
if not selected_knowledge:
|
||
return "❌ 没有找到相关知识"
|
||
|
||
# 显示识别到的实体信息
|
||
result = "\n".join(entity_info)
|
||
|
||
# 步骤3:dialog_manager.generated_code(query, selected_knowledge) 实现
|
||
result = "正在生成指标代码..."
|
||
if "dialog_manager" not in st.session_state:
|
||
return "❌ 对话管理器未初始化"
|
||
|
||
generated_result = st.session_state.dialog_manager.generated_code(query, selected_knowledge)
|
||
|
||
# 处理返回结构
|
||
if not generated_result:
|
||
return "❌ 代码生成失败"
|
||
|
||
# 检查返回的是否是字典结构
|
||
if isinstance(generated_result, dict):
|
||
if not generated_result.get('status', False):
|
||
return f"❌ 代码生成失败: {generated_result.get('message', '未知错误')}"
|
||
|
||
# 提取代码内容
|
||
code = generated_result.get('data', '')
|
||
if not code:
|
||
return "❌ 生成的代码为空"
|
||
|
||
# 将生成的代码保存到session_state中
|
||
st.session_state.generated_code = code
|
||
else:
|
||
# 如果直接返回的是代码字符串
|
||
st.session_state.generated_code = generated_result
|
||
|
||
if "current_indicator_id" in st.session_state:
|
||
st.session_state.current_query_id = st.session_state.current_indicator_id
|
||
|
||
return "✅ 指标代码生成成功,可以在下方查看并执行"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
return f"❌ 指标代码生成失败: {str(e)}"
|
||
|
||
return result
|
||
|
||
@st.cache_resource
|
||
def init_app_data():
|
||
"""
|
||
初始化应用数据,使用缓存资源以避免重复加载
|
||
"""
|
||
# 加载依赖
|
||
deps = load_dependencies()
|
||
|
||
config = deps["Config"]()
|
||
|
||
business_structure = deps["load_file"](config.business_object_structure_path)
|
||
bowei_api_docs = deps["load_file"](config.bowei_api_docs_path)
|
||
|
||
llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai)
|
||
user_interaction = deps["UserInteraction"](llm_client.llm, business_structure)
|
||
|
||
llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder)
|
||
|
||
prompt_manager = deps["PromptManager"]()
|
||
|
||
neo4j_conf = config.neo4j_conf
|
||
embedding_conf = config.embedding
|
||
|
||
embedding_client = deps["EmbeddingClient"](embedding_conf)
|
||
|
||
# 创建Neo4j检索器
|
||
knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf)
|
||
|
||
deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver)
|
||
|
||
code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||
|
||
dialog_manager = deps["DialogManager"](
|
||
llm_client,
|
||
business_structure,
|
||
bowei_api_docs,
|
||
code_executor,
|
||
knowledge_retriever,
|
||
prompt_manager,
|
||
)
|
||
|
||
# 将必要的对象存储在session_state中
|
||
return {
|
||
"dialog_manager": dialog_manager,
|
||
"user_interaction": user_interaction,
|
||
"code_executor": code_executor,
|
||
"knowledge_retriever": knowledge_retriever
|
||
}
|
||
|
||
|
||
# 自动补全策略相关函数
|
||
def create_project_division_strategy():
|
||
"""创建项目划分自动补全策略"""
|
||
return StrategyProps(
|
||
id="projectDivision",
|
||
match="从项目划分【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const divisions = [
|
||
"架空输电线路本体工程/基础工程",
|
||
"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立",
|
||
"架空输电线路本体工程/架线工程",
|
||
"架空输电线路本体工程/附件安装工程"
|
||
];
|
||
const matches = divisions.filter(div => div.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(division) => `从项目划分【${division}】`",
|
||
template="(division) => `📁 ${division}`",
|
||
)
|
||
|
||
def create_search_type_strategy():
|
||
"""创建搜索类型自动补全策略"""
|
||
return StrategyProps(
|
||
id="searchType",
|
||
match="中查找(名称|编码)(包含|等于)【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const types = ["名称包含", "编码包含", "名称等于", "编码等于"];
|
||
const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(type) => `中查找${type}【`",
|
||
template="(type) => `🔍 ${type}`",
|
||
)
|
||
|
||
def create_item_type_strategy():
|
||
"""创建项目类型自动补全策略"""
|
||
return StrategyProps(
|
||
id="itemType",
|
||
match="的所有【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const types = ["定额", "主材", "取费名称"];
|
||
const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(type) => `的所有【${type}】`",
|
||
template="(type) => `📋 ${type}`",
|
||
)
|
||
|
||
def create_attribute_strategy():
|
||
"""创建属性自动补全策略"""
|
||
return StrategyProps(
|
||
id="attribute",
|
||
match="的【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const attributes = ["数量", "单价", "合计费", "属性"];
|
||
const matches = attributes.filter(attr => attr.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(attr) => `的【${attr}】`",
|
||
template="(attr) => `📊 ${attr}`",
|
||
)
|
||
|
||
def create_fee_type_strategy():
|
||
"""创建费用类型自动补全策略"""
|
||
return StrategyProps(
|
||
id="feeType",
|
||
match="从【(.*)】中",
|
||
search="""async (term, callback) => {
|
||
const types = ["工程费用", "其他费用"];
|
||
const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(type) => `从【${type}】中`",
|
||
template="(type) => `💰 ${type}`",
|
||
)
|
||
|
||
def create_project_name_strategy():
|
||
"""创建项目名称自动补全策略"""
|
||
return StrategyProps(
|
||
id="projectName",
|
||
match="获取【(.*)】的",
|
||
search="""async (term, callback) => {
|
||
const names = [
|
||
"架空输电线路本体工程",
|
||
"建设场地征用及清理费"
|
||
];
|
||
const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(name) => `获取【${name}】的`",
|
||
template="(name) => `🏗️ ${name}`",
|
||
)
|
||
|
||
def create_code_value_strategy():
|
||
"""创建编码值自动补全策略"""
|
||
return StrategyProps(
|
||
id="codeValue",
|
||
match="编码包含【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const codes = ["YX2-1~7", "YX5-9"];
|
||
const matches = codes.filter(code => code.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(code) => `编码包含【${code}】`",
|
||
template="(code) => `🔢 ${code}`",
|
||
)
|
||
|
||
def create_name_value_strategy():
|
||
"""创建名称值自动补全策略"""
|
||
return StrategyProps(
|
||
id="nameValue",
|
||
match="名称包含【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const names = ["角钢", "钢管杆"];
|
||
const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(name) => `名称包含【${name}】`",
|
||
template="(name) => `📝 ${name}`",
|
||
)
|
||
|
||
def create_name_equal_strategy():
|
||
"""创建名称等于自动补全策略"""
|
||
return StrategyProps(
|
||
id="nameEqual",
|
||
match="取费名称等于【(.*)】",
|
||
search="""async (term, callback) => {
|
||
const names = ["合计"];
|
||
const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(name) => `取费名称等于【${name}】`",
|
||
template="(name) => `📌 ${name}`",
|
||
)
|
||
|
||
def create_query_template_strategy():
|
||
"""创建查询模板自动补全策略"""
|
||
return StrategyProps(
|
||
id="queryTemplate",
|
||
match="^$",
|
||
search="""async (term, callback) => {
|
||
const templates = [
|
||
"从项目划分【架空输电线路本体工程/基础工程】中查找编码包含【YX2-1~7】的所有【定额】的【数量】之和",
|
||
"从项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】中查找名称包含【角钢】的所有【主材】的【数量】之和",
|
||
"从项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】中查找名称包含【钢管杆】的所有【主材】的【单价】之和",
|
||
"从项目划分【架空输电线路本体工程/架线工程】中查找编码中包含【YX5-9】的所有【定额】的【数量】之和",
|
||
"从【工程费用】中获取【架空输电线路本体工程】的【合计费】属性",
|
||
"从【其他费用】中获取【建设场地征用及清理费】的属性",
|
||
"从项目划分【架空输电线路本体工程/附件安装工程】中获取取费名称等于【合计】的费用"
|
||
];
|
||
const matches = templates.filter(template => template.toLowerCase().includes(term.toLowerCase()));
|
||
callback(matches);
|
||
}""",
|
||
replace="(template) => template",
|
||
template="(template) => `📋 ${template}`",
|
||
)
|
||
|
||
def apply_autocomplete_to_query_input(area_label="输入查询语句", max_count=5):
|
||
"""应用自动补全到查询输入框"""
|
||
# 初始化textcomplete组件
|
||
textcomplete(
|
||
area_label=area_label,
|
||
strategies=[
|
||
create_query_template_strategy(),
|
||
create_project_division_strategy(),
|
||
create_search_type_strategy(),
|
||
create_item_type_strategy(),
|
||
create_attribute_strategy(),
|
||
create_fee_type_strategy(),
|
||
create_project_name_strategy(),
|
||
create_code_value_strategy(),
|
||
create_name_value_strategy(),
|
||
create_name_equal_strategy()
|
||
],
|
||
max_count=max_count
|
||
)
|
||
|
||
# Streamlit 应用界面
|
||
def main():
|
||
st.set_page_config(page_title="造价工程指标管理器", layout="wide")
|
||
|
||
# 使用进度条显示加载过程
|
||
with st.spinner("正在初始化应用..."):
|
||
# 初始化应用数据
|
||
app_data = init_app_data()
|
||
|
||
# 将必要的对象存储在session_state中
|
||
for key, value in app_data.items():
|
||
st.session_state[key] = value
|
||
|
||
# 初始化会话状态
|
||
initialize_session_state()
|
||
|
||
# 侧边栏
|
||
render_sidebar()
|
||
|
||
# 主区域
|
||
render_main_area()
|
||
|
||
|
||
def initialize_session_state():
|
||
"""初始化会话状态"""
|
||
if "indicator_manager" not in st.session_state:
|
||
st.session_state.indicator_manager = IndicatorManager()
|
||
|
||
if "current_indicator_id" not in st.session_state:
|
||
st.session_state.current_indicator_id = None
|
||
|
||
# 加载保存的会话状态
|
||
load_session_state()
|
||
|
||
|
||
def render_sidebar():
|
||
"""渲染侧边栏内容"""
|
||
with st.sidebar:
|
||
# 第一个expander:指标管理
|
||
render_indicator_management_expander()
|
||
|
||
# 第二个expander:指标库浏览
|
||
render_indicator_library_expander()
|
||
|
||
|
||
def render_indicator_management_expander():
|
||
"""渲染指标管理扩展面板"""
|
||
with st.expander("新建指标工具", expanded=True):
|
||
# 显示清除指标链接
|
||
col1, col2, col3 = st.columns([2, 1, 1])
|
||
with col1:
|
||
st.subheader("指标列表")
|
||
with col2:
|
||
render_new_indicator_button()
|
||
with col3:
|
||
render_clear_indicators_button()
|
||
|
||
# 显示指标列表
|
||
render_indicator_list()
|
||
|
||
|
||
def render_new_indicator_button():
|
||
"""渲染新建指标按钮"""
|
||
if st.button("➕新建"):
|
||
try:
|
||
# 生成新指标名称
|
||
indicator_count = len(st.session_state.indicator_manager.get_all_indicators())
|
||
new_name = f"指标 {indicator_count + 1}"
|
||
|
||
# 创建新指标
|
||
new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name)
|
||
st.session_state.current_indicator_id = new_indicator.id
|
||
# 清除已生成的代码
|
||
if "generated_code" in st.session_state:
|
||
del st.session_state.generated_code
|
||
# 清除第二个expander的状态
|
||
st.session_state.view_indicator_detail = False
|
||
save_session_state()
|
||
st.rerun()
|
||
except ValueError as e:
|
||
st.error(str(e))
|
||
|
||
|
||
def render_clear_indicators_button():
|
||
"""渲染清空指标按钮"""
|
||
if st.button("🗑️清空", help="删除所有指标", type="secondary"):
|
||
st.session_state.indicator_manager.clear_all_indicators()
|
||
st.session_state.current_indicator_id = None
|
||
# 清除已生成的代码
|
||
if "generated_code" in st.session_state:
|
||
del st.session_state.generated_code
|
||
# 清除第二个expander的状态
|
||
st.session_state.view_indicator_detail = False
|
||
save_session_state()
|
||
st.rerun()
|
||
|
||
|
||
def render_indicator_list():
|
||
"""渲染指标列表"""
|
||
indicators = st.session_state.indicator_manager.get_all_indicators()
|
||
if not indicators:
|
||
st.info("暂无指标,请点击\"新建指标\"按钮创建")
|
||
else:
|
||
for indicator in sorted(indicators, key=lambda x: x.created_at):
|
||
if st.button(
|
||
f"📄 {indicator.name}",
|
||
key=f"btn_{indicator.id}",
|
||
use_container_width=True,
|
||
type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary"
|
||
):
|
||
st.session_state.current_indicator_id = indicator.id
|
||
# 更新显示的代码
|
||
if indicator.code:
|
||
st.session_state.generated_code = indicator.code
|
||
elif "generated_code" in st.session_state:
|
||
# 如果当前指标没有代码,清除session中的代码
|
||
del st.session_state.generated_code
|
||
# 清除第二个expander的状态
|
||
st.session_state.view_indicator_detail = False
|
||
save_session_state()
|
||
st.rerun()
|
||
|
||
|
||
def render_indicator_library_expander():
|
||
"""渲染指标库浏览扩展面板"""
|
||
with st.expander("查看指标库", expanded=False):
|
||
# 设置默认文件路径
|
||
file_path = os.path.join(DATA_DIR, "code.jsonl")
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
st.warning(f"文件不存在: {file_path}")
|
||
else:
|
||
# 加载JSONL文件
|
||
records = load_jsonl(file_path)
|
||
|
||
if not records:
|
||
st.warning("没有找到有效的指标")
|
||
else:
|
||
# 初始化会话状态
|
||
if "selected_index" not in st.session_state:
|
||
st.session_state.selected_index = 0
|
||
|
||
# 显示指标列表
|
||
render_library_indicator_list(records)
|
||
|
||
|
||
@st.cache_data
|
||
def load_jsonl(file_path):
|
||
"""加载JSONL文件并返回JSON指标列表"""
|
||
records = []
|
||
try:
|
||
with open(file_path, "r", encoding="utf-8") as file:
|
||
for line in file:
|
||
if line.strip():
|
||
records.append(json.loads(line))
|
||
return records
|
||
except Exception as e:
|
||
st.error(f"加载文件失败: {str(e)}")
|
||
return []
|
||
|
||
|
||
def render_library_indicator_list(records):
|
||
"""渲染指标库列表"""
|
||
st.subheader("指标列表")
|
||
for i, record in enumerate(records):
|
||
if st.button(f"📚 {record.get('name', f'指标 {i+1}')}", key=f"lib_btn_{i}", use_container_width=True):
|
||
st.session_state.selected_index = i
|
||
# 清除第一个expander的选中状态
|
||
st.session_state.current_indicator_id = None
|
||
# 设置查看指标详情的逻辑
|
||
st.session_state.view_indicator_detail = True
|
||
st.session_state.current_library_indicator = record
|
||
save_session_state()
|
||
st.rerun()
|
||
|
||
|
||
def render_main_area():
|
||
"""渲染主区域内容"""
|
||
if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state:
|
||
render_library_indicator_detail()
|
||
elif st.session_state.current_indicator_id is not None:
|
||
render_current_indicator_detail()
|
||
else:
|
||
st.info("请在左侧创建或选择一个指标")
|
||
|
||
|
||
def render_library_indicator_detail():
|
||
"""渲染指标库详情"""
|
||
library_indicator = st.session_state.current_library_indicator
|
||
|
||
# 指标名称和操作按钮
|
||
col1, col2, col3 = st.columns([1, 4, 2])
|
||
with col1:
|
||
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
|
||
with col2:
|
||
st.markdown(f'<div style="font-size: 20px; padding-top: 4px;">{library_indicator.get("name", "未命名指标")}</div>', unsafe_allow_html=True)
|
||
|
||
|
||
# 查询语句部分
|
||
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
|
||
with title_col:
|
||
st.subheader("查询语句")
|
||
with result_col:
|
||
title_query_result = st.subheader("")
|
||
|
||
# 显示查询语句
|
||
st.text_area(
|
||
"查询语句",
|
||
value=library_indicator.get('query', '无查询语句'),
|
||
height=70,
|
||
key="lib_query_display",
|
||
disabled=True,
|
||
label_visibility="collapsed"
|
||
)
|
||
|
||
# 指标代码部分
|
||
render_library_indicator_code(library_indicator)
|
||
|
||
|
||
def render_library_indicator_code(library_indicator):
|
||
"""渲染指标库代码部分"""
|
||
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
|
||
with title_col:
|
||
st.subheader("指标代码")
|
||
with result_col:
|
||
title_code_result = st.subheader("")
|
||
with run_btn_col:
|
||
if st.button("▶️ 执行代码", key="execute_lib_indicator", use_container_width=True):
|
||
execute_indicator_code(library_indicator.get('code', ''), title_code_result)
|
||
|
||
# 显示代码
|
||
st.code(library_indicator.get('code', '无代码'), language='python')
|
||
|
||
|
||
def render_current_indicator_detail():
|
||
"""渲染当前指标详情"""
|
||
current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id)
|
||
|
||
if current_indicator:
|
||
# 指标名称和保存按钮
|
||
render_indicator_header(current_indicator)
|
||
|
||
# 查询语句
|
||
render_query_input_section(current_indicator)
|
||
|
||
# 代码执行部分
|
||
render_code_execution_section(current_indicator)
|
||
else:
|
||
st.warning("找不到当前指标,可能已被删除")
|
||
|
||
|
||
def render_indicator_header(current_indicator):
|
||
"""渲染指标头部(名称和保存按钮)"""
|
||
col1, col2, col3 = st.columns([1, 4, 2])
|
||
|
||
with col1:
|
||
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
|
||
|
||
with col2:
|
||
new_name = st.text_input("指标名称", value=current_indicator.name, key="indicator_name", label_visibility="collapsed")
|
||
if new_name != current_indicator.name:
|
||
if new_name is None or not new_name.strip():
|
||
st.error("指标名称不能为空")
|
||
else:
|
||
try:
|
||
st.session_state.indicator_manager.update_indicator(
|
||
current_indicator.id, name=new_name
|
||
)
|
||
except ValueError as e:
|
||
st.error(str(e))
|
||
|
||
with col3:
|
||
render_save_indicator_button(current_indicator, new_name)
|
||
|
||
|
||
def render_save_indicator_button(current_indicator, new_name):
|
||
"""渲染保存指标按钮"""
|
||
if st.button("💾 保存指标", use_container_width=True):
|
||
if new_name is None or not new_name.strip():
|
||
st.error("指标名称不能为空")
|
||
elif not current_indicator.query.strip():
|
||
st.error("请先输入查询语句")
|
||
elif not current_indicator.code:
|
||
st.error("请先生成指标代码")
|
||
else:
|
||
try:
|
||
save_indicator_to_library(current_indicator)
|
||
st.toast("指标已保存")
|
||
st.rerun() # 刷新页面以更新指标库列表
|
||
except Exception as e:
|
||
st.error(f"保存失败: {str(e)}")
|
||
|
||
|
||
def save_indicator_to_library(current_indicator):
|
||
"""保存指标到指标库"""
|
||
# 保存到指标管理器
|
||
st.session_state.indicator_manager.save_indicators()
|
||
|
||
# 同时保存到指标库
|
||
library_path = os.path.join(DATA_DIR, "code.jsonl")
|
||
new_record = {
|
||
"name": current_indicator.name,
|
||
"query": current_indicator.query,
|
||
"code": current_indicator.code,
|
||
"created_at": time.time()
|
||
}
|
||
|
||
# 追加到文件
|
||
with open(library_path, "a", encoding="utf-8") as f:
|
||
f.write(json.dumps(new_record, ensure_ascii=False) + "\n")
|
||
|
||
# 清除指标库的缓存,以便重新加载
|
||
load_jsonl.clear()
|
||
|
||
|
||
def render_query_input_section(current_indicator):
|
||
"""渲染查询输入部分"""
|
||
title_col, test_btn_col, result_col = st.columns([1, 1, 4])
|
||
with title_col:
|
||
st.subheader("查询语句")
|
||
with result_col:
|
||
title_query_result = st.subheader("")
|
||
|
||
# 查询输入框
|
||
query = st.text_area(
|
||
"输入查询语句",
|
||
value=current_indicator.query,
|
||
height=70,
|
||
key="query_input",
|
||
placeholder="在此输入查询语句...",
|
||
label_visibility="collapsed"
|
||
)
|
||
|
||
# 应用自动补全
|
||
apply_autocomplete_to_query_input()
|
||
|
||
if query != current_indicator.query:
|
||
st.session_state.indicator_manager.update_indicator(
|
||
current_indicator.id, query=query
|
||
)
|
||
|
||
# 生成代码按钮
|
||
with test_btn_col:
|
||
render_generate_code_button(query, title_query_result, current_indicator)
|
||
|
||
|
||
def render_generate_code_button(query, title_query_result, current_indicator):
|
||
"""渲染生成代码按钮"""
|
||
if st.button("🔄 生成代码", key="generate_code_btn", use_container_width=True):
|
||
with st.spinner("正在生成指标代码..."):
|
||
# 使用缓存的测试函数
|
||
result = generate_indicator_code(query)
|
||
|
||
# 根据结果显示不同的状态
|
||
if result.startswith("✅"):
|
||
title_query_result.success(result)
|
||
# 生成成功后,立即保存代码到指标对象中
|
||
if "generated_code" in st.session_state:
|
||
st.session_state.indicator_manager.update_indicator(
|
||
current_indicator.id,
|
||
code=st.session_state.generated_code
|
||
)
|
||
elif result.startswith("❌"):
|
||
title_query_result.error(result)
|
||
else:
|
||
# 显示实体识别结果或其他信息
|
||
title_query_result.info(result)
|
||
|
||
|
||
def render_code_execution_section(current_indicator):
|
||
"""渲染代码执行部分"""
|
||
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
|
||
with title_col:
|
||
st.subheader("指标代码")
|
||
with result_col:
|
||
title_code_result = st.subheader("")
|
||
with run_btn_col:
|
||
if st.button("▶️ 执行代码", key="execute_code_btn", use_container_width=True):
|
||
with st.spinner("正在执行代码..."):
|
||
if "code_executor" in st.session_state and "generated_code" in st.session_state:
|
||
execute_and_update_code(current_indicator, title_code_result)
|
||
else:
|
||
title_code_result.error("代码执行器未初始化或没有生成的代码")
|
||
|
||
# 显示代码
|
||
st.code(st.session_state.get("generated_code", "# 暂无代码,请输入有效的查询语句并点击测试按钮"), language="python")
|
||
|
||
|
||
def execute_and_update_code(current_indicator, title_code_result):
|
||
"""执行并更新代码"""
|
||
result = st.session_state.code_executor.execute_code(st.session_state.generated_code)
|
||
if result and hasattr(result, 'get') and result.get('status'):
|
||
title_code_result.success("执行成功")
|
||
title_code_result.json(result.get('data', {}))
|
||
# 更新指标的代码
|
||
st.session_state.indicator_manager.update_indicator(
|
||
current_indicator.id,
|
||
code=st.session_state.generated_code
|
||
)
|
||
else:
|
||
title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
|
||
|
||
|
||
def execute_indicator_code(code, result_container):
|
||
"""执行指标代码"""
|
||
with st.spinner('正在执行代码...'):
|
||
# 确保code_executor已初始化
|
||
if "code_executor" in st.session_state:
|
||
result = st.session_state.code_executor.execute_code(code)
|
||
if result and hasattr(result, 'get') and result.get('status'):
|
||
result_container.success("执行成功")
|
||
result_container.json(result.get('data', {}))
|
||
else:
|
||
result_container.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
|
||
else:
|
||
result_container.error("代码执行器未初始化")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |