Skip to content

Commit 6bb5096

Browse files
authored
Merge pull request #35 from Fly-Pluche/main
Add pad token
2 parents 75b3f08 + a1a495c commit 6bb5096

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

hf_mini/utils.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,11 +1091,16 @@
10911091

10921092
import re
10931093
import time
1094+
import torch
1095+
from typing import Union, Dict
1096+
from transformers import AutoTokenizer
1097+
10941098
from hf_mini.filter import SensitiveInforRM
1095-
is_security = SensitiveInforRM()
10961099

1097-
def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
10981100

1101+
is_security = SensitiveInforRM()
1102+
1103+
def input_wrapper(tokenizer: AutoTokenizer, code_string: str, later_code: str = "", path: str = "", pad_token: str = "☺" ) -> Union[Dict,None]:
10991104
start = time.time()
11001105
_sequerity = True
11011106
for i in [code_string, later_code, path]:
@@ -1104,7 +1109,7 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
11041109
break
11051110
print(f"Done inputs checking with {(time.time()-start) * 1000:.2f}ms", flush=True)
11061111
if not _sequerity:
1107-
return ""
1112+
return None
11081113

11091114
extension_pattern = re.compile(r"(\.\w+)$")
11101115
p = ""
@@ -1119,4 +1124,18 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
11191124
des = LANGUAGE_WRAPPER.get(lang, "")
11201125
if len(des) > 0 and "<AIX-SPE>" in des:
11211126
p = des.replace("<AIX-SPE>", f"the file path is: {path}") + "\n"
1122-
return f"<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>{later_code}▁<AIX-SPAN-MIDDLE>{p}{code_string}"
1127+
1128+
# SPM
1129+
pad_ids = tokenizer(pad_token, return_tensors="pt", return_token_type_ids=False)
1130+
pad_len = len(pad_ids["input_ids"][0])
1131+
pre_code_ids = tokenizer("<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>", return_tensors="pt", return_token_type_ids=False)
1132+
1133+
later_code_ids = tokenizer(pad_token + later_code, return_tensors="pt", return_token_type_ids=False)
1134+
later_code_ids["input_ids"] = later_code_ids["input_ids"][:,pad_len:]
1135+
later_code_ids["attention_mask"] = later_code_ids["attention_mask"][:,pad_len:]
1136+
1137+
code_string_ids = tokenizer(f"▁<AIX-SPAN-MIDDLE>{p}{code_string}", return_tensors="pt", return_token_type_ids=False)
1138+
code_string_ids["input_ids"] = torch.cat([pre_code_ids["input_ids"], later_code_ids["input_ids"], code_string_ids["input_ids"]], dim = 1)
1139+
code_string_ids["attention_mask"] = torch.cat([pre_code_ids["attention_mask"], later_code_ids["attention_mask"], code_string_ids["attention_mask"]], dim = 1)
1140+
1141+
return code_string_ids

sess_huggingface.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,18 @@
99
model = AutoModelForCausalLM.from_pretrained("aiXcoder/aixcoder-7b-base", torch_dtype=torch.bfloat16)
1010

1111

12-
text = input_wrapper(
12+
inputs = input_wrapper(
13+
tokenizer=tokenizer,
1314
code_string="# 快速排序算法",
1415
later_code="\n",
15-
path="test.py"
16+
path="test.py",
1617
)
1718

18-
if len(text) == 0:
19+
if inputs is None:
1920
sys.exit()
2021

21-
inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False)
22-
2322
inputs = inputs.to(device)
2423
model.to(device)
2524

2625
outputs = model.generate(**inputs, max_new_tokens=256)
27-
print(tokenizer.decode(outputs[0], skip_special_tokens=False))
26+
print(tokenizer.decode(outputs[0], skip_special_tokens=False))

0 commit comments

Comments
 (0)