diff --git a/rag2_0/api/query_dinge_qingdan_api.py b/rag2_0/api/query_dinge_qingdan_api.py index b074bcd..809dd4f 100644 --- a/rag2_0/api/query_dinge_qingdan_api.py +++ b/rag2_0/api/query_dinge_qingdan_api.py @@ -53,12 +53,12 @@ class BatchQueryResponse(BaseModel): # 封装查询数据的相关代码 class QingDanDingEQueryService: - def __init__(self, db_path="/data/QueryRewrite/data/db/qingdan_ding_e_ku.db"): - self.db_path = db_path + def __init__(self): + self.db_path = f"{os.getcwd()}/data/db/qingdan_ding_e_ku.db" self.top_k = TOP_K # 初始化向量检索相关组件 - self.embedding_function = XinferenceEmbeddings(api_key="") + self.embedding_function = XinferenceEmbeddings() # 初始化向量数据库连接 self.ding_e_vector_db = SQLiteVSS( diff --git a/rag2_0/demo/create_qingdan_dinge_database.py b/rag2_0/demo/create_qingdan_dinge_database.py index ed974c0..9cb1ddb 100644 --- a/rag2_0/demo/create_qingdan_dinge_database.py +++ b/rag2_0/demo/create_qingdan_dinge_database.py @@ -462,10 +462,10 @@ class ExcelToSQLiteProcessor: print("数据库事务已提交,连接已关闭") class CreateEmbedingData(): - def __init__(self, db_path, api_key="aa"): + def __init__(self, db_path): self.db_path = db_path self.conn = sqlite3.connect(db_path) - self.embedding_function = XinferenceEmbeddings(api_key=api_key) + self.embedding_function = XinferenceEmbeddings() def create_ding_e_zimu_embedding(self): """创建定额子目名称的向量索引""" @@ -535,22 +535,24 @@ def main(): print("开始处理定额库和清单库Excel文件...") # 配置参数 - ding_e_base_dir = "/data/QueryRewrite/data/excel/Excel版 清单定额库/定额库" - qing_dan_base_dir = "/data/QueryRewrite/data/excel/Excel版 清单定额库/清单库" - db_path = "/data/QueryRewrite/data/db/qingdan_ding_e_ku copy.db" - + ding_e_base_dir = f"{os.getcwd()}/data/excel/Excel版 清单定额库/定额库" + qing_dan_base_dir = f"{os.getcwd()}/data/excel/Excel版 清单定额库/清单库" + db_path = f"{os.getcwd()}/data/db/qingdan_ding_e_ku.db" + if os.path.exists(db_path): + print("数据库文件已存在, 任务结束...") + return # 创建处理器实例 - # processor = ExcelToSQLiteProcessor(db_path) + processor = ExcelToSQLiteProcessor(db_path) try: # 处理定额库文件 - # processor.process_ding_e_files(ding_e_base_dir) + processor.process_ding_e_files(ding_e_base_dir) # # 处理清单库文件 - # processor.process_qing_dan_files(qing_dan_base_dir) + processor.process_qing_dan_files(qing_dan_base_dir) # # 提交并关闭 - # processor.commit_and_close() + processor.commit_and_close() print("=" * 50) print("所有Excel文件处理完成!数据已成功导入SQLite数据库") diff --git a/rag2_0/intent_recognition/ProfessionalNounVector.py b/rag2_0/intent_recognition/ProfessionalNounVector.py index b589c83..ccd313c 100755 --- a/rag2_0/intent_recognition/ProfessionalNounVector.py +++ b/rag2_0/intent_recognition/ProfessionalNounVector.py @@ -28,7 +28,7 @@ def get_embedding_model(api_key: str = None) -> Embeddings: Returns: 嵌入模型实例 """ - return XinferenceEmbeddings(api_key=api_key) + return XinferenceEmbeddings() class ProfessionalNounVectorizer: diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index 5927891..1bb92da 100755 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -23,13 +23,13 @@ from urllib.parse import urljoin class XinferenceEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" - def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")): - self.api_key = api_key + def __init__(self, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")): self.model = model - base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + base_url = os.getenv("EMBEDDING_BASE_URL", "http://10.1.16.39:9995") self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings') + api_key = os.getenv("EMBEDDING_API_KEY", "") self.headers = { - "Authorization": f"Bearer {self.api_key}", + "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } @@ -89,12 +89,13 @@ class XinferenceReRankerModel: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ - base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + base_url = os.getenv("RERANKER_BASE_URL", "http://10.1.16.39:9995") model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") + api_key = os.getenv("RERANKER_API_KEY", "") rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { - "Authorization": "Bearer ", # 这里需要替换为实际的token + "Authorization": f"Bearer {api_key}", # 这里需要替换为实际的token "Content-Type": "application/json" } @@ -123,12 +124,14 @@ class XinferenceReRankerModel: Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ - base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + base_url = os.getenv("RERANKER_BASE_URL", "http://10.1.16.39:9995") rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") + api_key = os.getenv("RERANKER_API_KEY", "") + params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { - "Authorization": "Bearer ", # 这里需要替换为实际的token + "Authorization": f"Bearer {api_key}", # 这里需要替换为实际的token "Content-Type": "application/json" } @@ -214,7 +217,9 @@ class OpenAiLLM: messages=[{'role': 'user', 'content': user_prompt}], **kwargs ) - return completion.choices[0].message + message = completion.choices[0].message + message.usage = completion.usage + return message except Exception as e: raise RuntimeError(f"OpenAiLLM:ainvoke:error:{str(e)}") from e