Skip to content

Commit ffe7bea

Browse files
authored
[INTEL_HPU] add HPU stack kernel (#1872)
1 parent 7a8cc78 commit ffe7bea

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs.h"
16+
#include "kernels/hpu_funcs.h"
17+
#include "kernels/hpu_operator.h"
18+
#include "utils/utils.h"
19+
20+
namespace custom_kernel {
21+
22+
class Stack : public HpuFusedOperator {
23+
public:
24+
explicit Stack(synDataType dtype)
25+
: HpuFusedOperator("stack"), dtype_(dtype) {}
26+
27+
void AddNode(ConvertTensors& ct, unsigned params) {
28+
auto inputs = ct.GetTensors();
29+
auto outputs = ct.GetTensors(false);
30+
31+
std::vector<synTensor> syn_inputs;
32+
for (size_t i = 0; i < inputs.size(); i++) {
33+
syn_inputs.push_back(createTensorFromCT(&ct, i));
34+
}
35+
36+
auto concat_dims = outputs[0].dims;
37+
38+
// Merge concat_dims[params] and concat_dims[params+1]
39+
auto reduce_dim = concat_dims.size() - 2 - params;
40+
concat_dims[reduce_dim] *= concat_dims[reduce_dim + 1];
41+
concat_dims.erase(concat_dims.begin() + reduce_dim + 1);
42+
43+
std::vector<synTensor> outputs_concat;
44+
auto concated = createTensorNoPresist("concat", dtype_, concat_dims);
45+
outputs_concat.push_back(concated);
46+
47+
synConcatenateParams concatParams;
48+
concatParams.axis = params;
49+
AddNodeConcat(syn_inputs, outputs_concat, concatParams, guid_ + "concat");
50+
51+
std::vector<synTensor> syn_outputs;
52+
auto stacked = createTensorFromCT(&ct, 0, false);
53+
syn_outputs.push_back(stacked);
54+
55+
AddNodeReshape(outputs_concat, syn_outputs, guid_ + "reshape");
56+
}
57+
58+
protected:
59+
synDataType dtype_;
60+
};
61+
62+
template <typename T, typename Context>
63+
void StackKernel(const Context& dev_ctx,
64+
const std::vector<const phi::DenseTensor*>& x,
65+
int axis,
66+
phi::DenseTensor* y) {
67+
dev_ctx.template Alloc<T>(y);
68+
69+
ConvertTensors ct;
70+
for (size_t i = 0; i < x.size(); i++) {
71+
ct.Add(x[i]);
72+
}
73+
ct.Add(y, false);
74+
75+
axis = CanonicalAxis(static_cast<int64_t>(axis),
76+
static_cast<int64_t>(x[0]->dims().size()));
77+
axis = static_cast<int64_t>(x[0]->dims().size()) - 1 - axis;
78+
unsigned params = static_cast<unsigned>(axis);
79+
80+
std::vector<DIMS> inputs_dims = ct.GetDims();
81+
OpCacheOperator op_info;
82+
op_info.prepareOpInfo<T, unsigned>("StackKernel", inputs_dims, &params);
83+
auto recipe = op_info.GetRecipe();
84+
85+
if (recipe == nullptr) {
86+
Stack op(op_info.datatype_);
87+
op.AddNode(ct, params);
88+
op.Compile();
89+
op_info.setOp(op);
90+
recipe = op_info.GetRecipe();
91+
}
92+
93+
RecipeRunner runner(recipe);
94+
auto tensors = ct.GetDeviceAddr();
95+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
96+
}
97+
98+
} // namespace custom_kernel
99+
100+
PD_REGISTER_PLUGIN_KERNEL(stack,
101+
intel_hpu,
102+
ALL_LAYOUT,
103+
custom_kernel::StackKernel,
104+
float,
105+
int64_t,
106+
phi::dtype::float16,
107+
phi::dtype::bfloat16,
108+
phi::dtype::float8_e4m3fn) {}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
19+
import numpy as np
20+
import paddle
21+
from tests.op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
22+
23+
import os
24+
25+
intel_hpus_module_id = os.environ.get("FLAGS_selected_intel_hpus", 0)
26+
27+
28+
class TestStackOpBf16(OpTest):
29+
def initDefaultParameters(self):
30+
self.num_inputs = 4
31+
self.input_dim = (5, 6, 7)
32+
self.axis = 0
33+
34+
def initParameters(self):
35+
pass
36+
37+
def get_x_names(self):
38+
x_names = []
39+
for i in range(self.num_inputs):
40+
x_names.append("x{}".format(i))
41+
return x_names
42+
43+
def setUp(self):
44+
self.initDefaultParameters()
45+
self.initParameters()
46+
self.op_type = "stack"
47+
self.set_hpu()
48+
self.init_dtype()
49+
self.x = []
50+
self.y = []
51+
for i in range(self.num_inputs):
52+
self.x.append(
53+
convert_float_to_uint16(
54+
np.random.random(size=self.input_dim).astype(np.float32)
55+
)
56+
)
57+
58+
tmp = []
59+
x_names = self.get_x_names()
60+
for i in range(self.num_inputs):
61+
tmp.append((x_names[i], self.x[i]))
62+
63+
self.inputs = {"X": tmp}
64+
for i in self.x:
65+
self.y.append(convert_uint16_to_float(i))
66+
self.outputs = {"Y": np.stack(self.y, axis=self.axis)}
67+
self.attrs = {"axis": self.axis}
68+
69+
def set_hpu(self):
70+
self.__class__.use_custom_device = True
71+
self.__class__.no_need_check_grad = True
72+
self.place = paddle.CustomPlace("intel_hpu", int(intel_hpus_module_id))
73+
74+
def init_dtype(self):
75+
self.dtype = np.float32
76+
77+
def test_check_output(self):
78+
self.check_output_with_place(self.place)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)