diff --git a/hf_mini/utils.py b/hf_mini/utils.py index 9fee8ac..cca09e9 100644 --- a/hf_mini/utils.py +++ b/hf_mini/utils.py @@ -1091,11 +1091,16 @@ import re import time +import torch +from typing import Union, Dict +from transformers import AutoTokenizer + from hf_mini.filter import SensitiveInforRM -is_security = SensitiveInforRM() -def input_wrapper(code_string, later_code: str = "", path: str = "") -> str: +is_security = SensitiveInforRM() + +def input_wrapper(tokenizer: AutoTokenizer, code_string: str, later_code: str = "", path: str = "", pad_token: str = "☺" ) -> Union[Dict,None]: start = time.time() _sequerity = True for i in [code_string, later_code, path]: @@ -1104,7 +1109,7 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str: break print(f"Done inputs checking with {(time.time()-start) * 1000:.2f}ms", flush=True) if not _sequerity: - return "" + return None extension_pattern = re.compile(r"(\.\w+)$") p = "" @@ -1119,4 +1124,18 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str: des = LANGUAGE_WRAPPER.get(lang, "") if len(des) > 0 and "" in des: p = des.replace("", f"the file path is: {path}") + "\n" - return f"{later_code}▁{p}{code_string}" \ No newline at end of file + + # SPM + pad_ids = tokenizer(pad_token, return_tensors="pt", return_token_type_ids=False) + pad_len = len(pad_ids["input_ids"][0]) + pre_code_ids = tokenizer("", return_tensors="pt", return_token_type_ids=False) + + later_code_ids = tokenizer(pad_token + later_code, return_tensors="pt", return_token_type_ids=False) + later_code_ids["input_ids"] = later_code_ids["input_ids"][:,pad_len:] + later_code_ids["attention_mask"] = later_code_ids["attention_mask"][:,pad_len:] + + code_string_ids = tokenizer(f"▁{p}{code_string}", return_tensors="pt", return_token_type_ids=False) + code_string_ids["input_ids"] = torch.cat([pre_code_ids["input_ids"], later_code_ids["input_ids"], code_string_ids["input_ids"]], dim = 1) + code_string_ids["attention_mask"] = torch.cat([pre_code_ids["attention_mask"], later_code_ids["attention_mask"], code_string_ids["attention_mask"]], dim = 1) + + return code_string_ids \ No newline at end of file diff --git a/sess_huggingface.py b/sess_huggingface.py index 3fb91fa..d5984a8 100755 --- a/sess_huggingface.py +++ b/sess_huggingface.py @@ -9,19 +9,18 @@ model = AutoModelForCausalLM.from_pretrained("aiXcoder/aixcoder-7b-base", torch_dtype=torch.bfloat16) -text = input_wrapper( +inputs = input_wrapper( + tokenizer=tokenizer, code_string="# 快速排序算法", later_code="\n", - path="test.py" + path="test.py", ) -if len(text) == 0: +if inputs is None: sys.exit() -inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False) - inputs = inputs.to(device) model.to(device) outputs = model.generate(**inputs, max_new_tokens=256) -print(tokenizer.decode(outputs[0], skip_special_tokens=False)) +print(tokenizer.decode(outputs[0], skip_special_tokens=False)) \ No newline at end of file