Skip to content

Commit 386a9fe

Browse files
committed
add code for generating critic scores
1 parent 59ad35a commit 386a9fe

File tree

10 files changed

+157
-35
lines changed

10 files changed

+157
-35
lines changed

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Authors:
2929
* [x] [Running Unit Tests](#running-unit-tests)
3030
* [x] [Evaluating Programs](#evaluating-programs)
3131
* [x] [Training Critic](#training-critic)
32+
* [x] [Generating Critic Scores](#generating-critic-scores)
3233
* [ ] [Generating Programs with Critic Sampling](#generating-programs-with-critic-sampling)
3334
* [x] [Example Generated Programs](#example-generated-programs)
3435
* [x] [Citation](#citation)
@@ -125,7 +126,8 @@ We created `scripts/generate.sh` to generate programs on the APPS benchmark. You
125126
| `end` | end index of test samples to be generated | 5000 |
126127
|`num_seqs` | number of total output programs to be generated (for sampling generation) | 1000 |
127128
| `num_seqs_per_iter` | Depending on the limit of GPU, we can generate multiple rounds, each with this number of output programs | 50 |
128-
| `temp` | temperature for sampling generation | 0.6 ||
129+
| `temp` | temperature for sampling generation | 0.6 |
130+
| `output_path` | Path to save generated programs | outputs/codes/ |
129131

130132
Other parameters are defined in the file `utils/generate_configs.py`.
131133

@@ -162,7 +164,7 @@ To compute the pass@k metrics, rather than using the APPS evaluation metrics, we
162164

163165
### Training Critic
164166

165-
We can train a critic model as a classifier that predicts the test outcomes of generated samples. For each training sample, we can follow the prior processes to generate programs and evaluate them with available unit tests. On average, we generate 20 programs per training sample (we provided some example generated programs in `data/APPS/train/`).
167+
We can train a critic model as a classifier that predicts the test outcomes of generated samples. For each training sample, we can follow the prior processes ([generating programs](#generating-programs) and [running unit tests](running-unit-tests)) to obtain synthetic samples and their annotations of unit test outcomes. On average, we generate 20 programs per training sample (we provided some example generated programs in `data/APPS/train/`).
166168

167169
Once the programs are tested, we can used their test outcomes as annotations to train a critic model initialized from a LM pretrained on source code data (we used CodeT5-based in this case).
168170

@@ -185,6 +187,20 @@ Other parameters are defined in the file `utils/train_critic_configs.py`.
185187

186188
Running the script will train a critic model as a classifier that receives inputs as a problem description + a generated program and returns an output as one of 4 test outcomes: compile error, runtime error, failed tests, and passed tests. The model checkpoints are saved in a folder under `exps/`.
187189

190+
### Generating Critic Scores
191+
192+
We created `scripts/generate_critic_scores.sh` to generate critic scores for synthetic programs. We use the same parameters as defined in [the generating program process](generating-programs) with the following additional parameters:
193+
194+
| **Parameters** | **Description** | **Example Values** |
195+
|:-----------------:|:--------------------------------------------------------------------------------------------------------:|:------------------------------:|
196+
| `critic_scores` | Enable this to run inference on critic models and obtain critic scores | N/A |
197+
| `gt_solutions` | Enable this to run inference on ground-truth programs; else, synthetic programs are used by default | N/A |
198+
199+
Other parameters are defined in the file `utils/generate_configs.py`.
200+
201+
Running the generation script will output programs, each of which is saved into a `pkl` (pickle) file, including data fields `code` (list of programs), `prompt` (constructed input sequence to the critic model), `gt_error_type` (ground-truth test outcomes), `pred_error_type` (predicted test outcomes by critic), `error_hidden_states` (hidden states returned by critic).
202+
203+
188204
### Generating Programs with Critic Sampling
189205

190206
We will release the implementation details of our critic sampling procedure.

configs/generate_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
parser.add_argument("--output_path", type=str, help='Path to save output programs')
1414
parser.add_argument("--model_path", type=str, help='Path of trained model')
1515
parser.add_argument("--tokenizer_path", type=str, help='Path to the tokenizer')
16+
parser.add_argument("--critic_scores", default=False, action='store_true', help='if model is a critic model, enable this to output critic scores')
1617

1718
parser.add_argument("--num_seqs", default=5, type=int, help='Number of total generated programs per test sample')
1819
parser.add_argument('--num_seqs_per_iter', default=5, type=int, help='Number of possible minibatch to generate programs per iteration, depending on GPU memory')
1920

2021
parser.add_argument("--max_len", default=512, type=int, help='Maximum length of output sequence')
2122
parser.add_argument('--source_len', default=600, type=int, help='Maximum length of input sequence')
23+
parser.add_argument('--gt_solutions', default=False, action='store_true', help='Only when critic is used, enable this to estimate returns/rewards for ground-truth programs, else synthetic programs by default')
2224

2325
parser.add_argument("--temperature", default=0.6, type=float, help='temperature for sampling tokens')
2426
parser.add_argument("-s","--start", default=0, type=int, help='start index of test samples')

generate.py

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
#
2-
# '''
32
# Copyright (c) 2022, salesforce.com, inc.
43
# All rights reserved.
54
# SPDX-License-Identifier: BSD-3-Clause
65
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7-
# '''#
6+
#
87
import json
98
import os
109
import pprint
1110
import torch
1211
import pdb
1312
import glob
1413
from tqdm import tqdm
14+
import pickle as pkl
15+
import numpy as np
16+
from collections import Counter
1517
from transformers import RobertaTokenizer, T5ForConditionalGeneration
18+
import datasets.utils as dsutils
1619

1720
def generate_prompt(args, test_case_path, prompt_path, solutions_path, tokenizer,
1821
starter_path=None):
@@ -46,6 +49,42 @@ def generate_prompt(args, test_case_path, prompt_path, solutions_path, tokenizer
4649

4750
return _input
4851

52+
def generate_critic_inputs(args, test_case_path, prompt_path, solutions_path, tokenizer,
53+
starter_path=None, gt_solutions=False):
54+
_input = generate_prompt(args, test_case_path, prompt_path, solutions_path, tokenizer, starter_path)
55+
56+
q_tokens = tokenizer.encode(_input, verbose=False, max_length=args.source_len)
57+
in_tokens = [tokenizer.eos_token_id] * args.source_len
58+
in_tokens[:len(q_tokens)] = q_tokens
59+
in_tokens = in_tokens[:args.source_len]
60+
61+
solutions = json.load(open(solutions_path, 'r'))
62+
63+
all_texts = []
64+
gt_errors = []
65+
all_codes = []
66+
67+
for sol_index, solution in enumerate(solutions):
68+
if gt_solutions:
69+
solution_str = dsutils.reindent_code(solution)
70+
else:
71+
solution_str = dsutils.reindent_code(solution['code'])
72+
73+
a_tokens = tokenizer.encode(solution_str)
74+
code = [-100] * args.max_len
75+
code[:len(a_tokens)] = a_tokens
76+
code = code[:args.max_len]
77+
78+
all_texts.append(in_tokens)
79+
all_codes.append(code)
80+
81+
if gt_solutions:
82+
gt_errors.append(dsutils.get_error_type(True))
83+
else:
84+
gt_errors.append(dsutils.get_error_type(solution['result']))
85+
86+
return all_texts, all_codes, gt_errors
87+
4988
def main(args):
5089

5190
argsdict = vars(args)
@@ -71,56 +110,103 @@ def main(args):
71110
# Set up model
72111
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base', cache_dir=args.tokenizer_path)
73112
print("Loading model from {}...".format(args.model_path))
74-
model = T5ForConditionalGeneration.from_pretrained(args.model_path)
113+
if args.critic_scores:
114+
model = T5ForConditionalGeneration.from_pretrained(args.model_path, tuning_mode='critic')
115+
else:
116+
model = T5ForConditionalGeneration.from_pretrained(args.model_path)
117+
75118
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76119
model.to(device)
77120

121+
if args.critic_scores:
122+
all_preds = []
123+
all_gts = []
124+
78125
# main eval loop
79126
for index, problem in tqdm(enumerate(problems), ncols=0, total=len(problems)):
80127

81128
prob_path = os.path.join(problem)
82129
print(f"problem path = {prob_path}")
83130

84131
problem_id = int(problem.split('/')[-1])
85-
if os.path.exists(os.path.join(args.output_path, f"{problem_id}.json")):
132+
133+
if args.critic_scores and \
134+
os.path.exists(os.path.join(args.output_path, f"{problem_id}_gt{args.gt_solutions}.pkl")):
86135
continue
87-
136+
elif os.path.exists(os.path.join(args.output_path, f"{problem_id}.json")):
137+
continue
138+
88139
test_case_path = os.path.join(prob_path, "input_output.json")
89140
prompt_path = os.path.join(prob_path, "question.txt")
90141
starter_path = os.path.join(prob_path, "starter_code.py")
91-
solutions_path = os.path.join(prob_path, "solutions.json")
142+
if args.critic_scores and not args.gt_solutions:
143+
solutions_path = os.path.join(prob_path, "gen_solutions.json")
144+
else:
145+
solutions_path = os.path.join(prob_path, "solutions.json")
92146
if not os.path.exists(starter_path):
93147
starter_path = None
94148

95-
input_text = generate_prompt(args, test_case_path, prompt_path, solutions_path,
149+
if args.critic_scores:
150+
input_texts, input_codes, gt_error_types = generate_critic_inputs(args, test_case_path, prompt_path, solutions_path,
151+
tokenizer, starter_path, args.gt_solutions)
152+
else:
153+
input_text = generate_prompt(args, test_case_path, prompt_path, solutions_path,
96154
tokenizer, starter_path)
97155

98156
with torch.no_grad():
99-
input_ids = torch.LongTensor(tokenizer.encode(input_text,
100-
verbose=False,
101-
max_length=args.source_len)).unsqueeze(0).cuda()
102-
103-
num_loops = int(args.num_seqs / args.num_seqs_per_iter)
104-
output_programs = []
105-
for i in tqdm(range(num_loops), ncols=0, total=num_loops, leave=False):
106-
output_ids = model.generate(
107-
input_ids,
108-
do_sample=True,
109-
temperature=args.temperature,
110-
max_length=args.max_len,
111-
num_return_sequences=args.num_seqs_per_iter,
112-
top_p=0.95)
157+
if args.critic_scores:
158+
text_tensor = torch.tensor(input_texts).to(device)
159+
code_tensor = torch.tensor(input_codes).to(device)
160+
gt_error_tensor = torch.tensor(gt_error_types).to(device)
161+
162+
curr_inputs = {'input_ids': text_tensor, 'error_types': gt_error_tensor, 'labels': code_tensor}
163+
_, error_preds, error_hidden_states = model(**curr_inputs, return_error_hidden_states=True)
113164

114-
for output_id in output_ids:
115-
output_programs.append(tokenizer.decode(output_id, skip_special_tokens=True))
165+
assert len(gt_error_types) == len(error_preds)
166+
all_preds.extend(error_preds.cpu().numpy().tolist())
167+
all_gts.extend(gt_error_types)
168+
169+
saved_critic_scores = {}
170+
saved_critic_scores[problem_id] = {'code': input_codes, 'prompt': input_texts,
171+
'gt_error_type': gt_error_types,
172+
'pred_error_type': error_preds.cpu().numpy(),
173+
'error_hidden_states': error_hidden_states.cpu().numpy()}
174+
scores_loc = os.path.join(args.output_path, f"{problem_id}_gt{args.gt_solutions}.pkl")
175+
pkl.dump(saved_critic_scores, open(scores_loc, 'wb'))
176+
177+
else:
178+
input_ids = torch.LongTensor(tokenizer.encode(input_text,
179+
verbose=False,
180+
max_length=args.source_len)).unsqueeze(0).cuda()
116181

117-
saved_codes = {}
118-
saved_codes[problem_id] = {'code': output_programs, 'prompt': input_text}
119-
120-
codes_loc = os.path.join(args.output_path, f"{problem_id}.json")
121-
with open(codes_loc, "w") as f:
122-
json.dump(saved_codes, f)
123-
182+
num_loops = int(args.num_seqs / args.num_seqs_per_iter)
183+
output_programs = []
184+
for i in tqdm(range(num_loops), ncols=0, total=num_loops, leave=False):
185+
output_ids = model.generate(
186+
input_ids,
187+
do_sample=True,
188+
temperature=args.temperature,
189+
max_length=args.max_len,
190+
num_return_sequences=args.num_seqs_per_iter,
191+
top_p=0.95)
192+
193+
for output_id in output_ids:
194+
output_programs.append(tokenizer.decode(output_id, skip_special_tokens=True))
195+
196+
saved_codes = {}
197+
saved_codes[problem_id] = {'code': output_programs, 'prompt': input_text}
198+
199+
codes_loc = os.path.join(args.output_path, f"{problem_id}.json")
200+
with open(codes_loc, "w") as f:
201+
json.dump(saved_codes, f)
202+
203+
if args.critic_scores:
204+
print("Total number of samples: {}".format(len(all_gts)))
205+
acc = (np.array(all_preds) == np.array(all_gts)).sum()/len(all_gts)
206+
print("Error Pred Acc: {}".format(acc))
207+
print("Prediction distribution: {}".format(Counter(all_preds)))
208+
print("GT distribution: {}".format(Counter(all_gts)))
209+
124210
if __name__ == "__main__":
125211

126212
from configs.generate_configs import *

outputs/critic_scores/111_gtFalse.pkl

201 KB
Binary file not shown.

outputs/critic_scores/111_gtTrue.pkl

11.1 KB
Binary file not shown.

outputs/critic_scores/1_gtFalse.pkl

203 KB
Binary file not shown.

outputs/critic_scores/1_gtTrue.pkl

242 KB
Binary file not shown.

scripts/generate.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
##
2-
## '''
32
## Copyright (c) 2022, salesforce.com, inc.
43
## All rights reserved.
54
## SPDX-License-Identifier: BSD-3-Clause
65
## For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7-
## '''##
6+
##
87
model_path=models/codet5_finetuned_codeRL
98
tokenizer_path=models/codet5_tokenizer/
109
test_path=data/APPS/test/

scripts/generate_critic_scores.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
##
2+
## Copyright (c) 2022, salesforce.com, inc.
3+
## All rights reserved.
4+
## SPDX-License-Identifier: BSD-3-Clause
5+
## For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
##
7+
critic_path=models/codet5_finetuned_critic/
8+
tokenizer_path=models/codet5_tokenizer/
9+
test_path=data/APPS/train/ #test.json
10+
11+
output_path=outputs/critic_scores/
12+
13+
CUDA_VISIBLE_DEVICES=0 python generate.py \
14+
--model_path ${critic_path} \
15+
--test_path ${test_path} \
16+
--output_path ${output_path} \
17+
--critic_scores --gt_solutions

transformers/src/transformers/models/t5/modeling_t5.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,7 @@ def forward(
15421542
output_hidden_states=None,
15431543
return_dict=None,
15441544
error_types=None,
1545+
return_error_hidden_states=False
15451546
):
15461547
r"""
15471548
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1665,7 +1666,8 @@ def forward(
16651666
error_pred_loss_fct = CrossEntropyLoss()
16661667
error_pred_loss = error_pred_loss_fct(error_logits.view(-1, error_logits.size(-1)), error_types.view(-1))
16671668
_, error_preds = torch.max(error_logits, dim=-1)
1668-
1669+
if return_error_hidden_states:
1670+
return error_pred_loss, error_preds, error_states
16691671
return error_pred_loss, error_preds
16701672

16711673
if not return_dict:

0 commit comments

Comments
 (0)