1091
1091
1092
1092
import re
1093
1093
import time
1094
+ import torch
1095
+ from typing import Union , Dict
1096
+ from transformers import AutoTokenizer
1097
+
1094
1098
from hf_mini .filter import SensitiveInforRM
1095
- is_security = SensitiveInforRM ()
1096
1099
1097
- def input_wrapper (code_string , later_code : str = "" , path : str = "" ) -> str :
1098
1100
1101
+ is_security = SensitiveInforRM ()
1102
+
1103
+ def input_wrapper (tokenizer : AutoTokenizer , code_string : str , later_code : str = "" , path : str = "" , pad_token : str = "☺" ) -> Union [Dict ,None ]:
1099
1104
start = time .time ()
1100
1105
_sequerity = True
1101
1106
for i in [code_string , later_code , path ]:
@@ -1104,7 +1109,7 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
1104
1109
break
1105
1110
print (f"Done inputs checking with { (time .time ()- start ) * 1000 :.2f} ms" , flush = True )
1106
1111
if not _sequerity :
1107
- return ""
1112
+ return None
1108
1113
1109
1114
extension_pattern = re .compile (r"(\.\w+)$" )
1110
1115
p = ""
@@ -1119,4 +1124,18 @@ def input_wrapper(code_string, later_code: str = "", path: str = "") -> str:
1119
1124
des = LANGUAGE_WRAPPER .get (lang , "" )
1120
1125
if len (des ) > 0 and "<AIX-SPE>" in des :
1121
1126
p = des .replace ("<AIX-SPE>" , f"the file path is: { path } " ) + "\n "
1122
- return f"<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>{ later_code } ▁<AIX-SPAN-MIDDLE>{ p } { code_string } "
1127
+
1128
+ # SPM
1129
+ pad_ids = tokenizer (pad_token , return_tensors = "pt" , return_token_type_ids = False )
1130
+ pad_len = len (pad_ids ["input_ids" ][0 ])
1131
+ pre_code_ids = tokenizer ("<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>" , return_tensors = "pt" , return_token_type_ids = False )
1132
+
1133
+ later_code_ids = tokenizer (pad_token + later_code , return_tensors = "pt" , return_token_type_ids = False )
1134
+ later_code_ids ["input_ids" ] = later_code_ids ["input_ids" ][:,pad_len :]
1135
+ later_code_ids ["attention_mask" ] = later_code_ids ["attention_mask" ][:,pad_len :]
1136
+
1137
+ code_string_ids = tokenizer (f"▁<AIX-SPAN-MIDDLE>{ p } { code_string } " , return_tensors = "pt" , return_token_type_ids = False )
1138
+ code_string_ids ["input_ids" ] = torch .cat ([pre_code_ids ["input_ids" ], later_code_ids ["input_ids" ], code_string_ids ["input_ids" ]], dim = 1 )
1139
+ code_string_ids ["attention_mask" ] = torch .cat ([pre_code_ids ["attention_mask" ], later_code_ids ["attention_mask" ], code_string_ids ["attention_mask" ]], dim = 1 )
1140
+
1141
+ return code_string_ids
0 commit comments