15
15
import unittest
16
16
17
17
import numpy as np
18
- import torch
19
18
import paddle
20
19
import paddlenlp_ops
21
20
26
25
intel_hpus_module_id = os .environ .get ("FLAGS_selected_intel_hpus" , 0 )
27
26
28
27
29
- def index_copy_torch (input , dim , index , source , dtype ):
30
- dtype_map = {
31
- "float16" : torch . float16 ,
32
- "float32" : torch . float32 ,
33
- "float64" : torch . float64 ,
34
- "int32" : torch . int32 ,
35
- }
36
- torch_dtype = dtype_map [ dtype ]
37
- input_tensor = torch . tensor ( input ). clone (). detach (). to ( dtype = torch_dtype )
38
- index_tensor = torch . tensor ( index ). clone (). detach (). to ( dtype = torch . int64 )
39
- source_tensor = torch . tensor ( source ). clone (). detach (). to ( dtype = torch_dtype )
40
- output = torch . index_copy (
41
- input = input_tensor , dim = dim , index = index_tensor , source = source_tensor
42
- )
43
- return output
28
+ def index_copy_paddle (input , dim , index , source , dtype ):
29
+ input_tensor = paddle . to_tensor ( input , dtype = "float32" ). clone (). cpu ()
30
+ index_tensor = paddle . to_tensor ( index , dtype = "int64" ). clone (). cpu ()
31
+ source_tensor = paddle . to_tensor ( source , dtype = "float32" ). clone (). cpu ()
32
+
33
+ shape = input_tensor . shape
34
+ new_index = []
35
+ for i in range ( 0 , int ( np . prod ( shape [: dim ]))):
36
+ new_index . append ( index_tensor + i * shape [ dim ] )
37
+ new_index = paddle . concat ( new_index )
38
+ new_x = input_tensor . reshape_ ([ - 1 ] + shape [ dim + 1 :] )
39
+ new_source = source_tensor . reshape ([ - 1 ] + shape [ dim + 1 :])
40
+ y = new_x . scatter_ ( new_index , new_source ). reshape_ ( shape )
41
+
42
+ return y
44
43
45
44
46
45
@skip_check_grad_ci (reason = "index_copy_forward ops not support gradient calculation." )
@@ -56,7 +55,7 @@ def setUp(self):
56
55
def init_dtype (self ):
57
56
self .dtype = "float32"
58
57
59
- def check_result (self , torch_res , ops_res ):
58
+ def check_result (self , paddle_res , ops_res ):
60
59
if self .dtype == "float32" :
61
60
rtol = 1e-5
62
61
atol = 1e-6
@@ -73,7 +72,7 @@ def check_result(self, torch_res, ops_res):
73
72
float16 and float32, but got "
74
73
+ self .dtype ,
75
74
)
76
- np .testing .assert_allclose (torch_res , ops_res , rtol = rtol , atol = atol )
75
+ np .testing .assert_allclose (paddle_res , ops_res , rtol = rtol , atol = atol )
77
76
78
77
def index_copy_custom (self , input , dim , index , source ):
79
78
input_tensor = paddle .to_tensor (input , dtype = self .dtype ).clone ()
@@ -121,78 +120,78 @@ def prepare_input(
121
120
def test_index_copy_dim0_index0 (self ):
122
121
input , index , source , dim = self .prepare_input (dim = 0 , index = 0 )
123
122
custom_res = self .index_copy_custom (input , dim , index , source )
124
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
125
- self .check_result (torch_res .numpy (), custom_res )
123
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
124
+ self .check_result (paddle_res .numpy (), custom_res )
126
125
127
126
def test_index_copy_dim0_index1 (self ):
128
127
input , index , source , dim = self .prepare_input (dim = 0 , index = 1 )
129
128
custom_res = self .index_copy_custom (input , dim , index , source )
130
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
131
- self .check_result (torch_res .numpy (), custom_res )
129
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
130
+ self .check_result (paddle_res .numpy (), custom_res )
132
131
133
132
def test_index_copy_dim0_index_max (self ):
134
133
index = max (self .num_heads - 1 , 0 )
135
134
input , index , source , dim = self .prepare_input (dim = 0 , index = index )
136
135
custom_res = self .index_copy_custom (input , dim , index , source )
137
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
138
- self .check_result (torch_res .numpy (), custom_res )
136
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
137
+ self .check_result (paddle_res .numpy (), custom_res )
139
138
140
139
def test_index_copy_dim1_index0 (self ):
141
140
input , index , source , dim = self .prepare_input (dim = 1 , index = 0 )
142
141
custom_res = self .index_copy_custom (input , dim , index , source )
143
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
144
- self .check_result (torch_res .numpy (), custom_res )
142
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
143
+ self .check_result (paddle_res .numpy (), custom_res )
145
144
146
145
def test_index_copy_dim1_index1 (self ):
147
146
input , index , source , dim = self .prepare_input (dim = 1 , index = 1 )
148
147
custom_res = self .index_copy_custom (input , dim , index , source )
149
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
150
- self .check_result (torch_res .numpy (), custom_res .numpy ())
148
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
149
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
151
150
152
151
def test_index_copy_dim1_index_max (self ):
153
152
index = max (self .head_dim - 1 , 0 )
154
153
input , index , source , dim = self .prepare_input (dim = 1 , index = index )
155
154
custom_res = self .index_copy_custom (input , dim , index , source )
156
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
157
- self .check_result (torch_res .numpy (), custom_res .numpy ())
155
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
156
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
158
157
159
158
def test_index_copy_dim2_index0 (self ):
160
159
input , index , source , dim = self .prepare_input (dim = 2 , index = 0 )
161
160
custom_res = self .index_copy_custom (input , dim , index , source )
162
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
163
- self .check_result (torch_res .numpy (), custom_res .numpy ())
161
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
162
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
164
163
165
164
def test_index_copy_dim2_index1 (self ):
166
165
input , index , source , dim = self .prepare_input (dim = 2 , index = 1 )
167
166
custom_res = self .index_copy_custom (input , dim , index , source )
168
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
169
- self .check_result (torch_res .numpy (), custom_res .numpy ())
167
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
168
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
170
169
171
170
def test_index_copy_dim2_index_max (self ):
172
171
index = max (self .seq_length - 1 , 0 )
173
172
input , index , source , dim = self .prepare_input (dim = 2 , index = index )
174
173
custom_res = self .index_copy_custom (input , dim , index , source )
175
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
176
- self .check_result (torch_res .numpy (), custom_res .numpy ())
174
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
175
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
177
176
178
177
def test_index_copy_dim3_index0 (self ):
179
178
input , index , source , dim = self .prepare_input (dim = 3 , index = 0 )
180
179
custom_res = self .index_copy_custom (input , dim , index , source )
181
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
182
- self .check_result (torch_res .numpy (), custom_res .numpy ())
180
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
181
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
183
182
184
183
def test_index_copy_dim3_index1 (self ):
185
184
input , index , source , dim = self .prepare_input (dim = 3 , index = 1 )
186
185
custom_res = self .index_copy_custom (input , dim , index , source )
187
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
188
- self .check_result (torch_res .numpy (), custom_res .numpy ())
186
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
187
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
189
188
190
189
def test_index_copy_dim3_index_max (self ):
191
190
index = max (self .batch_size - 1 , 0 )
192
191
input , index , source , dim = self .prepare_input (dim = 3 , index = index )
193
192
custom_res = self .index_copy_custom (input , dim , index , source )
194
- torch_res = index_copy_torch (input , dim , index , source , dtype = self .dtype )
195
- self .check_result (torch_res .numpy (), custom_res .numpy ())
193
+ paddle_res = index_copy_paddle (input , dim , index , source , dtype = self .dtype )
194
+ self .check_result (paddle_res .numpy (), custom_res .numpy ())
196
195
197
196
198
197
@skip_check_grad_ci (reason = "index_copy_forward ops not support gradient calculation." )
0 commit comments