Skip to content

Commit 1641edd

Browse files
Merge pull request #171 from gomate-community/pipeline
Pipeline
2 parents 1cee5ac + 54f91ca commit 1641edd

File tree

7 files changed

+455
-10
lines changed

7 files changed

+455
-10
lines changed

docs/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## GoMate快速上手教程
1+
## TrustRAG快速上手教程
22

33
## 🛠️ 安装
44

@@ -22,7 +22,7 @@ pip install gomate
2222
1. 下载源码
2323

2424
```shell
25-
git clone https://github.com/gomate-community/GoMate.git
25+
git clone https://github.com/gomate-community/TrustRAG.git
2626
```
2727

2828
2. 安装依赖

docs/xinference.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
docker run -e XINFERENCE_MODEL_SRC=modelscope -p 9998:9997 --gpus all xprobe/xinference:<your_version> xinference-local -H 0.0.0.0 --log-level debug
2+
docker run \
3+
-v </your/home/path>/.xinference:/root/.xinference \
4+
-v </your/home/path>/.cache/huggingface:/root/.cache/huggingface \
5+
-v </your/home/path>/.cache/modelscope:/root/.cache/modelscope \
6+
-p 9997:9997 \
7+
--gpus all \
8+
xprobe/xinference:v<your_version> \
9+
xinference-local -H 0.0.0.0
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# 步骤一:安装依赖库
2+
from langchain_community.document_loaders import WebBaseLoader
3+
from langchain.text_splitter import RecursiveCharacterTextSplitter
4+
from langchain_community.embeddings import DashScopeEmbeddings
5+
from pymilvus import MilvusClient, DataType, Function, FunctionType
6+
7+
dashscope_api_key = "<YOUR_DASHSCOPE_API_KEY>"
8+
milvus_url = "<YOUR_MMILVUS_URL>"
9+
user_name = "root"
10+
password = "<YOUR_PASSWORD>"
11+
collection_name = "milvus_overview"
12+
dense_dim = 1536
13+
14+
# 步骤二:数据准备
15+
loader = WebBaseLoader([
16+
'https://raw.githubusercontent.com/milvus-io/milvus-docs/refs/heads/v2.5.x/site/en/about/overview.md'
17+
])
18+
19+
docs = loader.load()
20+
21+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
22+
23+
# 使用LangChain将输入文档安照chunk_size切分
24+
all_splits = text_splitter.split_documents(docs)
25+
26+
embeddings = DashScopeEmbeddings(
27+
model="text-embedding-v2", dashscope_api_key=dashscope_api_key
28+
)
29+
30+
text_contents = [doc.page_content for doc in all_splits]
31+
32+
vectors = embeddings.embed_documents(text_contents)
33+
34+
35+
client = MilvusClient(
36+
uri=f"http://{milvus_url}:19530",
37+
token=f"{user_name}:{password}",
38+
)
39+
40+
schema = MilvusClient.create_schema(
41+
enable_dynamic_field=True,
42+
)
43+
44+
analyzer_params = {
45+
"type": "english"
46+
}
47+
48+
# Add fields to schema
49+
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
50+
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535, enable_analyzer=True, analyzer_params=analyzer_params, enable_match=True)
51+
schema.add_field(field_name="sparse_bm25", datatype=DataType.SPARSE_FLOAT_VECTOR)
52+
schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=dense_dim)
53+
54+
bm25_function = Function(
55+
name="bm25",
56+
function_type=FunctionType.BM25,
57+
input_field_names=["text"],
58+
output_field_names="sparse_bm25",
59+
)
60+
schema.add_function(bm25_function)
61+
62+
index_params = client.prepare_index_params()
63+
64+
# Add indexes
65+
index_params.add_index(
66+
field_name="dense",
67+
index_name="dense_index",
68+
index_type="IVF_FLAT",
69+
metric_type="IP",
70+
params={"nlist": 128},
71+
)
72+
73+
index_params.add_index(
74+
field_name="sparse_bm25",
75+
index_name="sparse_bm25_index",
76+
index_type="SPARSE_WAND",
77+
metric_type="BM25"
78+
)
79+
80+
# Create collection
81+
client.create_collection(
82+
collection_name=collection_name,
83+
schema=schema,
84+
index_params=index_params
85+
)
86+
87+
data = [
88+
{"dense": vectors[idx], "text": doc}
89+
for idx, doc in enumerate(text_contents)
90+
]
91+
92+
# Insert data
93+
res = client.insert(
94+
collection_name=collection_name,
95+
data=data
96+
)
97+
98+
print(f"生成 {len(vectors)} 个向量,维度:{len(vectors[0])}")
99+
100+
# 同样,在处理中文文档时,Milvus 2.5版本也支持指定相应的中文分析器。
101+
# # 定义分词器参数
102+
# analyzer_params = {
103+
# "type": "chinese" # 指定分词器类型为中文
104+
# }
105+
#
106+
# # 添加文本字段到 Schema,并启用分词器
107+
# schema.add_field(
108+
# field_name="text", # 字段名称
109+
# datatype=DataType.VARCHAR, # 数据类型:字符串(VARCHAR)
110+
# max_length=65535, # 最大长度:65535 字符
111+
# enable_analyzer=True, # 启用分词器
112+
# analyzer_params=analyzer_params # 分词器参数
113+
# )
114+
115+
# 步骤三:全文检索
116+
from pymilvus import MilvusClient
117+
118+
# 创建Milvus Client。
119+
client = MilvusClient(
120+
uri="http://c-xxxx.milvus.aliyuncs.com:19530", # Milvus实例的公网地址。
121+
token="<yourUsername>:<yourPassword>", # 登录Milvus实例的用户名和密码。
122+
db_name="default" # 待连接的数据库名称,本文示例为默认的default。
123+
)
124+
125+
search_params = {
126+
'params': {'drop_ratio_search': 0.2},
127+
}
128+
129+
full_text_search_res = client.search(
130+
collection_name='milvus_overview',
131+
data=['what makes milvus so fast?'],
132+
anns_field='sparse_bm25',
133+
limit=3,
134+
search_params=search_params,
135+
output_fields=["text"],
136+
)
137+
138+
for hits in full_text_search_res:
139+
for hit in hits:
140+
print(hit)
141+
print("\n")
142+
143+
# 步骤四:关键词匹配
144+
# filter = "TEXT_MATCH(text, 'query') and TEXT_MATCH(text, 'node')"
145+
#
146+
# text_match_res = client.search(
147+
# collection_name="milvus_overview",
148+
# anns_field="dense",
149+
# data=query_embeddings,
150+
# filter=filter,
151+
# search_params={"params": {"nprobe": 10}},
152+
# limit=2,
153+
# output_fields=["text"]
154+
# )
155+
156+
# 步骤五:混合检索与RAG
157+
from pymilvus import MilvusClient
158+
from pymilvus import AnnSearchRequest, RRFRanker
159+
from langchain_community.embeddings import DashScopeEmbeddings
160+
from dashscope import Generation
161+
162+
# 创建Milvus Client。
163+
client = MilvusClient(
164+
uri="http://c-xxxx.milvus.aliyuncs.com:19530", # Milvus实例的公网地址。
165+
token="<yourUsername>:<yourPassword>", # 登录Milvus实例的用户名和密码。
166+
db_name="default" # 待连接的数据库名称,本文示例为默认的default。
167+
)
168+
169+
collection_name = "milvus_overview"
170+
171+
# 替换为您的 DashScope API-KEY
172+
dashscope_api_key = "<YOUR_DASHSCOPE_API_KEY>"
173+
174+
# 初始化 Embedding 模型
175+
embeddings = DashScopeEmbeddings(
176+
model="text-embedding-v2", # 使用text-embedding-v2模型。
177+
dashscope_api_key=dashscope_api_key
178+
)
179+
180+
# Define the query
181+
query = "Why does Milvus run so scalable?"
182+
183+
# Embed the query and generate the corresponding vector representation
184+
query_embeddings = embeddings.embed_documents([query])
185+
186+
# Set the top K result count
187+
top_k = 5 # Get the top 5 docs related to the query
188+
189+
# Define the parameters for the dense vector search
190+
search_params_dense = {
191+
"metric_type": "IP",
192+
"params": {"nprobe": 2}
193+
}
194+
195+
# Create a dense vector search request
196+
request_dense = AnnSearchRequest([query_embeddings[0]], "dense", search_params_dense, limit=top_k)
197+
198+
# Define the parameters for the BM25 text search
199+
search_params_bm25 = {
200+
"metric_type": "BM25"
201+
}
202+
203+
# Create a BM25 text search request
204+
request_bm25 = AnnSearchRequest([query], "sparse_bm25", search_params_bm25, limit=top_k)
205+
206+
# Combine the two requests
207+
reqs = [request_dense, request_bm25]
208+
209+
# Initialize the RRF ranking algorithm
210+
ranker = RRFRanker(100)
211+
212+
# Perform the hybrid search
213+
hybrid_search_res = client.hybrid_search(
214+
collection_name=collection_name,
215+
reqs=reqs,
216+
ranker=ranker,
217+
limit=top_k,
218+
output_fields=["text"]
219+
)
220+
221+
# Extract the context from hybrid search results
222+
context = []
223+
print("Top K Results:")
224+
for hits in hybrid_search_res: # Use the correct variable here
225+
for hit in hits:
226+
context.append(hit['entity']['text']) # Extract text content to the context list
227+
print(hit['entity']['text']) # Output each retrieved document
228+
229+
230+
# Define a function to get an answer based on the query and context
231+
def getAnswer(query, context):
232+
prompt = f'''Please answer my question based on the content within:
233+
```
234+
{context}
235+
```
236+
My question is: {query}.
237+
'''
238+
# Call the generation module to get an answer
239+
rsp = Generation.call(model='qwen-turbo', prompt=prompt)
240+
return rsp.output.text
241+
242+
# Get the answer
243+
answer = getAnswer(query, context)
244+
245+
print(answer)
246+
247+
248+
# Expected output excerpt
249+
"""
250+
Milvus is highly scalable due to its cloud-native and highly decoupled system architecture. This architecture allows the system to continuously expand as data grows. Additionally, Milvus supports three deployment modes that cover a wide...
251+
"""
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import openai
2+
3+
client = openai.Client(
4+
api_key="api-key",
5+
base_url="http://localhost:9997/v1"
6+
)
7+
response=client.embeddings.create(
8+
model="bge-m3",
9+
input=["What is the capital of China?"]
10+
)
11+
print(type(response.data[0].embedding),len(response.data[0].embedding),response.data[0].embedding,)
12+
# <class 'list'> 1024 [-0.031030284240841866, ]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import requests
2+
import traceback
3+
from random import randrange
4+
5+
6+
def extract(row: dict,
7+
api_base: str = "http://10.208.62.156:6200/api/file/_extract",
8+
name_key: str = "name",
9+
data_key: str = "data",
10+
md_key: str = 'md',
11+
image_key: str = 'images',
12+
method: str = "auto",
13+
response_content: str = "markdown",
14+
**kwargs):
15+
"""
16+
基于MinerU服务(封装)抽取文件(支持pdf/word等),按指定格式返回(默认markdown)
17+
:param row 待处理的dict记录
18+
:param api_base 自封装的MinerU服务地址
19+
:param name_key 待抽取的文件的名称字段,默认为`name`
20+
:param data_key 待抽取的文件内容(bytes)或文件名
21+
:param md_key 输出的markdown字段名 默认`md`
22+
:param image_key 输出的图片字段名 默认`images`
23+
:param method 抽取的方法,支持text/ocr/auto,默认为auto,表示自动识别
24+
:param response_content 返回内容类型,支持markdown/json,默认为markdown
25+
"""
26+
if isinstance(api_base, list):
27+
api_base = api_base[randrange(len(api_base))]
28+
29+
content = row[data_key]
30+
assert isinstance(content, bytes) or isinstance(content, str), f"content field `{data_key}`must be bytes or str"
31+
32+
filename = row.get(name_key, 'auto_file')
33+
34+
if isinstance(content, bytes):
35+
files = {'file': (filename, content)}
36+
else:
37+
with open(content, 'rb') as reader:
38+
files = {'file': (filename, reader.read())}
39+
40+
data = {
41+
'method': method,
42+
'response_content': response_content
43+
}
44+
45+
try:
46+
response = requests.post(api_base, files=files, data=data)
47+
response_data = response.json()
48+
if 'data' in response_data:
49+
data = response_data['data']
50+
if isinstance(data, dict) and 'extract_data' in data:
51+
row[md_key] = data['extract_data']
52+
return row
53+
error = response.text
54+
print('ERROR', filename, error, api_base)
55+
row['ERROR'] = error
56+
except:
57+
print('ERROR', filename)
58+
traceback.print_exc()
59+
60+
return row
61+
62+
63+
if __name__ == '__main__':
64+
content = extract({"data": "../../../data/paper/16400599.pdf", "name": "16400599.pdf"},
65+
api_base="http://10.208.62.156:6201/api/file/_extract")
66+
print(content)

0 commit comments

Comments
 (0)