Skip to content

Commit 677d48e

Browse files
qjia7annxingyuan
andcommitted
webgpu: Increase work per thread for maxpool (tensorflow#2628)
PERF With this change, maxPool[1, 131, 131, 64] has 50%~90% speedup on different platforms. * Add workPerThread to shaderKey Co-authored-by: Ann Yuan <annyuan@google.com>
1 parent 7ba9a7e commit 677d48e

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

tfjs-backend-webgpu/src/benchmark_ops_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ describeWebGPU('Ops benchmarks', () => {
159159
it('maxPool', async () => {
160160
const x = tf.randomNormal<tf.Rank.R4>([1, 131, 131, 64]);
161161

162-
await time(() => tf.maxPool(x, 2, 1, 'same'));
162+
await time(() => tf.maxPool(x, 2, 1, 'same'), null, true, 10, 10);
163163
});
164164

165165
it('prelu', async () => {

tfjs-backend-webgpu/src/kernels/maxpool_webgpu.ts

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ export class MaxPoolProgram implements WebGPUProgram {
2929
dispatch: [number, number, number];
3030
variableNames = ['x'];
3131
uniforms = 'ivec2 pad, stride, dilation, convDims, filterDims;';
32-
workGroupSize: [number, number, number] = [4, 4, 4];
32+
// TODO(jiajia.qin@intel.com): Dynamically choose different workGroupSize and
33+
// workPerThead for different output shapes.
34+
workGroupSize: [number, number, number] = [4, 4, 1];
35+
workPerThread = 16;
3336

3437
constructor(convInfo: backend_util.Conv2DInfo) {
3538
this.outputShape = convInfo.outShape;
3639

37-
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
40+
this.dispatchLayout = {x: [0, 1], y: [2], z: [3]};
3841

3942
this.dispatch = computeDispatch(
40-
this.dispatchLayout, this.outputShape, this.workGroupSize);
43+
this.dispatchLayout, this.outputShape, this.workGroupSize,
44+
[1, 1, this.workPerThread]);
4145

42-
// TODO: Parallelize max computation by thread and merge result.
4346
this.userCode = `
4447
float getValue(int batch, int xR, int xC, int d) {
4548
if (xC < 0 || xC >= convDims.x) {
@@ -50,15 +53,17 @@ export class MaxPoolProgram implements WebGPUProgram {
5053
5154
void main() {
5255
ivec4 coords = getOutputCoords();
53-
int batch = coords[0];
54-
int d = coords[3];
55-
5656
if (all(lessThan(coords, outShape))) {
57+
int batch = coords[0];
5758
ivec2 xRCCorner = coords.yz * stride - pad;
5859
int xRCorner = xRCCorner.x;
5960
int xCCorner = xRCCorner.y;
6061
61-
float minMaxValue = 0.0;
62+
float minMaxValue[${this.workPerThread}];
63+
for (int i = 0; i < ${this.workPerThread}; i++)
64+
{
65+
minMaxValue[i] = 0.0;
66+
}
6267
6368
for (int wR = 0; wR < filterDims.y; wR += dilation.y) {
6469
int xR = xRCorner + wR;
@@ -69,14 +74,36 @@ export class MaxPoolProgram implements WebGPUProgram {
6974
7075
for (int wC = 0; wC < filterDims.x; wC += dilation.x) {
7176
int xC = xCCorner + wC * dilation.x;
72-
float value = getValue(batch, xR, xC, d);
73-
minMaxValue = max(value, minMaxValue);
77+
for (int i = 0; i < ${this.workPerThread}; i++)
78+
{
79+
int d = coords[3] * ${this.workPerThread} + i;
80+
if (d < outShape[3])
81+
{
82+
float value = getValue(batch, xR, xC, d);
83+
minMaxValue[i] = max(value, minMaxValue[i]);
84+
}
85+
else
86+
{
87+
break;
88+
}
89+
}
90+
}
91+
}
92+
for (int i = 0; i < ${this.workPerThread}; i++)
93+
{
94+
int d = coords[3] * ${this.workPerThread} + i;
95+
if (d < outShape[3])
96+
{
97+
setOutput(batch, coords[1], coords[2], d, minMaxValue[i]);
98+
}
99+
else
100+
{
101+
break;
74102
}
75103
}
76-
setOutput(batch, coords[1], coords[2], d, minMaxValue);
77104
}
78105
}
79106
`;
80-
this.shaderKey = 'maxpool';
107+
this.shaderKey = `maxpool${this.workPerThread}`;
81108
}
82109
}

0 commit comments

Comments
 (0)