From 0d855d449227babd62c30b472f6ba6027c2ff8c6 Mon Sep 17 00:00:00 2001 From: zoujiwen Date: Tue, 1 Apr 2025 09:28:01 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vector_load2.py | 126 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 vector_load2.py diff --git a/vector_load2.py b/vector_load2.py new file mode 100644 index 0000000..28b9809 --- /dev/null +++ b/vector_load2.py @@ -0,0 +1,126 @@ +import os +from langchain_community.vectorstores import FAISS +# from langchain_huggingface import HuggingFaceEmbeddings + +# embedding_path = "/data/Z/Z_llm_dm/vector_data/bge-m3" +# embeddings = HuggingFaceEmbeddings(model_name=embedding_path) + + +from typing import List +import requests +from langchain.embeddings.base import Embeddings + + +class SiliconFlowEmbeddings(Embeddings): + def __init__(self, api_key: str, model: str = "bge-m3"): + self.api_key = api_key + self.model = model + self.url = "http://10.1.16.39:9995/v1/embeddings" + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + def _embed(self, input: List[str]) -> List[List[float]]: + payload = { + "model": self.model, + "input": input, + "encoding_format": "float" + } + response = requests.post(self.url, json=payload, headers=self.headers) + response.raise_for_status() + data = response.json() + return [item["embedding"] for item in data["data"]] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self._embed(texts) + + def embed_query(self, text: str) -> List[float]: + return self._embed([text])[0] + +embeddings = SiliconFlowEmbeddings(api_key="sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr") + +def Mixed_retrieval(input_path): + file_name = os.path.splitext(os.path.basename(input_path))[0] + faiss_archived = f"./faiss_data/{file_name}" + + txt_list = [] + with open(input_path, 'r', encoding='utf-8') as file: + txt_list = [line.strip() for line in file] + vectorstore_txt_faiss = FAISS.from_texts(txt_list, embeddings) + vectorstore_txt_faiss.save_local(faiss_archived) + + # vectorstore_txt_faiss = FAISS.load_local(vectorstore_txt_faiss, + # embeddings=embeddings, + # allow_dangerous_deserialization=True) + + retriever_txt_faiss1 = vectorstore_txt_faiss.as_retriever(search_kwargs={"k": 5}) + retriever_txt_faiss2 = vectorstore_txt_faiss.as_retriever( + search_type="mmr", + search_kwargs={"k": 5, # 检索结果 + "fetch_k": 2, # 候选结果数量 + "lambda_mult": 0.1} # 平衡指数,1为相关性;0为多样性 + ) + retriever_txt_faiss3 = vectorstore_txt_faiss.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": 0.3} + ) + + return retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3 + + + +def interface_search(input_str, retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3): + index_keyword1 = [] + for i in retriever_txt_faiss1.invoke(input_str): + index_keyword1.append(i.page_content) + index_keyword2 = [] + for i in retriever_txt_faiss2.invoke(input_str): + index_keyword2.append(i.page_content) + index_keyword3 = [] + for i in retriever_txt_faiss3.invoke(input_str): + index_keyword3.append(i.page_content) + + return list(set(index_keyword1) & set(index_keyword2) & set(index_keyword3)) + + +def Building_search_dictionary(input_csv_path1, input_csv_path2, index_keyword): + import pandas as pd + + df1 = pd.read_csv(input_csv_path1, encoding='utf-8') + df2 = pd.read_csv(input_csv_path2, encoding='utf-8', names=['path', 'id']) + #df2 = pd.read_csv(input_csv_path2, encoding='utf-8') + + matching_path = df1.loc[df1['name'] == index_keyword, 'index'] + + # print(matching_path) + + # print(matching_path.tolist()[0] ) + + # todo: bug修改: 避免matching_path和matching_ids没有映射 + if matching_path.empty: + return(None, None) + else: + matching_ids = df2.loc[df2['path'] == matching_path.tolist()[0], 'id'] + + # print(matching_ids) + if matching_ids.empty: + return (matching_path.tolist()[0], None) + else: + return (matching_path.tolist()[0], int(matching_ids.values[0])) + + +def Official_website_kg_search(input_id): + # info = WikijsTool.get_all_documents() + + import re + from bs4 import BeautifulSoup + from booway_kg_api.WikijsTool import WikijsTool + + html_text = WikijsTool.query_doc_info(input_id)['content'] + cleaned_img_text = re.sub(r']*>', '', html_text) + + soup = BeautifulSoup(cleaned_img_text, "html.parser") + plain_text = soup.get_text() + + return plain_text