@@ -29,17 +29,20 @@ export class MaxPoolProgram implements WebGPUProgram {
29
29
dispatch : [ number , number , number ] ;
30
30
variableNames = [ 'x' ] ;
31
31
uniforms = 'ivec2 pad, stride, dilation, convDims, filterDims;' ;
32
- workGroupSize : [ number , number , number ] = [ 4 , 4 , 4 ] ;
32
+ // TODO(jiajia.qin@intel.com): Dynamically choose different workGroupSize and
33
+ // workPerThead for different output shapes.
34
+ workGroupSize : [ number , number , number ] = [ 4 , 4 , 1 ] ;
35
+ workPerThread = 16 ;
33
36
34
37
constructor ( convInfo : backend_util . Conv2DInfo ) {
35
38
this . outputShape = convInfo . outShape ;
36
39
37
- this . dispatchLayout = { x : [ 2 ] , y : [ 1 ] , z : [ 0 , 3 ] } ;
40
+ this . dispatchLayout = { x : [ 0 , 1 ] , y : [ 2 ] , z : [ 3 ] } ;
38
41
39
42
this . dispatch = computeDispatch (
40
- this . dispatchLayout , this . outputShape , this . workGroupSize ) ;
43
+ this . dispatchLayout , this . outputShape , this . workGroupSize ,
44
+ [ 1 , 1 , this . workPerThread ] ) ;
41
45
42
- // TODO: Parallelize max computation by thread and merge result.
43
46
this . userCode = `
44
47
float getValue(int batch, int xR, int xC, int d) {
45
48
if (xC < 0 || xC >= convDims.x) {
@@ -50,15 +53,17 @@ export class MaxPoolProgram implements WebGPUProgram {
50
53
51
54
void main() {
52
55
ivec4 coords = getOutputCoords();
53
- int batch = coords[0];
54
- int d = coords[3];
55
-
56
56
if (all(lessThan(coords, outShape))) {
57
+ int batch = coords[0];
57
58
ivec2 xRCCorner = coords.yz * stride - pad;
58
59
int xRCorner = xRCCorner.x;
59
60
int xCCorner = xRCCorner.y;
60
61
61
- float minMaxValue = 0.0;
62
+ float minMaxValue[${ this . workPerThread } ];
63
+ for (int i = 0; i < ${ this . workPerThread } ; i++)
64
+ {
65
+ minMaxValue[i] = 0.0;
66
+ }
62
67
63
68
for (int wR = 0; wR < filterDims.y; wR += dilation.y) {
64
69
int xR = xRCorner + wR;
@@ -69,14 +74,36 @@ export class MaxPoolProgram implements WebGPUProgram {
69
74
70
75
for (int wC = 0; wC < filterDims.x; wC += dilation.x) {
71
76
int xC = xCCorner + wC * dilation.x;
72
- float value = getValue(batch, xR, xC, d);
73
- minMaxValue = max(value, minMaxValue);
77
+ for (int i = 0; i < ${ this . workPerThread } ; i++)
78
+ {
79
+ int d = coords[3] * ${ this . workPerThread } + i;
80
+ if (d < outShape[3])
81
+ {
82
+ float value = getValue(batch, xR, xC, d);
83
+ minMaxValue[i] = max(value, minMaxValue[i]);
84
+ }
85
+ else
86
+ {
87
+ break;
88
+ }
89
+ }
90
+ }
91
+ }
92
+ for (int i = 0; i < ${ this . workPerThread } ; i++)
93
+ {
94
+ int d = coords[3] * ${ this . workPerThread } + i;
95
+ if (d < outShape[3])
96
+ {
97
+ setOutput(batch, coords[1], coords[2], d, minMaxValue[i]);
98
+ }
99
+ else
100
+ {
101
+ break;
74
102
}
75
103
}
76
- setOutput(batch, coords[1], coords[2], d, minMaxValue);
77
104
}
78
105
}
79
106
` ;
80
- this . shaderKey = ' maxpool' ;
107
+ this . shaderKey = ` maxpool${ this . workPerThread } ` ;
81
108
}
82
109
}
0 commit comments