上传文件至 /
This commit is contained in:
+126
@@ -0,0 +1,126 @@
|
|||||||
|
import os
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
# from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
# embedding_path = "/data/Z/Z_llm_dm/vector_data/bge-m3"
|
||||||
|
# embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
import requests
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowEmbeddings(Embeddings):
|
||||||
|
def __init__(self, api_key: str, model: str = "bge-m3"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _embed(self, input: List[str]) -> List[List[float]]:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": input,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
response = requests.post(self.url, json=payload, headers=self.headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return [item["embedding"] for item in data["data"]]
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self._embed(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
return self._embed([text])[0]
|
||||||
|
|
||||||
|
embeddings = SiliconFlowEmbeddings(api_key="sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
||||||
|
|
||||||
|
def Mixed_retrieval(input_path):
|
||||||
|
file_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||||
|
faiss_archived = f"./faiss_data/{file_name}"
|
||||||
|
|
||||||
|
txt_list = []
|
||||||
|
with open(input_path, 'r', encoding='utf-8') as file:
|
||||||
|
txt_list = [line.strip() for line in file]
|
||||||
|
vectorstore_txt_faiss = FAISS.from_texts(txt_list, embeddings)
|
||||||
|
vectorstore_txt_faiss.save_local(faiss_archived)
|
||||||
|
|
||||||
|
# vectorstore_txt_faiss = FAISS.load_local(vectorstore_txt_faiss,
|
||||||
|
# embeddings=embeddings,
|
||||||
|
# allow_dangerous_deserialization=True)
|
||||||
|
|
||||||
|
retriever_txt_faiss1 = vectorstore_txt_faiss.as_retriever(search_kwargs={"k": 5})
|
||||||
|
retriever_txt_faiss2 = vectorstore_txt_faiss.as_retriever(
|
||||||
|
search_type="mmr",
|
||||||
|
search_kwargs={"k": 5, # 检索结果
|
||||||
|
"fetch_k": 2, # 候选结果数量
|
||||||
|
"lambda_mult": 0.1} # 平衡指数,1为相关性;0为多样性
|
||||||
|
)
|
||||||
|
retriever_txt_faiss3 = vectorstore_txt_faiss.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={"score_threshold": 0.3}
|
||||||
|
)
|
||||||
|
|
||||||
|
return retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def interface_search(input_str, retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3):
|
||||||
|
index_keyword1 = []
|
||||||
|
for i in retriever_txt_faiss1.invoke(input_str):
|
||||||
|
index_keyword1.append(i.page_content)
|
||||||
|
index_keyword2 = []
|
||||||
|
for i in retriever_txt_faiss2.invoke(input_str):
|
||||||
|
index_keyword2.append(i.page_content)
|
||||||
|
index_keyword3 = []
|
||||||
|
for i in retriever_txt_faiss3.invoke(input_str):
|
||||||
|
index_keyword3.append(i.page_content)
|
||||||
|
|
||||||
|
return list(set(index_keyword1) & set(index_keyword2) & set(index_keyword3))
|
||||||
|
|
||||||
|
|
||||||
|
def Building_search_dictionary(input_csv_path1, input_csv_path2, index_keyword):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df1 = pd.read_csv(input_csv_path1, encoding='utf-8')
|
||||||
|
df2 = pd.read_csv(input_csv_path2, encoding='utf-8', names=['path', 'id'])
|
||||||
|
#df2 = pd.read_csv(input_csv_path2, encoding='utf-8')
|
||||||
|
|
||||||
|
matching_path = df1.loc[df1['name'] == index_keyword, 'index']
|
||||||
|
|
||||||
|
# print(matching_path)
|
||||||
|
|
||||||
|
# print(matching_path.tolist()[0] )
|
||||||
|
|
||||||
|
# todo: bug修改: 避免matching_path和matching_ids没有映射
|
||||||
|
if matching_path.empty:
|
||||||
|
return(None, None)
|
||||||
|
else:
|
||||||
|
matching_ids = df2.loc[df2['path'] == matching_path.tolist()[0], 'id']
|
||||||
|
|
||||||
|
# print(matching_ids)
|
||||||
|
if matching_ids.empty:
|
||||||
|
return (matching_path.tolist()[0], None)
|
||||||
|
else:
|
||||||
|
return (matching_path.tolist()[0], int(matching_ids.values[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def Official_website_kg_search(input_id):
|
||||||
|
# info = WikijsTool.get_all_documents()
|
||||||
|
|
||||||
|
import re
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from booway_kg_api.WikijsTool import WikijsTool
|
||||||
|
|
||||||
|
html_text = WikijsTool.query_doc_info(input_id)['content']
|
||||||
|
cleaned_img_text = re.sub(r'<img\s+[^>]*>', '', html_text)
|
||||||
|
|
||||||
|
soup = BeautifulSoup(cleaned_img_text, "html.parser")
|
||||||
|
plain_text = soup.get_text()
|
||||||
|
|
||||||
|
return plain_text
|
||||||
Reference in New Issue
Block a user