Files
GraphRAG/graph/entity_extractor.py
T
2025-03-31 17:28:23 +08:00

193 lines
6.8 KiB
Python

import os
import json
import re
from dotenv import load_dotenv
from utils.llm import llm
from utils.prompt import PROMPTS
load_dotenv()
class EntityRelationExtractor:
def __init__(self):
# 使用从llm.py导入的模型
self.llm = llm
# 设置分隔符
self.tuple_delimiter = "|||"
self.record_delimiter = "\n"
self.completion_delimiter = ""
def extract_from_text(self, text, entity_types=None):
"""从文本中提取实体和关系"""
try:
entity_types = PROMPTS["DEFAULT_ENTITY_TYPES"]
relationship_types = PROMPTS["DEFAULT_RELATIONSHIP_TYPES"]
# 构建提示词
prompt = PROMPTS["entity_extraction"]
prompt = prompt.replace("{tuple_delimiter}", self.tuple_delimiter)
prompt = prompt.replace("{record_delimiter}", self.record_delimiter)
prompt = prompt.replace("{completion_delimiter}", self.completion_delimiter)
# 添加实体类型和文本内容
entity_types_str = ", ".join(entity_types)
relationship_types_str = ", ".join(relationship_types)
user_message = (
f"实体类型列表: {entity_types_str}\n\n关系类型列表: {relationship_types_str}\n\n文本内容:\n{text}"
)
# 调用LLM
full_prompt = f"System: {prompt}\n\nUser: {user_message}"
response = self.llm.generate(full_prompt)
# 解析结果
return self._parse_extraction_result(response)
except Exception as e:
return {"error": str(e)}
def _parse_extraction_result(self, result):
"""解析模型返回的实体和关系结果"""
entities = []
relations = []
keywords = []
# 按行分割结果
lines = result.strip().split(self.record_delimiter)
for line in lines:
if not line.strip():
continue
# 移除可能的括号
line = line.strip()
if line.startswith("(") and line.endswith(")"):
line = line[1:-1]
# 分割字段
parts = line.split(self.tuple_delimiter)
if len(parts) < 3:
continue
record_type = parts[0].strip('"')
if record_type == "entity":
if len(parts) >= 4:
entity = {
"name": parts[1].strip('"'),
"type": parts[2].strip('"'),
"description": parts[3].strip('"'),
}
entities.append(entity)
elif record_type == "relationship":
if len(parts) >= 5:
relation = {
"source": parts[1].strip('"'),
"target": parts[2].strip('"'),
"description": parts[3].strip('"'),
"type": parts[4].strip('"'),
"confidence": float(parts[5]) if len(parts) > 5 else 1.0,
}
relations.append(relation)
elif record_type == "content_keywords":
keywords = [kw.strip() for kw in parts[1].strip('"').split(",")]
return {"entities": entities, "relations": relations, "keywords": keywords}
def extract_from_file(self, file_path, entity_types=None):
"""从文件中提取实体和关系"""
try:
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return self.extract_from_text(text, entity_types)
except Exception as e:
return {"error": f"读取文件时出错: {str(e)}"}
def extract_from_documents(self, documents, entity_types=None):
"""从多个文档中提取实体和关系"""
all_entities = {}
all_relations = []
all_keywords = set()
for doc in documents:
result = self.extract_from_text(doc, entity_types)
if "error" in result:
continue
# 合并实体(避免重复)
for entity in result.get("entities", []):
entity_name = entity["name"]
if entity_name not in all_entities:
all_entities[entity_name] = entity
# 添加关系
all_relations.extend(result.get("relations", []))
# 添加关键词
all_keywords.update(result.get("keywords", []))
return {"entities": list(all_entities.values()), "relations": all_relations, "keywords": list(all_keywords)}
# 测试代码
if __name__ == "__main__":
extractor = EntityRelationExtractor()
# 测试文本
test_text = """
(电力建设计价通软件) (计价通)导入或清除电子徽标
# (电力建设计价通软件) (计价通)导入或清除电子徽标
## 使用场景
1.单位对打印的报表有要求,需要插入显示企业电子徽标。
2.报表中有电子徽标,现在想要清除。
## 功能入口
【报表输出】界面——“导入电子徽标”按钮。
![导入电子徽标-功能入口](https://172.20.0.145/files/acee3357-594a-4bc2-9f6d-6a015a0b6512/image-preview)
## 操作步骤
### 导入电子徽标
1.左侧勾选需要显示徽标的报表,点击“选择徽标”,选中需要导入的电子徽标(可识别bmp,gif,jpg格式文件),点击“打开”即可导入徽标;
![导入电子徽标](https://172.20.0.145/files/756275ee-ca64-42b4-8dae-858fc1d980ef/image-preview)
### **清除电子徽标**
左侧勾选需要清除徽标的报表,点击“清除徽标”按钮,点击“确定”,即可清除徽标。
![清除电子徽标](https://172.20.0.145/files/9a20c719-4cff-4102-af7a-640967b6ad12/image-preview)
"""
# 提取实体和关系
print("正在提取实体和关系...")
result = extractor.extract_from_text(test_text)
# 打印结果
print("\n提取结果:")
if "error" in result:
print(f"错误: {result['error']}")
else:
print(f"发现 {len(result.get('entities', []))} 个实体和 {len(result.get('relations', []))} 个关系")
# 打印实体
print("\n实体:")
for entity in result.get("entities", []):
print(f"- {entity['name']} (类型: {entity['type']})")
print(f" 描述: {entity['description']}")
# 打印关系
print("\n关系:")
for relation in result.get("relations", []):
print(f"- {relation['source']} --[{relation['type']}]--> {relation['target']}")
print(f" 描述: {relation['description']}")
print(f" 置信度: {relation['confidence']}")
# 打印关键词
if "keywords" in result and result["keywords"]:
print("\n关键词:")
print(", ".join(result["keywords"]))