-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreranking.ts
231 lines (203 loc) · 8.54 KB
/
reranking.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import { RerankingOptions, SearchResult } from '../types';
import * as distanceMetrics from '../utils/distance_metrics'; // Import distance functions
/**
* SearchReranker provides various methods to reorder search results
* to improve diversity, relevance, or other custom criteria.
*/
/**
* A utility class for reranking search results using various strategies.
*
* The SearchReranker provides algorithms to refine initial search results
* beyond simple distance/similarity sorting. This can improve result relevance
* and user experience by considering factors like diversity or weighted attributes.
*
* @class SearchReranker
*
* Supports three reranking strategies:
* - `standard`: Preserves the original ranking, optionally limiting to top k results
* - `diversity`: Implements Maximal Marginal Relevance (MMR) to balance relevance and diversity
* - `weighted`: Adjusts ranking based on weighted metadata attributes
*
* @example
* ```typescript
* const reranker = new SearchReranker();
*
* // Standard reranking (limit to top 5)
* const topResults = reranker.rerank(initialResults, { method: 'standard', k: 5 });
*
* // Diversity reranking
* const diverseResults = reranker.rerank(initialResults, {
* method: 'diversity',
* queryVector: query,
* vectorsMap: vectors,
* lambda: 0.7
* });
*
* // Weighted reranking based on metadata
* const weightedResults = reranker.rerank(initialResults, {
* method: 'weighted',
* metadataMap: metadata,
* weights: { recency: 0.3, popularity: 0.5 }
* });
* ```
*/
export class SearchReranker {
/**
* Rerank search results using the specified method.
* This is the main public entry point for reranking.
*
* @param results The initial list of search results, typically sorted by distance/similarity.
* @param options Configuration for the reranking process, including the method to use.
* @returns A new list of reranked search results.
*/
public rerank(results: SearchResult[], options: RerankingOptions = {}): SearchResult[] {
const { method = 'standard' } = options; // Default to standard if no method specified
// Ensure results is an array before proceeding
if (!Array.isArray(results)) {
console.error('Reranker received invalid input: results is not an array.');
return [];
}
// Dispatch to the appropriate private reranking method
switch (method) {
case 'diversity':
console.log('Dispatching to diversity reranking...'); // Debug log
return this._diversityReranking(results, options);
case 'weighted':
console.log('Dispatching to weighted reranking...'); // Debug log
return this._weightedReranking(results, options);
case 'standard':
default: // Fallback to standard reranking
console.log('Dispatching to standard reranking (default)...'); // Debug log
return this._standardReranking(results, options);
}
}
/**
* Basic reranking: Returns the results as is or potentially capped at k.
* Does not change the order based on content or metadata.
*/
private _standardReranking(results: SearchResult[], options: RerankingOptions): SearchResult[] {
const { k = results.length } = options;
// Simple copy and slice to avoid modifying original results and apply k limit
return results.slice(0, k);
}
/**
* Diversity-based reranking using Maximal Marginal Relevance (MMR) concept.
* Requires actual vectors for calculation.
*/
private _diversityReranking(initialResults: SearchResult[], options: RerankingOptions): SearchResult[] {
const {
k = initialResults.length,
queryVector,
lambda = 0.7, // Default balance: more towards relevance
vectorsMap,
distanceMetric = 'euclidean', // Default distance metric
} = options;
// --- Input Validation ---
if (!queryVector || !vectorsMap || vectorsMap.size === 0 || initialResults.length <= 1) {
console.warn('Diversity reranking skipped: Missing queryVector, vectorsMap, or insufficient results.');
return initialResults.slice(0, k); // Return original top K
}
// Add more validation as needed (e.g., lambda range)
// --- Setup ---
const distanceFunc = distanceMetrics.getDistanceFunction(distanceMetric);
const typedQueryVector = queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector);
const remainingResults = new Map<number | string, SearchResult>();
const resultVectors = new Map<number | string, Float32Array>();
initialResults.forEach((res) => {
const vector = vectorsMap.get(res.id);
if (vector) {
remainingResults.set(res.id, res);
resultVectors.set(res.id, vector);
} else {
console.warn(`Vector for result ID ${res.id} not found in vectorsMap. Skipping for diversity rerank.`);
}
});
if (remainingResults.size === 0) {
console.warn('No results with available vectors for diversity reranking.');
return initialResults.slice(0, k);
}
const finalResults: SearchResult[] = [];
const selectedIds = new Set<number | string>();
// --- MMR Algorithm ---
// 1. Select the first result
let firstResult: SearchResult | null = null;
let minInitialDist = Infinity;
for (const res of remainingResults.values()) {
if (res.dist < minInitialDist) {
minInitialDist = res.dist;
firstResult = res;
}
}
if (!firstResult) {
console.error('Could not determine the first result for MMR.');
return initialResults.slice(0, k);
}
finalResults.push(firstResult);
selectedIds.add(firstResult.id);
remainingResults.delete(firstResult.id);
// 2. Iteratively select remaining results
while (finalResults.length < k && remainingResults.size > 0) {
let bestCandidateId: number | string | null = null;
let maxMmrScore = -Infinity;
for (const [candidateId, candidateResult] of remainingResults.entries()) {
const candidateVector = resultVectors.get(candidateId);
if (!candidateVector) continue;
// Calculate Relevance Score (using similarity proxy from distance)
const relevanceScore = 1.0 / (1.0 + candidateResult.dist);
// Calculate Diversity Score (Min Distance to Selected)
let minDistanceToSelected = Infinity;
for (const selectedId of selectedIds) {
const selectedVector = resultVectors.get(selectedId);
if (selectedVector) {
const distToSelected = distanceFunc(candidateVector, selectedVector);
minDistanceToSelected = Math.min(minDistanceToSelected, distToSelected);
}
}
const diversityScore = minDistanceToSelected; // Higher is more diverse
// Combine scores using lambda
const mmrScore = lambda * relevanceScore + (1 - lambda) * diversityScore;
if (mmrScore > maxMmrScore) {
maxMmrScore = mmrScore;
bestCandidateId = candidateId;
}
}
// Add the best candidate found
if (bestCandidateId !== null) {
const bestResult = remainingResults.get(bestCandidateId)!;
finalResults.push(bestResult);
selectedIds.add(bestCandidateId);
remainingResults.delete(bestCandidateId);
} else {
console.warn('MMR iteration finished without selecting a candidate.');
break; // No more suitable candidates
}
}
return finalResults;
}
/**
* Weighted reranking based on metadata attributes.
* Requires metadataMap in options.
*/
private _weightedReranking(results: SearchResult[], options: RerankingOptions): SearchResult[] {
const { k = results.length, weights = {}, metadataMap } = options; // Use metadataMap from options
if (!metadataMap || metadataMap.size === 0 || Object.keys(weights).length === 0) {
console.warn('Weighted reranking skipped: Missing metadataMap or weights.');
return results.slice(0, k); // Apply k limit even if not reranking
}
// Create weighted scores
const weightedResults = results.map((result) => {
const itemMetadata = metadataMap.get(result.id) || {};
let weightedScore = result.dist; // Start with original distance
// Apply weights
for (const [key, weight] of Object.entries(weights)) {
if (key in itemMetadata && typeof itemMetadata[key] === 'number') {
weightedScore -= (itemMetadata[key] as number) * weight;
}
}
return { ...result, weightedScore };
});
// Sort by weighted score and take top k
return weightedResults.sort((a, b) => a.weightedScore - b.weightedScore).slice(0, k);
}
}
export default SearchReranker;