Skip to content

Commit 08f5b6c

Browse files
authored
Merge pull request gangiswag#2 from rryisthebest/main
Add Relevance Feedback Implementation + Update readme
2 parents 82cc36e + 47fc076 commit 08f5b6c

12 files changed

+396
-11
lines changed

README.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,36 @@ Training and accelerate configs are at `{REPO_DIR}/bash/run_train.sh` and `{REPO
7878

7979
To train the model, run:
8080
```
81-
bash bash/run_train.sh
81+
bash bash/beir/run_train.sh
8282
```
8383

8484
To train gated model, login to Huggingface and get token access at huggingface.co/settings/tokens.
8585
```
8686
huggingface-cli login
8787
```
88+
## 4. Relevance Feedback
89+
### 4a. Dataset preparation for relevance feedback
90+
To prepare dataset(s) for relevance feedback, run:
91+
```
92+
bash bash/beir/run_prepare_distill.sh <Path of precomputed BEIR encodings>
93+
```
94+
### 4b. Distillation
95+
Distillation config \ settings is at `{REPO_DIR}/bash/beir/run_eval.sh`
96+
To perform the distillation step, run:
97+
```
98+
bash bash/beir/run_distill.sh
99+
```
100+
101+
### 4c. 2nd Retrieval
102+
To perform the retrieval step after distillation, run:
103+
```
104+
bash bash/beir/run_2nd_retrieval.sh <Path of precomputed BEIR encodings>
105+
```
106+
107+
### 4d. Relevance feedback evaluation
108+
To get the 2nd Retreival evaluation, run:
109+
```
110+
bash bash/beir/run_eval.sh rank_refit
111+
```
112+
88113

bash/beir/run_1st_retrieval.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ data_dir="${REPO_DIR}/datasets/beir"
1515
mkdir -p "$output_dir" "$data_dir"
1616

1717
# Datasets to process
18-
datasets=('trec-covid') # 'climate-fever' 'dbpedia-entity' 'fever' 'fiqa' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'scidocs' 'scifact' 'trec-covid'
18+
datasets=('trec-covid') # 'climate-fever' 'dbpedia-entity' 'fever' 'fiqa' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'scidocs' 'scifact'
1919

2020
# Iterate over datasets
2121
for dataset in "${datasets[@]}"; do

bash/beir/run_2nd_retrieval.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
# Check if input directory is provided
4+
if [ -z "$1" ]; then
5+
echo "Usage: $0 <input_directory>"
6+
exit 1
7+
fi
8+
9+
input_dir=$1
10+
11+
# Check if output directory exists
12+
mkdir -p "${REPO_DIR}/outputs/beir"
13+
14+
# List of datasets to process
15+
datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact'
16+
17+
# Process each dataset
18+
for dataset in "${datasets[@]}"; do
19+
echo "Processing dataset: ${dataset}"
20+
21+
dataset_output_dir="${REPO_DIR}/outputs/beir/${dataset}"
22+
mkdir -p "$dataset_output_dir"
23+
24+
python -m tevatron.faiss_retriever \
25+
--query_reps "${dataset_output_dir}/qry_refit.pt" \
26+
--passage_reps "${input_dir}/${dataset}/original_corpus/*.pt" \
27+
--depth 1000 \
28+
--batch_size -1 \
29+
--save_text \
30+
--save_ranking_to "${dataset_output_dir}/rank_refit.tsv"
31+
32+
if [ $? -ne 0 ]; then
33+
echo "Error processing dataset: ${dataset}"
34+
exit 1
35+
fi
36+
37+
echo "Finished processing dataset: ${dataset}"
38+
done
39+
40+
echo "All datasets processed successfully."

bash/beir/run_convert_results.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ data_dir=${REPO_DIR}/datasets/beir/
22
output_dir=${REPO_DIR}/outputs/beir/
33

44
# List of datasets to process
5-
datasets=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 'trec-covid'
5+
datasets=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
66

77
# Iterate over datasets and process each one
88
for datasets in "${datasets[@]}"; do
99
echo "Processing dataset: ${datasets}"
1010

11-
# Execute the conversion script with error handling
1211
if python "${REPO_DIR}/scripts/convert_results.py" \
1312
--dataset "${datasets}" \
1413
--output_dir "${output_dir}" \

bash/beir/run_distill.sh

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/bin/bash
2+
3+
# Check if output directory exists
4+
mkdir -p "${REPO_DIR}/outputs/beir"
5+
output_dir="${REPO_DIR}/outputs/beir/"
6+
data_dir="${REPO_DIR}/datasets/beir/"
7+
8+
# Configuration flags
9+
use_logits=1 # Whether to use FIRST single token logit decoding
10+
use_alpha=1 # Whether to use Alphabetic Identifiers
11+
12+
# List of datasets to process
13+
datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact'
14+
15+
# Process each dataset
16+
for dataset in "${datasets[@]}"; do
17+
echo "Processing dataset: ${dataset}"
18+
19+
python ${REPO_DIR}/scripts/distill.py \
20+
--inp_path ${output_dir}/${dataset}/distill_input.pt \
21+
--rerank_path ${output_dir}/${dataset} \
22+
--output_path ${output_dir}/${dataset}/qry_refit.pt \
23+
--ce_top_k 100 \
24+
--llm_top_k 100 \
25+
--use_logits ${use_logits} \
26+
--use_alpha ${use_alpha} \
27+
--loss_path ${output_dir}/${dataset} \
28+
--llm_loss ranknet
29+
30+
if [ $? -ne 0 ]; then
31+
echo "Error processing dataset: ${dataset}"
32+
exit 1
33+
fi
34+
35+
echo "Finished processing dataset: ${dataset}"
36+
done
37+
38+
echo "All datasets processed successfully."

bash/beir/run_eval.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ DATA_DIR="${REPO_DIR}/datasets/beir/"
1111
OUTPUT_DIR="${REPO_DIR}/outputs/beir/"
1212

1313
# List of datasets to process
14-
DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 'trec-covid'
14+
DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
1515

1616
# Iterate over datasets and process each one
1717
for DATASET in "${DATASETS[@]}"; do
1818
echo "Evaluating dataset: ${DATASET}"
1919

20-
# Execute the evaluation script
2120
# suffix: ce -> cross encoder reranker | llm_FIRST_alpha -> FIRST Model
2221
if python "${REPO_DIR}/scripts/eval.py" \
2322
--dataset "${DATASET}" \

bash/beir/run_prepare_distill.sh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
3+
# Check if input directory is provided
4+
if [ -z "$1" ]; then
5+
echo "Usage: $0 <input_directory>"
6+
exit 1
7+
fi
8+
9+
input_dir=$1
10+
output_dir="${REPO_DIR}/outputs/beir/"
11+
12+
# List of datasets to process
13+
datasets=('trec-covid' 'dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
14+
15+
# Process each dataset
16+
for dataset in "${datasets[@]}"; do
17+
echo "Processing dataset: ${dataset}"
18+
19+
python ${REPO_DIR}/scripts/prepare_distill.py \
20+
--output_path ${output_dir}/${dataset}/distill_input.pt \
21+
--rank_path ${output_dir}/${dataset}/rank.tsv \
22+
--psg_embs_dir ${input_dir}/${dataset}/original_corpus/ \
23+
--qry_embs_path ${input_dir}/${dataset}/original_query/qry.pt
24+
25+
if [ $? -ne 0 ]; then
26+
echo "Error processing dataset: ${dataset}"
27+
exit 1
28+
fi
29+
30+
echo "Finished processing dataset: ${dataset}"
31+
done

bash/beir/run_rerank_CE.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus
1111
for DATASET in "${DATASETS[@]}"; do
1212
echo "Reranking dataset: ${DATASET}"
1313

14-
# Execute the rerank script with error handling
1514
if python "${REPO_DIR}/scripts/rerank_CE.py" \
1615
--dataset "${DATASET}" \
1716
--output_dir "${OUTPUT_DIR}" \

bash/beir/run_rerank_llm.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ DATASETS=('dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfco
1616
for DATASET in "${DATASETS[@]}"; do
1717
echo "Reranking dataset: ${DATASET}"
1818

19-
# Execute the rerank script with error handling
2019
if python "${REPO_DIR}/scripts/rerank_llm.py" \
2120
--model "${MODEL_IN_USE}" \
2221
--dataset "${DATASET}" \

scripts/distill.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import os
2+
import json
3+
import pickle
4+
import numpy as np
5+
from copy import deepcopy
6+
from tqdm import tqdm
7+
from itertools import product
8+
from argparse import ArgumentParser
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.optim as optim
13+
from torch import Tensor
14+
from utils.loss import loss_dict
15+
16+
class QueryImpModel(nn.Module):
17+
def __init__(self, query_rep, scaler):
18+
super().__init__()
19+
self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True)
20+
self.scaler = scaler
21+
22+
def forward(self, psg_embs: Tensor, attn_mask: Tensor = None):
23+
pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1))
24+
if attn_mask is not None:
25+
extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min
26+
pred_scores += extended_attention_mask
27+
pred_probs = nn.functional.log_softmax(pred_scores, dim=-1)
28+
return pred_probs
29+
30+
31+
class QueryScoreModel(nn.Module):
32+
def __init__(self, query_rep, scaler=2.0):
33+
super().__init__()
34+
self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True)
35+
self.scaler = scaler
36+
37+
def forward(self, psg_embs: Tensor, attn_mask: Tensor = None):
38+
pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1))
39+
if attn_mask is not None:
40+
extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min
41+
pred_scores += extended_attention_mask
42+
return pred_scores.unsqueeze(0)
43+
44+
45+
def load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha):
46+
llm_rerank = None
47+
ce_rerank = None
48+
49+
if llm_top_k > 0:
50+
suffix = "_llm"
51+
suffix += "_FIRST" if use_logits else "_gen"
52+
suffix += "_alpha" if use_alpha else "_num"
53+
llm_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{llm_top_k}{suffix}.json")))
54+
55+
if ce_top_k > 0:
56+
ce_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{ce_top_k}_ce.json")))
57+
58+
examples = pickle.load(open(inp_path, "rb"))
59+
return examples, ce_rerank, llm_rerank
60+
61+
62+
def prepare_distill_ce(data, ce_rerank, ce_top_k):
63+
qid = data["query_id"]
64+
pids = data["passage_ids"][:ce_top_k]
65+
66+
data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])}
67+
target_scores = [ce_rerank[qid][pid] for pid in pids]
68+
psg_embs = [data_passage_mapping[pid] for pid in pids]
69+
70+
target_scores = torch.FloatTensor(target_scores)
71+
target_probs = nn.functional.log_softmax(target_scores, dim=-1)
72+
73+
baseline_rep = torch.FloatTensor(data["query_rep"])
74+
passage_reps = torch.FloatTensor(np.array(psg_embs))
75+
76+
init_scores = torch.matmul(baseline_rep, passage_reps.transpose(0, 1))
77+
scaler = (target_scores.max() - target_scores.min()) / (init_scores.max().item() - init_scores.min().item())
78+
79+
return passage_reps, target_probs, scaler
80+
81+
82+
def prepare_distill_llm(data, llm_rerank, query_rep, llm_top_k):
83+
qid = data["query_id"]
84+
pids = data["passage_ids"][:llm_top_k]
85+
86+
data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])}
87+
reranked_target_scores = [llm_rerank[qid][pid] for pid in pids]
88+
reranked_psg_embs = [data_passage_mapping[pid] for pid in pids]
89+
90+
reranked_target_scores = torch.FloatTensor(reranked_target_scores)
91+
reranked_passage_reps = torch.FloatTensor(np.array(reranked_psg_embs))
92+
93+
init_scores = torch.matmul(query_rep, reranked_passage_reps.transpose(0, 1))
94+
scaler = (reranked_target_scores.max() - reranked_target_scores.min()) / \
95+
(init_scores.max().item() - init_scores.min().item())
96+
97+
return reranked_passage_reps, reranked_target_scores.unsqueeze(0), scaler
98+
99+
100+
def run_query_teacher_importance_learner(inp_path, rerank_path, output_path, loss_path, ce_top_k, llm_top_k, learning_rate,
101+
num_updates, use_logits, use_alpha, llm_loss):
102+
assert llm_loss in loss_dict
103+
examples, ce_rerank, llm_rerank = load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha)
104+
105+
reps = []
106+
ids = []
107+
108+
for data in tqdm(examples):
109+
baseline_rep = torch.FloatTensor(data["query_rep"])
110+
111+
try:
112+
learned_rep = baseline_rep
113+
if ce_top_k > 0:
114+
passage_reps, target_probs, scaler = prepare_distill_ce(data, ce_rerank, ce_top_k)
115+
ce_dstl_model = QueryImpModel(query_rep=baseline_rep.numpy(), scaler=scaler)
116+
loss_function = nn.KLDivLoss(reduction="batchmean", log_target=True)
117+
optimizer = optim.Adam(ce_dstl_model.parameters(), lr=learning_rate)
118+
119+
for _ in range(num_updates):
120+
optimizer.zero_grad()
121+
pred_probs = ce_dstl_model(psg_embs=passage_reps)
122+
loss = loss_function(pred_probs.unsqueeze(0), target_probs.unsqueeze(0))
123+
loss.backward()
124+
optimizer.step()
125+
126+
learned_rep = ce_dstl_model.query_rep.data.cpu().detach()
127+
128+
reranked_passage_reps, reranked_target_scores, scaler = prepare_distill_llm(data, llm_rerank, learned_rep,
129+
llm_top_k)
130+
llm_dstl_model = QueryScoreModel(query_rep=learned_rep.numpy(), scaler=scaler)
131+
optimizer = optim.Adam(llm_dstl_model.parameters(), lr=learning_rate / 5)
132+
133+
for _ in range(num_updates // 5):
134+
optimizer.zero_grad()
135+
pred_scores = llm_dstl_model(psg_embs=reranked_passage_reps)
136+
loss = loss_dict[llm_loss](pred_scores, reranked_target_scores, weighted=True if llm_loss == "ranknet" else False)
137+
loss.backward()
138+
optimizer.step()
139+
140+
rep = llm_dstl_model.query_rep.data.cpu().detach()
141+
reps.append(rep.numpy())
142+
ids.append(data["query_id"])
143+
except Exception as e:
144+
print(f"Error for query ID {data['query_id']}: {e}")
145+
146+
pickle.dump((np.array(reps), ids), open(output_path, "wb"))
147+
148+
149+
if __name__ == "__main__":
150+
151+
parser = ArgumentParser()
152+
parser.add_argument('--inp_path', required=True)
153+
parser.add_argument('--rerank_path', required=True)
154+
parser.add_argument('--output_path', required=True)
155+
parser.add_argument('--loss_path', required=True)
156+
parser.add_argument('--ce_top_k', type=int, default=100)
157+
parser.add_argument('--llm_top_k', type=int, default=9)
158+
parser.add_argument('--learning_rate', type=float, default=0.005)
159+
parser.add_argument('--num_updates', type=int, default=100)
160+
parser.add_argument('--use_logits', type=int, default=0)
161+
parser.add_argument('--use_alpha', type=int, default=0)
162+
parser.add_argument('--llm_loss', type=str, default="lambdarank")
163+
164+
args = parser.parse_args()
165+
166+
run_query_teacher_importance_learner(args.inp_path, args.rerank_path, args.output_path, args.loss_path, args.ce_top_k, args.llm_top_k, args.learning_rate, args.num_updates, args.use_logits, args.use_alpha, args.llm_loss)

0 commit comments

Comments
 (0)