forked from haonan-yuan/RAG-GFM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch_motif_db.py
More file actions
82 lines (69 loc) · 3.77 KB
/
search_motif_db.py
File metadata and controls
82 lines (69 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import numpy as np
import logging
from typing import List, Dict, Optional, Union
from torch_geometric.data import Data
from nano_vectordb import NanoVectorDB
from train_motif_finder import SubgraphEncoder
logging.getLogger('nano-vectordb').setLevel(logging.ERROR)
class MotifRetriever:
def __init__(self, motif_lib_path: str, motif_db_path: str, device: Optional[str] = None):
if device: self.device = torch.device(device)
else: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.motif_lib_path = motif_lib_path
self.motif_db_path = motif_db_path
self.loaded_encoders = {}
self.loaded_dbs = {}
def _load_assets(self, domain_name: str) -> None:
if domain_name in self.loaded_encoders and domain_name in self.loaded_dbs:
return
encoder_path = os.path.join(self.motif_lib_path, domain_name, 'encoder.pth')
db_path = os.path.join(self.motif_lib_path, domain_name, 'motif_vectordb.json')
if not os.path.exists(encoder_path): raise FileNotFoundError(f"encoder file not found: {encoder_path}")
if not os.path.exists(db_path): raise FileNotFoundError(f"database file not found: {db_path}")
config_path = os.path.join(self.motif_lib_path, domain_name, 'config.pth')
config = torch.load(config_path, map_location=self.device)
encoder = SubgraphEncoder(input_dim=config['struct_input_dim'], hidden_dim=config['hidden_dim'], output_dim=config['output_dim'])
encoder.load_state_dict(torch.load(encoder_path, map_location=self.device))
encoder.to(self.device)
encoder.eval()
self.loaded_encoders[domain_name] = encoder
db = NanoVectorDB(config['output_dim'], storage_file=db_path)
self.loaded_dbs[domain_name] = db
@torch.no_grad()
def search(self, query_subgraph: Data, search_domain: str, k: int = 1) -> List[Dict]:
self._load_assets(search_domain)
encoder = self.loaded_encoders[search_domain]
db = self.loaded_dbs[search_domain]
query_subgraph.batch = torch.zeros(query_subgraph.num_nodes, dtype=torch.long, device=self.device)
query_vector_torch = encoder(query_subgraph.to(self.device))
query_vector_np = query_vector_torch.cpu().numpy().flatten()
try:
candidate_docs = db._NanoVectorDB__storage['data']
candidate_vectors = db._NanoVectorDB__storage['matrix']
except (AttributeError, KeyError) as e:
raise AttributeError(f"cannot access internal storage of '{search_domain}' MotifDB.")
if len(candidate_vectors) == 0:
return []
query_norm = np.linalg.norm(query_vector_np)
if query_norm == 0: return []
vectors_norm = np.linalg.norm(candidate_vectors, axis=1)
vectors_norm[vectors_norm == 0] = 1e-9
similarities = np.dot(candidate_vectors, query_vector_np) / (vectors_norm * query_norm)
actual_k = min(k, len(similarities))
if actual_k == 0: return []
top_k_indices = np.argpartition(similarities, -actual_k)[-actual_k:]
top_k_sorted_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
results = []
for idx in top_k_sorted_indices:
doc_metadata = dict(candidate_docs[idx])
doc_vector = candidate_vectors[idx]
doc_similarity = similarities[idx]
complete_doc = {
'__vector__': doc_vector,
'metadata': doc_metadata.get('metadata', {}),
'__metrics__': 1 - doc_similarity
}
results.append(complete_doc)
return results