From fcb09c04f27648de07e0625eb4ae9f84b7176028 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 7 Jul 2025 12:47:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + data/code.jsonl | 46 ++++ main_streamlt.py | 14 +- main_streamlt_qa.py | 219 ++++++++++++++++++ src/project.py | 2 +- .../interaction_generated_code.py | 10 +- tools/zhibiao_to_interaction.py | 184 +++++++++++++++ 7 files changed, 464 insertions(+), 13 deletions(-) create mode 100644 data/code.jsonl create mode 100644 main_streamlt_qa.py rename tests/test_generated_code.py => tools/interaction_generated_code.py (94%) create mode 100644 tools/zhibiao_to_interaction.py diff --git a/.gitignore b/.gitignore index a2dc25a..b7fa45d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ src/__pycache__/* .vscode/* .cursor/* logs/* +cache/* +转化服务/* \ No newline at end of file diff --git a/data/code.jsonl b/data/code.jsonl new file mode 100644 index 0000000..302dde0 --- /dev/null +++ b/data/code.jsonl @@ -0,0 +1,46 @@ +{"name": "杆塔总基数", "query": "获取项目划分【架空输电线路本体工程/基础工程】下编码包含【YX2-1~7】的所有【定额】的【数量】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n \n def expand_code_range(code_range):\n if '~' not in code_range:\n return code_range\n prefix, range_part = code_range.split('-')\n start, end = map(int, range_part.split('~'))\n return '/'.join([f\"{prefix}-{i}\" for i in range(start, end + 1)])\n \n expanded_codes = expand_code_range('YX2-1~7')\n \n result_dict = project.get_quantities_node_by_parent_and_code(\n parent_path='架空输电线路本体工程/基础工程',\n quantity_type='定额',\n code=expanded_codes\n )\n \n status = result_dict.get('status', False)\n if not status:\n return result_dict\n \n data = result_dict.get('data', [])\n if not isinstance(data, list):\n data = [data]\n \n total_quantity = 0.0\n for item in data:\n if isinstance(item, dict) and '数量' in item:\n try:\n quantity = float(item['数量'])\n total_quantity += quantity\n except (ValueError, TypeError):\n pass\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': {\n 'total_quantity': total_quantity,\n 'details': data\n }\n }"} +{"name": "角钢塔_塔材量", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【角钢】的所有【主材】的【数量】之和", "code": "根据用户问题和上下文信息,我将生成对应的Python代码:\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result = project.get_quantities_node_by_parent_and_name(\n parent_path=division_path,\n quantity_type=\"主材\",\n partial_name=\"角钢\"\n )\n \n if not result.get('status', False):\n return result\n \n total_quantity = 0.0\n materials = result.get('data', [])\n for material in materials:\n if isinstance(material, dict) and material.get('类型') == '1':\n try:\n quantity = float(material.get('数量', 0))\n total_quantity += quantity\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': {\n 'total_quantity': total_quantity,\n 'unit': '个' if materials else '',\n 'materials_count': len(materials)\n }\n }\n```"} +{"name": "角钢塔_其中:高强钢塔材量", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【角钢、高强】的所有【主材】的【数量】之和", "code": "根据用户问题和上下文信息,我将生成对应的Python代码:\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n \n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n quantity_type = \"主材\"\n partial_name = \"角钢、高强\"\n \n result = project.get_quantities_node_by_parent_and_name(division_path, quantity_type, partial_name)\n if not result.get('status', False):\n return result\n \n materials = result.get('data', [])\n if not isinstance(materials, list):\n return {\n 'code': 201,\n 'message': '获取主材数据失败: 返回数据格式不正确',\n 'status': False,\n 'data': []\n }\n \n total_quantity = 0.0\n for material in materials:\n if not isinstance(material, dict):\n continue\n quantity_str = material.get('数量', '0')\n try:\n quantity = float(quantity_str)\n total_quantity += quantity\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': {\n 'total_quantity': total_quantity,\n 'unit': '个' if materials else '',\n 'material_count': len(materials)\n }\n }\n```"} +{"name": "角钢塔_塔材装材费_元", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【角钢】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result = project.get_division_item_by_path(division_path)\n if not result.get('status', False):\n return result\n \n division_data = result.get('data', {})\n if not division_data:\n return {'code': 201, 'message': '项目划分不存在', 'status': False, 'data': []}\n \n parent_path = division_path\n quantity_type = \"主材\"\n partial_name = \"角钢\"\n result = project.get_quantities_node_by_parent_and_name(parent_path, quantity_type, partial_name)\n if not result.get('status', False):\n return result\n \n materials = result.get('data', [])\n if not materials:\n return {'code': 201, 'message': '未找到匹配的主材', 'status': False, 'data': []}\n \n total_price = 0.0\n for material in materials:\n if isinstance(material, dict):\n price_str = material.get('单价含税', '0')\n else:\n price_str = material.单价含税 if hasattr(material, '单价含税') else '0'\n try:\n price = float(price_str)\n total_price += price\n except ValueError:\n continue\n \n return {'code': 200, 'message': 'Ok', 'status': True, 'data': total_price}"} +{"name": "角钢塔_其中:高强钢塔材费用_元", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【角钢、高强】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result = project.get_quantities_node_by_parent_and_name(division_path, \"主材\", \"角钢、高强\")\n status = result.get('status', False)\n data = result.get('data', [])\n if not status:\n return result\n \n total_price = 0.0\n for material in data:\n if isinstance(material, Material) and \"单价不含税\" in material.__dict__:\n try:\n price = float(material.单价不含税)\n total_price += price\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_price\n }"} +{"name": "钢管塔_塔材量", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【钢管塔】的所有【主材】的【数量】之和", "code": "根据用户问题和上下文信息,我将生成对应的Python代码:\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n \n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n division_result = project.get_division_item_by_path(division_path)\n if not division_result.get('status', False):\n return division_result\n \n quantities_result = project.get_quantities_node_by_parent_and_name(\n parent_path=division_path,\n quantity_type=\"主材\",\n partial_name=\"钢管塔\"\n )\n \n if not quantities_result.get('status', False):\n return quantities_result\n \n total_quantity = 0.0\n data = quantities_result.get('data', [])\n for item in data:\n if isinstance(item, dict) and item.get('类型') == '主材' and '钢管塔' in item.get('name', ''):\n try:\n quantity = float(item.get('数量', 0))\n total_quantity += quantity\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': {\n 'total_quantity': total_quantity,\n 'unit': '个' if data else '',\n 'details': [{'name': item.get('name'), 'quantity': item.get('数量')} for item in data if isinstance(item, dict)]\n }\n }\n```"} +{"name": "钢管塔_钢管价格_元", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【钢管塔】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result = project.get_division_item_by_path(division_path)\n if not result.get('status', False):\n return result\n \n division_data = result.get('data', {})\n if not division_data:\n return {'code': 201, 'message': '项目划分节点不存在', 'status': False, 'data': []}\n \n parent_path = division_path\n quantity_type = \"主材\"\n partial_name = \"钢管塔\"\n result = project.get_quantities_node_by_parent_and_name(parent_path, quantity_type, partial_name)\n if not result.get('status', False):\n return result\n \n materials = result.get('data', [])\n if not materials:\n return {'code': 201, 'message': '未找到匹配的主材', 'status': False, 'data': []}\n \n total_price = 0.0\n for material in materials:\n if isinstance(material, dict) and material.get('类型') == '1':\n price_str = material.get('单价含税', '0')\n try:\n price = float(price_str) if price_str else 0.0\n total_price += price\n except ValueError:\n continue\n \n return {'code': 200, 'message': 'Ok', 'status': True, 'data': total_price}"} +{"name": "钢管杆_塔材量", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【钢管杆】的所有【主材】的【数量】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result_dict = project.get_quantities_node_by_parent_and_name(division_path, \"主材\", \"钢管杆\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', [])\n if not status:\n return {'status': False, 'message': message, 'code': code, 'data': []}\n total_quantity = 0.0\n for item in data:\n if isinstance(item, dict) and item.get('类型') == '主材' and '钢管杆' in item.get('name', ''):\n try:\n quantity = float(item.get('数量', 0))\n total_quantity += quantity\n except (ValueError, TypeError):\n continue\n return {'status': True, 'message': 'Ok', 'code': 200, 'data': total_quantity}"} +{"name": "钢管杆_钢管价格_元", "query": "获取项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】及名称包含【钢管杆】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立\"\n result = project.get_quantities_node_by_parent_and_name(division_path, \"主材\", \"钢管杆\")\n if not result.get('status', False):\n return result\n \n total_price = 0.0\n materials = result.get('data', [])\n for material in materials:\n if isinstance(material, Material) and \"钢管杆\" in material.name:\n try:\n price = float(material.单价含税) if material.单价含税 else 0.0\n total_price += price\n except (ValueError, AttributeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_price\n }"} +{"name": "导线及线材_分裂数", "query": "获取项目划分【架空输电线路本体工程/架线工程】下编码包含【['YX5-67', 'YX5-68', 'YX5-69', 'YX5-78', 'YX5-79', 'YX5-59', 'YX5-60', 'YX5-61', 'YX5-62', 'YX5-63', 'YX5-64', 'YX5-65', 'YX5-66', 'YX5-74', 'YX5-75', 'YX5-76', 'YX5-77', 'YX5-52', 'YX5-53', 'YX5-54', 'YX5-55', 'YX5-56', 'YX5-57', 'YX5-58', 'YX5-70', 'YX5-71', 'YX5-72', 'YX5-73', 'YX5-14', 'YX5-15', 'YX5-16', 'YX5-17', 'YX5-43', 'YX5-44', 'YX5-45', 'YX5-46', 'YX5-47', 'YX5-48', 'YX5-49', 'YX5-50', 'YX5-51', 'YX5-10', 'YX5-11', 'YX5-12', 'YX5-13', 'YX5-38', 'YX5-39', 'YX5-40', 'YX5-41', 'YX5-42', 'YX5-8', 'YX5-9']】的所有【定额】的【参数】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n codes = ['YX5-67', 'YX5-68', 'YX5-69', 'YX5-78', 'YX5-79', 'YX5-59', 'YX5-60', 'YX5-61', 'YX5-62', 'YX5-63', 'YX5-64', 'YX5-65', 'YX5-66', 'YX5-74', 'YX5-75', 'YX5-76', 'YX5-77', 'YX5-52', 'YX5-53', 'YX5-54', 'YX5-55', 'YX5-56', 'YX5-57', 'YX5-58', 'YX5-70', 'YX5-71', 'YX5-72', 'YX5-73', 'YX5-14', 'YX5-15', 'YX5-16', 'YX5-17', 'YX5-43', 'YX5-44', 'YX5-45', 'YX5-46', 'YX5-47', 'YX5-48', 'YX5-49', 'YX5-50', 'YX5-51', 'YX5-10', 'YX5-11', 'YX5-12', 'YX5-13', 'YX5-38', 'YX5-39', 'YX5-40', 'YX5-41', 'YX5-42', 'YX5-8', 'YX5-9']\n total_params = 0\n path = '架空输电线路本体工程/架线工程'\n \n for code in codes:\n result_dict = project.get_quantities_node_by_parent_and_code(path, '定额', code)\n if not result_dict.get('status', False):\n return result_dict\n data = result_dict.get('data', [])\n if isinstance(data, list):\n for item in data:\n if isinstance(item, dict) and '参数' in item:\n try:\n total_params += float(item['参数'])\n except (ValueError, TypeError):\n pass\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_params\n }"} +{"name": "导线及线材_其中:节能导线量", "query": "获取项目划分【架空输电线路本体工程/架线工程/导地线架设@@架空输电线路本体工程/架线工程/导地线跨越架设】及名称包含【高导电率】的所有【主材】的【数量】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n total_quantity = 0.0\n \n division_paths = [\n \"架空输电线路本体工程/架线工程/导地线架设\",\n \"架空输电线路本体工程/架线工程/导地线跨越架设\"\n ]\n \n for path in division_paths:\n quantities_result = project.get_quantities_node_by_parent_and_name(\n parent_path=path,\n quantity_type=\"主材\",\n partial_name=\"高导电率\"\n )\n \n if not quantities_result.get('status', False):\n return quantities_result\n \n materials = quantities_result.get('data', [])\n for material in materials:\n if isinstance(material, dict) and material.get('类型') == '1':\n try:\n quantity = float(material.get('数量', '0'))\n total_quantity += quantity\n except ValueError:\n pass\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_quantity\n }"} +{"name": "导线及线材_导线装材费_元", "query": "获取项目划分【架空输电线路本体工程/架线工程/导地线架设@@架空输电线路本体工程/架线工程/导地线跨越架设】及名称包含【线】的所有【主材】的【单价】之和", "code": "根据用户问题和上下文信息,我将生成对应的Python代码:\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n total_price = 0.0\n \n division_paths = [\n \"架空输电线路本体工程/架线工程/导地线架设\",\n \"架空输电线路本体工程/架线工程/导地线跨越架设\"\n ]\n \n for path in division_paths:\n quantities_result = project.get_quantities_node_by_parent_and_name(\n parent_path=path,\n quantity_type=\"主材\",\n partial_name=\"线\"\n )\n \n if not quantities_result.get('status', False):\n return quantities_result\n \n materials = quantities_result.get('data', [])\n for material in materials:\n if isinstance(material, Material):\n try:\n price = float(material.单价含税) if material.单价含税 else 0.0\n total_price += price\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_price\n }\n```"} +{"name": "导线及线材_其中:节能导线费用_元", "query": "获取项目划分【架空输电线路本体工程/架线工程/导地线架设@@架空输电线路本体工程/架线工程/导地线跨越架设】及名称包含【高导电率】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n total_price = 0.0\n \n division_paths = [\n \"架空输电线路本体工程/架线工程/导地线架设\",\n \"架空输电线路本体工程/架线工程/导地线跨越架设\"\n ]\n \n for path in division_paths:\n quantities_result = project.get_quantities_node_by_parent_and_name(\n parent_path=path,\n quantity_type=\"主材\",\n partial_name=\"高导电率\"\n )\n \n if not quantities_result.get('status', False):\n return quantities_result\n \n materials = quantities_result.get('data', [])\n for material in materials:\n if isinstance(material, Material):\n try:\n price = float(material.单价含税) if material.单价含税 else 0.0\n total_price += price\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_price\n }"} +{"name": "导线及线材_导线类型", "query": "获取项目划分【架空输电线路本体工程/架线工程/导地线架设@@架空输电线路本体工程/架线工程/导地线跨越架设@@架空输电线路本体工程/架线工程/其他架线工程】及名称包含【['钢芯铝绞线', '铝包钢芯铝绞线', '中强度铝合金绞线', '铝合金芯铝绞线', '铝合金芯高导电率铝绞线', '钢芯高导电率铝绞线', '特高强度钢芯铝合金绞线', '扩径导线', '耐热导线', '碳纤维导线']】的所有【主材】的【参数】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_paths = [\n \"架空输电线路本体工程/架线工程/导地线架设\",\n \"架空输电线路本体工程/架线工程/导地线跨越架设\",\n \"架空输电线路本体工程/架线工程/其他架线工程\"\n ]\n material_names = [\n '钢芯铝绞线', '铝包钢芯铝绞线', '中强度铝合金绞线', '铝合金芯铝绞线', \n '铝合金芯高导电率铝绞线', '钢芯高导电率铝绞线', '特高强度钢芯铝合金绞线', \n '扩径导线', '耐热导线', '碳纤维导线'\n ]\n \n total_params = {}\n \n for path in division_paths:\n for name in material_names:\n result = project.get_quantities_node_by_parent_and_name(\n parent_path=path,\n quantity_type=\"主材\",\n partial_name=name\n )\n if not result.get('status', False):\n continue\n \n materials = result.get('data', [])\n for material in materials:\n if not isinstance(material, Material):\n continue\n \n for attr, value in material.__dict__.items():\n if not attr.startswith('_') and isinstance(value, (int, float)):\n if attr not in total_params:\n total_params[attr] = 0\n try:\n total_params[attr] += float(value)\n except (ValueError, TypeError):\n pass\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_params\n }"} +{"name": "基础钢材量", "query": "获取项目划分【架空输电线路本体工程/基础工程/基础砌筑】及名称包含【圆钢】的所有【主材】的【数量】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/基础工程/基础砌筑\"\n result = project.get_quantities_node_by_parent_and_name(division_path, \"主材\", \"圆钢\")\n if not result.get('status', False):\n return result\n \n total_quantity = 0.0\n for material in result.get('data', []):\n if isinstance(material, dict) and material.get('类型') == '主材' and '圆钢' in material.get('name', ''):\n try:\n quantity = float(material.get('数量', 0))\n total_quantity += quantity\n except (ValueError, TypeError):\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_quantity\n }"} +{"name": "基础钢材价格", "query": "获取项目划分【架空输电线路本体工程/基础工程/基础砌筑】及名称包含【圆钢】的所有【主材】的【单价】之和", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n division_path = \"架空输电线路本体工程/基础工程/基础砌筑\"\n quantities_result = project.get_quantities_by_paths(division_path)\n if not quantities_result.get('status', False):\n return quantities_result\n \n materials_result = project.get_quantities_node_by_parent_and_name(division_path, \"主材\", \"圆钢\")\n if not materials_result.get('status', False):\n return materials_result\n \n total_price = 0.0\n for material in materials_result.get('data', []):\n if isinstance(material, dict) and material.get('类型') == '主材' and '圆钢' in material.get('name', ''):\n price_str = material.get('单价含税', '0')\n try:\n price = float(price_str) if price_str else 0.0\n total_price += price\n except ValueError:\n continue\n \n return {\n 'code': 200,\n 'message': 'Ok',\n 'status': True,\n 'data': total_price\n }"} +{"name": "本体费用合计_元", "query": "从【工程费用】中获取【架空输电线路本体工程】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\"工程费用\", \"架空输电线路本体工程\", \"合计费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return data\n else:\n raise Exception(f\"Failed to get fee: {message}\")"} +{"name": "本体工程人工费_本体_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【人工费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"人工费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "本体工程人工费_调试_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【人工费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"人工费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "本体工程机械费_本体_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【施工机械使用费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"施工机械使用费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "本体工程机械费_调试_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【施工机械使用费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"施工机械使用费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "基础工程费用_元", "query": "获取项目划分【架空输电线路本体工程/基础工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/基础工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "杆塔工程费用_元", "query": "获取项目划分【架空输电线路本体工程/杆塔工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/杆塔工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "接地工程费用_元", "query": "获取项目划分【架空输电线路本体工程/接地工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/接地工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "架线工程费用_元", "query": "获取项目划分【架空输电线路本体工程/架线工程】下取费名称等于【合计】的费用", "code": "根据用户问题和上下文信息,我们需要获取项目划分\"架空输电线路本体工程/架线工程\"下取费名称为\"合计\"的费用。根据工程数据访问库,最匹配的方法是`get_fee_table_by_project_division`。\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/架线工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict\n```"} +{"name": "附件工程费用_元", "query": "获取项目划分【架空输电线路本体工程/附件安装工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/附件安装工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "辅助工程费用_元", "query": "获取项目划分【架空输电线路本体工程/辅助工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/辅助工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "辅助工程费用_调试_元", "query": "获取项目划分【架空输电线路本体工程/辅助工程】下取费名称等于【合计】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程/辅助工程\", \"合计\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "辅助设施工程_元", "query": "从【工程费用】中获取【辅助设施工程】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\"工程费用\", \"辅助设施工程\", \"合计费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其他费用合计_元", "query": "从【工程费用】中获取【其他费用】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table('工程费用表', '其他费用', '合计费')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return {'status': True, 'data': data}\n else:\n return {'status': False, 'message': message, 'code': code}"} +{"name": "建场费合计_元", "query": "从【其他费用】中获取【建设场地征用及清理费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table(\"其他费用表\", \"建设场地征用及清理费\", \"amount\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "项目建设管理费合计_元", "query": "从【其他费用】中获取【项目建设管理费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '项目建设管理费', 'amount')\n status = result_dict.get('status', False)\n if not status:\n raise Exception(f\"获取项目建设管理费失败: {result_dict.get('message', '未知错误')}\")\n return result_dict"} +{"name": "其中:工程监理费_元", "query": "从【其他费用】中获取【工程监理费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '工程监理费', 'amount')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "项目建设技术服务费合计_元", "query": "从【其他费用】中获取【项目建设技术服务费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '项目建设技术服务费', 'amount')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其中:项目前期工作费_元", "query": "从【其他费用】中获取【项目前期工作费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '项目前期工作费', 'amount')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其中:勘察费_元", "query": "从【其他费用】中获取【勘察费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '勘察费', 'amount')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其中:设计费_元", "query": "从【其他费用】中获取【设计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table('其他费用表', '设计费', 'amount')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return {'code': code, 'message': message, 'status': status, 'data': data}\n else:\n return {'code': code, 'message': message, 'status': status, 'data': data}"} +{"name": "其中:工程建设检测费_元", "query": "从【其他费用】中获取【工程建设检测费】的值", "code": "根据用户问题和上下文信息,我们需要从\"其他费用\"表中获取\"工程建设检测费\"的值。根据上下文信息,\"其他费用\"是一个FeeScheduleItem类型,而\"工程建设检测费\"是一个Fee类型,且其parent为\"其他费用\"。\n\n最匹配的方法是`get_fee_schedule_on_other_expense_table`,因为它专门用于在其他费用表中查找费用。\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table(\n table_name=\"其他费用表\",\n fee_name=\"工程建设检测费\",\n fee_attribute=\"amount\"\n )\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict\n```"} +{"name": "生产准备费_元", "query": "从【其他费用】中获取【生产准备费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_other_expense_table(\"其他费用表\", \"生产准备费\", \"amount\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其中:安全文明施工费_线路_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【安全文明施工费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"安全文明施工费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "其中:安全文明施工费_调试_元", "query": "获取项目划分【架空输电线路本体工程】下取费名称等于【安全文明施工费】的费用", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_table_by_project_division(\"架空输电线路本体工程\", \"安全文明施工费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "基本预备费_元", "query": "从【工程费用】中获取【基本预备费】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\"工程费用表\", \"基本预备费\", \"合计费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return data\n else:\n raise Exception(f\"Failed to get fee: {message}\")"} +{"name": "静态投资_元", "query": "从【工程费用】中获取【工程静态投资(一~七项合计)】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\"工程费用表\", \"工程静态投资(一~七项合计)\", \"合计费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "建设期利息_元", "query": "从【工程费用】中获取【建设期贷款利息】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table('工程费用表', '建设期贷款利息', '合计费')\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return {'status': True, 'code': code, 'message': message, 'data': data}\n else:\n return {'status': False, 'code': code, 'message': message, 'data': data}"} +{"name": "动态投资_元", "query": "从【工程费用】中获取【工程动态投资(一~八项合计)】的【合计费】的值", "code": "def project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\"工程费用表\", \"工程动态投资(一~八项合计)\", \"合计费\")\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict"} +{"name": "增值税抵扣税额_元", "query": "从【工程费用】中获取【其中:可抵扣增值税额】的【合计费】的值", "code": "根据用户问题和上下文信息,我们需要从工程费用表中获取\"其中:可抵扣增值税额\"的\"合计费\"值。根据提供的工程数据访问库,最匹配的方法是`get_fee_schedule_on_Engineering_Cost_table`。\n\n```python\ndef project_get_calculate_function():\n project = ProjectBuilder.build()\n result_dict = project.get_fee_schedule_on_Engineering_Cost_table(\n table_name=\"工程费用表\",\n fee_name=\"其中:可抵扣增值税额\",\n fee_attribute=\"合计费\"\n )\n status = result_dict.get('status', False)\n message = result_dict.get('message', '')\n code = result_dict.get('code', '')\n data = result_dict.get('data', {})\n if status:\n return result_dict\n else:\n return result_dict\n```"} diff --git a/main_streamlt.py b/main_streamlt.py index b556101..fbecb37 100644 --- a/main_streamlt.py +++ b/main_streamlt.py @@ -591,7 +591,7 @@ def render_indicator_management_expander(): def render_new_indicator_button(): """渲染新建指标按钮""" - if st.button("新建"): + if st.button("➕新建"): try: # 生成新指标名称 indicator_count = len(st.session_state.indicator_manager.get_all_indicators()) @@ -613,7 +613,7 @@ def render_new_indicator_button(): def render_clear_indicators_button(): """渲染清空指标按钮""" - if st.button("清空", help="删除所有指标", type="secondary"): + if st.button("🗑️清空", help="删除所有指标", type="secondary"): st.session_state.indicator_manager.clear_all_indicators() st.session_state.current_indicator_id = None # 清除已生成的代码 @@ -633,7 +633,7 @@ def render_indicator_list(): else: for indicator in sorted(indicators, key=lambda x: x.created_at): if st.button( - indicator.name, + f"📄 {indicator.name}", key=f"btn_{indicator.id}", use_container_width=True, type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary" @@ -694,7 +694,7 @@ def render_library_indicator_list(records): """渲染指标库列表""" st.subheader("指标列表") for i, record in enumerate(records): - if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True): + if st.button(f"📚 {record.get('name', f'指标 {i+1}')}", key=f"lib_btn_{i}", use_container_width=True): st.session_state.selected_index = i # 清除第一个expander的选中状态 st.session_state.current_indicator_id = None @@ -756,7 +756,7 @@ def render_library_indicator_code(library_indicator): with result_col: title_code_result = st.subheader("") with run_btn_col: - if st.button("执行代码", key="execute_lib_indicator", use_container_width=True): + if st.button("▶️ 执行代码", key="execute_lib_indicator", use_container_width=True): execute_indicator_code(library_indicator.get('code', ''), title_code_result) # 显示代码 @@ -906,7 +906,7 @@ def render_code_execution_section(current_indicator): with result_col: title_code_result = st.subheader("") with run_btn_col: - if st.button("执行代码", key="execute_code_btn", use_container_width=True): + if st.button("▶️ 执行代码", key="execute_code_btn", use_container_width=True): with st.spinner("正在执行代码..."): if "code_executor" in st.session_state and "generated_code" in st.session_state: execute_and_update_code(current_indicator, title_code_result) @@ -948,4 +948,4 @@ def execute_indicator_code(code, result_container): if __name__ == "__main__": - main() \ No newline at end of file + main() \ No newline at end of file diff --git a/main_streamlt_qa.py b/main_streamlt_qa.py new file mode 100644 index 0000000..a56efb5 --- /dev/null +++ b/main_streamlt_qa.py @@ -0,0 +1,219 @@ +import streamlit as st +import asyncio +import time +from datetime import datetime +import logging +from typing import List, Tuple, Optional, Dict, Any +import sys +import os + +# 将当前路径添加到系统路径,确保可以导入其他模块 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +# 配置日志 +current_file = os.path.splitext(os.path.basename(__file__))[0] +now_str = datetime.now().strftime("%Y%m%d%H%M%S") +log_filename = f"{current_file}_{now_str}.log" + +# 确保日志目录存在并设置在logs目录下 +log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") +os.makedirs(log_dir, exist_ok=True) + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"), + logging.StreamHandler() + ], +) + +logger = logging.getLogger(current_file) + +def setup_logger(logger_name): + """ + 设置指定名称的logger,将其级别设置为WARNING并禁用传播 + :param logger_name: logger的名称 + """ + logger = logging.getLogger(logger_name) + logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 + logger.propagate = False # 可选:禁用传播(防止被根logger处理) + return logger + + +logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] +for name in logger_names: + setup_logger(name) + +from src.code_executor import CodeExecutor +from src.config import Config + +# 页面配置 +st.set_page_config( + page_title="博微造价工程数据问答系统", + page_icon="📊", + layout="wide", +) + +# 初始化会话状态 +if "messages" not in st.session_state: + st.session_state.messages = [] +if "processing" not in st.session_state: + st.session_state.processing = False + +# 加载配置和初始化组件 +@st.cache_resource +def initialize_components(): + """加载配置并初始化必要组件""" + config = Config() + + # 导入必要模块 + from src.multi_llm_client import MultiAPIKeyChatOpenAI + from src.prompt_manager import PromptManager + from src.dialog_manager import DialogManager + from src.document_loader import load_file + from src.neo4j_raw_retriever import Neo4jRawRetriever + from src.embedding_client import EmbeddingClient + + # 初始化LLM客户端 + llm_client = MultiAPIKeyChatOpenAI(config.openai) + llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder) + + # 初始化提示词管理器 + prompt_manager = PromptManager() + + # 初始化代码执行器 - 设置最大重试次数为3 + code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, max_retries=3) + + # 加载API文档 + api_docs_path = config.bowei_api_docs_path + bowei_api_docs = load_file(api_docs_path) + + # 加载业务对象结构 + business_structure_path = config.business_object_structure_path + business_structure = load_file(business_structure_path) + + # 初始化嵌入客户端 + embedding_client = EmbeddingClient(config.openai) + + # 初始化知识检索器 + knowledge_retriever = Neo4jRawRetriever(config.neo4j_conf) + + # 初始化对话管理器 + dialog_manager = DialogManager( + llm_client, + business_structure, + bowei_api_docs, + code_executor, + knowledge_retriever, + prompt_manager, + ) + st.session_state.dialog_manager = dialog_manager + + return code_executor, config + +# 初始化组件 +try: + code_executor, config = initialize_components() +except Exception as e: + st.error(f"初始化组件失败: {str(e)}") + st.stop() + +# 页面标题 +st.title("博微造价工程数据问答系统") +st.markdown("---") + +# 侧边栏 +with st.sidebar: + st.header("系统信息") + st.info("博微造价工程数据问答系统是一个基于LLM的智能工具,可以帮助您生成并执行Python代码来回答问题。") + + st.markdown("### 使用说明") + st.markdown("1. 在输入框中输入您的问题") + st.markdown("2. 系统将生成并执行相关代码") + st.markdown("3. 查看执行结果和生成的代码") + + st.markdown("---") + + if st.button("清空对话历史"): + st.session_state.messages = [] + st.rerun() + +# 显示聊天历史 +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + +# 处理用户问题 +async def process_user_question(user_question): + """处理用户输入的问题""" + st.session_state.processing = True + + # 将用户问题添加到会话历史 + st.session_state.messages.append({"role": "user", "content": user_question}) + + # 显示用户消息 + with st.chat_message("user"): + st.write(user_question) + + # 显示系统思考中消息 + with st.chat_message("assistant"): + thinking_placeholder = st.empty() + thinking_placeholder.markdown("🤔 思考中...") + + # 理解用户问题并获取重写结果 + try: + # 获取对话管理器 + dialog_manager = st.session_state.get('dialog_manager') + if not dialog_manager: + raise Exception("对话管理器未初始化") + + rewritten_results = await dialog_manager.understand_user_question_stream(user_question) + if not rewritten_results: + thinking_placeholder.markdown("抱歉,没有找到符合要求的数据,请尝试其他问题。") + st.session_state.messages.append({"role": "assistant", "content": "抱歉,没有找到符合要求的数据,请尝试其他问题。"}) + st.session_state.processing = False + return + + # 保存重写结果到会话状态 + st.session_state.rewritten_results = rewritten_results + + # 显示选择项 + thinking_placeholder.empty() + result_container = st.container() + with result_container: + st.markdown("**系统为您理解并改写了以下访问请求,请选择:**") + + # 为每个重写结果创建按钮 + cols = st.columns(min(len(rewritten_results), 3)) + for idx, (rewritten, knowledge) in enumerate(rewritten_results): + col_idx = idx % 3 + with cols[col_idx]: + if st.button(f"选项 {idx+1}: {rewritten[:50]}...", key=f"option_{idx}"): + st.session_state.selected_rewritten = rewritten + st.session_state.selected_knowledge = knowledge + st.rerun() + + # 将选择提示添加到会话历史 + options_text = "**系统为您理解并改写了以下访问请求,请选择:**\n\n" + for idx, (rewritten, _) in enumerate(rewritten_results): + options_text += f"{idx+1}. {rewritten}\n" + + st.session_state.messages.append({"role": "assistant", "content": options_text}) + + except Exception as e: + thinking_placeholder.markdown(f"处理问题时出错: {str(e)}") + logger.error(f"处理问题时出错: {str(e)}") + st.session_state.messages.append({"role": "assistant", "content": f"处理问题时出错: {str(e)}"}) + + st.session_state.processing = False + +# 用户输入框 +if not st.session_state.processing: + user_question = st.chat_input("请输入您的问题...") + if user_question: + asyncio.run(process_user_question(user_question)) + +if __name__ == "__main__": + # Streamlit运行此文件时的入口点 + pass \ No newline at end of file diff --git a/src/project.py b/src/project.py index 7b28cd0..3f8449a 100644 --- a/src/project.py +++ b/src/project.py @@ -65,7 +65,7 @@ class ProjectToolkit(ABC): - code (int): 状态码,固定为 200(成功)或 201(失败) - message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息 - status (bool): 成功为 True,失败为 False - - data (List[Dict[str, Any]]): 成功时返回的数据列表,失败时为空列表 + - data (ProjectDivisionItem | None]): 成功时返回的ProjectDivisionItem,失败时为None """ pass diff --git a/tests/test_generated_code.py b/tools/interaction_generated_code.py similarity index 94% rename from tests/test_generated_code.py rename to tools/interaction_generated_code.py index c3bb449..e0a5a8e 100644 --- a/tests/test_generated_code.py +++ b/tools/interaction_generated_code.py @@ -2,7 +2,7 @@ import os import sys import logging sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - +current_file = os.path.splitext(os.path.basename(__file__))[0] from datetime import datetime # 获取当前时间,格式化为字符串 @@ -19,7 +19,7 @@ logging.basicConfig( ], ) -logger = logging.getLogger("test_code1") +logger = logging.getLogger(current_file) import logging def setup_logger(logger_name): @@ -96,7 +96,7 @@ def main(): # 加载 zhibiao.jsonl zhibiao_data = [] - with open('./tests/zhibiao.jsonl', 'r', encoding='utf-8') as f: + with open('./data/zhibiao.jsonl', 'r', encoding='utf-8') as f: for line in f: zhibiao_data.append(json.loads(line)) @@ -148,13 +148,13 @@ if __name__ == "__main__": print(error) # 保存成功结果到 jsonl 文件 - success_filename = f'./tests/code_{now_str}.jsonl' + success_filename = f'./data/code.jsonl' with open(success_filename, 'w', encoding='utf-8') as f: for item in success_results: f.write(json.dumps(item, ensure_ascii=False) + '\n') # 保存失败结果到 jsonl 文件 - fail_filename = f'./tests/fail_{now_str}.jsonl' + fail_filename = f'./data/fail_{now_str}.jsonl' with open(fail_filename, 'w', encoding='utf-8') as f: for item in fail_results: f.write(json.dumps(item, ensure_ascii=False) + '\n') diff --git a/tools/zhibiao_to_interaction.py b/tools/zhibiao_to_interaction.py new file mode 100644 index 0000000..6a03252 --- /dev/null +++ b/tools/zhibiao_to_interaction.py @@ -0,0 +1,184 @@ +# tests/test_userinteraction.py + +import os +import sys +import json +import logging +import os +from datetime import datetime +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +current_file = os.path.splitext(os.path.basename(__file__))[0] +now_str = datetime.now().strftime("%Y%m%d%H%M%S") +log_filename = f"{current_file}_{now_str}.log" + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"), + logging.StreamHandler() + ], +) + +logger = logging.getLogger(current_file) + +def setup_logger(logger_name): + """ + 设置指定名称的logger,将其级别设置为WARNING并禁用传播 + :param logger_name: logger的名称 + """ + logger = logging.getLogger(logger_name) + logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 + logger.propagate = False # 可选:禁用传播(防止被根logger处理) + return logger + + +logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] +for name in logger_names: + setup_logger(name) + +from src.config import Config +from src.document_loader import load_file +from src.multi_llm_client import MultiAPIKeyChatOpenAI +from src.user_interaction import UserInteraction +import json +from src.dialog_manager import DialogManager +from src.multi_llm_client import MultiAPIKeyChatOpenAI +from src.code_executor import CodeExecutor +from src.neo4j_raw_retriever import Neo4jRawRetriever +from src.prompt_manager import PromptManager +import yaml +from src.config import Config +from src.document_loader import load_file +from src.embedding_client import EmbeddingClient + +from src.project import ProjectBuilder, ProjectToolkit +from src.project_implementation import ProjectToolkitNeo4j + +success_count = 0 +fail_count = 0 +questions = [] +error_list = [] +success_list = [] + +def main(): + global success_count, fail_count, questions, error_list, success_list + + config = Config() + + business_structure = load_file(config.business_object_structure_path) + bowei_api_docs = load_file(config.bowei_api_docs_path) + + llm_client = MultiAPIKeyChatOpenAI(config.openai) + user_interaction = UserInteraction(llm_client.llm, business_structure) + + llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder) + + prompt_manager = PromptManager() + + neo4j_conf = config.neo4j_conf + embedding_conf = config.embedding + + embedding_client = EmbeddingClient(embedding_conf) + + # 创建Neo4j检索器 + knowledge_retriever = Neo4jRawRetriever(neo4j_conf) + + ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver) + + code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries) + + dialog_manager = DialogManager( + llm_client, + business_structure, + bowei_api_docs, + code_executor, + knowledge_retriever, + prompt_manager, + ) + + # 读取 zhibiao.json + zhibiao_path = os.path.join(os.path.dirname(__file__), "../data/zhibiao.json") + with open(zhibiao_path, "r", encoding="utf-8") as f: + zhibiao_data = json.load(f) + + isTest = True + isTest = False + + if isTest: + zhibiao_data = [ + { + "指标名称": "杆塔总基数", + "指标描述": { + "指标映射": "从【架空输电线路本体工程/附件安装工程】项目划分中获取名称属于【'合计'】的费用", + "映射规则": "YX2-1~7" + }, + "code": "", + "单位": "基", + "单价类型": None, + "序号": "1", + "提取方式": None, + "指标类型": "工程量指标", + "数据来源": "定额数量" + } + ] + + for idx, item in enumerate(zhibiao_data): + name = item.get("指标名称", "") + datasource = item.get("数据来源", "") + if datasource in ("报表指标", "指标库"): + logger.info(f"跳过索引 {idx},数据来源为 {datasource}") + continue + + query = item.get("指标描述", {}).get("指标映射", "") + if not query: + logger.warning(f"索引 {idx} 缺少指标映射,跳过") + continue + + try: + # 调用用户交互理解接口(同步调用) + result = user_interaction.understand(query) + if not result: + logger.error(f"问题: {query} 没有找到符合要求的数据") + fail_count += 1 + error_list.append(f"指标名称 {name} 问题 {query} 调用 understand 返回空结果") + continue + + # 这里示例只打印理解结果,你可以根据业务逻辑替换为后续处理 + logger.info( + f"指标名称 {name} 问题: {query} 理解结果: " + f"{[{'name': r.get('name'), 'constraints': r.get('constraints')} for r in result]}" + ) + + success_list.append({ + "name": name, + "query": query, + "result": [{'name': r.get('name'), 'constraints': r.get('constraints')} for r in result] + }) + success_count += 1 + + except Exception as e: + logger.error(f"指标名称 {name} 问题: {query} 处理异常: {e}") + fail_count += 1 + error_list.append(f"指标名称 {name} 问题 {query} 异常: {e}") + + total = success_count + fail_count + success_rate = (success_count / total) * 100 if total > 0 else 0 + fail_rate = (fail_count / total) * 100 if total > 0 else 0 + + print(f"问题总数: {total}") + print(f"成功比例: {success_rate:.2f}%)") + print(f"失败比例: {fail_rate:.2f}%)") + print("错误列表:") + for error in error_list: + print(error) + + # 将成功内容保存为 jsonl 文件 + success_jsonl_path = os.path.join(os.path.dirname(__file__), f"../data/zhibiao.jsonl") + with open(success_jsonl_path, "w", encoding="utf-8") as f: + for item in success_list: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + +if __name__ == "__main__": + main()