44
44
#define TEMPLATE (name ,type ) CONCAT(name,type)
45
45
#define Dtype float
46
46
47
- void TEMPLATE (max_pool_forward_impl , Dtype )(
48
- const int nthreads , __global const Dtype * bottom_data , const int num ,
47
+ #if defined KERNEL_MAX_POOL
48
+
49
+ __kernel void
50
+ #ifdef HAVE_MASK
51
+ TEMPLATE (max_pool_forward_mask , Dtype )
52
+ #else
53
+ TEMPLATE (max_pool_forward , Dtype )
54
+ #endif
55
+ (
56
+ const int nthreads , __global const Dtype * bottom_data ,
49
57
const int channels , const int height , const int width ,
50
- const int pooled_height , const int pooled_width , const int kernel_h ,
51
- const int kernel_w , const int stride_h , const int stride_w , const int pad_h ,
52
- const int pad_w ,
53
- __global Dtype * top_data ,
54
- const int use_mask , __global int * mask , __global Dtype * top_mask , bool no_mask )
58
+ const int pooled_height , const int pooled_width ,
59
+ __global Dtype * top_data
60
+ #ifdef HAVE_MASK
61
+ , __global Dtype * mask
62
+ #endif
63
+ )
55
64
{
56
65
for (int index = get_global_id (0 ); index < nthreads ;
57
66
index += get_global_size (0 ))
@@ -60,10 +69,10 @@ void TEMPLATE(max_pool_forward_impl, Dtype)(
60
69
const int ph = (index / pooled_width ) % pooled_height ;
61
70
const int c = (index / pooled_width / pooled_height ) % channels ;
62
71
const int n = index / pooled_width / pooled_height / channels ;
63
- int hstart = ph * stride_h - pad_h ;
64
- int wstart = pw * stride_w - pad_w ;
65
- const int hend = min (hstart + kernel_h , height );
66
- const int wend = min (wstart + kernel_w , width );
72
+ int hstart = ph * STRIDE_H - PAD_H ;
73
+ int wstart = pw * STRIDE_W - PAD_W ;
74
+ const int hend = min (hstart + KERNEL_H , height );
75
+ const int wend = min (wstart + KERNEL_W , width );
67
76
hstart = max (hstart , (int )0 );
68
77
wstart = max (wstart , (int )0 );
69
78
Dtype maxval = - FLT_MAX ;
@@ -79,38 +88,19 @@ void TEMPLATE(max_pool_forward_impl, Dtype)(
79
88
}
80
89
}
81
90
top_data [index ] = maxval ;
82
- if (!no_mask ) {
83
- if (use_mask == 1 ) {
84
- mask [index ] = maxidx ;
85
- } else {
86
- top_mask [index ] = maxidx ;
87
- }
88
- }
91
+ #ifdef HAVE_MASK
92
+ mask [index ] = maxidx ;
93
+ #endif
89
94
}
90
95
}
91
96
92
- __kernel void TEMPLATE (max_pool_forward , Dtype )(
93
- const int nthreads , __global const Dtype * bottom_data , const int num ,
94
- const int channels , const int height , const int width ,
95
- const int pooled_height , const int pooled_width , const int kernel_h ,
96
- const int kernel_w , const int stride_h , const int stride_w , const int pad_h ,
97
- const int pad_w ,
98
- __global Dtype * top_data ,
99
- const int use_mask , __global int * mask , __global Dtype * top_mask )
100
- {
101
- TEMPLATE (max_pool_forward_impl , Dtype )(
102
- nthreads , bottom_data , num , channels , height , width ,
103
- pooled_height , pooled_width , kernel_h ,
104
- kernel_w , stride_h , stride_w , pad_h , pad_w , top_data , use_mask , mask , top_mask , false
105
- );
106
- }
97
+ #elif defined KERNEL_AVE_POOL
107
98
108
99
__kernel void TEMPLATE (ave_pool_forward , Dtype )(
109
- const int nthreads , __global const Dtype * const bottom_data , const int num ,
100
+ const int nthreads , __global const Dtype * const bottom_data ,
110
101
const int channels , const int height , const int width ,
111
- const int pooled_height , const int pooled_width , const int kernel_h ,
112
- const int kernel_w , const int stride_h , const int stride_w , const int pad_h ,
113
- const int pad_w , __global Dtype * top_data )
102
+ const int pooled_height , const int pooled_width ,
103
+ __global Dtype * top_data )
114
104
{
115
105
for (int index = get_global_id (0 ); index < nthreads ;
116
106
index += get_global_size (0 ))
@@ -120,10 +110,10 @@ __kernel void TEMPLATE(ave_pool_forward, Dtype)(
120
110
const int ph = (index / pooled_width ) % pooled_height ;
121
111
const int c = (index / pooled_width / pooled_height ) % channels ;
122
112
const int n = index / pooled_width / pooled_height / channels ;
123
- int hstart = ph * stride_h - pad_h ;
124
- int wstart = pw * stride_w - pad_w ;
125
- int hend = min (hstart + kernel_h , height + pad_h );
126
- int wend = min (wstart + kernel_w , width + pad_w );
113
+ int hstart = ph * STRIDE_H - PAD_H ;
114
+ int wstart = pw * STRIDE_W - PAD_W ;
115
+ int hend = min (hstart + KERNEL_H , height + PAD_H );
116
+ int wend = min (wstart + KERNEL_W , width + PAD_W );
127
117
const int pool_size = (hend - hstart ) * (wend - wstart );
128
118
hstart = max (hstart , (int )0 );
129
119
wstart = max (wstart , (int )0 );
@@ -142,11 +132,12 @@ __kernel void TEMPLATE(ave_pool_forward, Dtype)(
142
132
}
143
133
}
144
134
135
+ #elif defined KERNEL_STO_POOL
136
+
145
137
__kernel void TEMPLATE (sto_pool_forward_test ,Dtype )(
146
- const int nthreads , __global const Dtype * const bottom_data , const int num ,
138
+ const int nthreads , __global const Dtype * const bottom_data ,
147
139
const int channels , const int height , const int width ,
148
- const int pooled_height , const int pooled_width , const int kernel_h ,
149
- const int kernel_w , const int stride_h , const int stride_w ,
140
+ const int pooled_height , const int pooled_width ,
150
141
__global Dtype * top_data )
151
142
{
152
143
for (int index = get_global_id (0 ); index < nthreads ;
@@ -156,10 +147,10 @@ __kernel void TEMPLATE(sto_pool_forward_test,Dtype)(
156
147
const int ph = (index / pooled_width ) % pooled_height ;
157
148
const int c = (index / pooled_width / pooled_height ) % channels ;
158
149
const int n = index / pooled_width / pooled_height / channels ;
159
- const int hstart = ph * stride_h ;
160
- const int hend = min (hstart + kernel_h , height );
161
- const int wstart = pw * stride_w ;
162
- const int wend = min (wstart + kernel_w , width );
150
+ const int hstart = ph * STRIDE_H ;
151
+ const int hend = min (hstart + KERNEL_H , height );
152
+ const int wstart = pw * STRIDE_W ;
153
+ const int wend = min (wstart + KERNEL_W , width );
163
154
// We set cumsum to be 0 to avoid divide-by-zero problems
164
155
Dtype cumsum = FLT_MIN ;
165
156
Dtype cumvalues = 0. ;
@@ -168,10 +159,13 @@ __kernel void TEMPLATE(sto_pool_forward_test,Dtype)(
168
159
// First pass: get sum
169
160
for (int h = hstart ; h < hend ; ++ h ) {
170
161
for (int w = wstart ; w < wend ; ++ w ) {
171
- cumsum += bottom_slice [h * width + w ];
172
- cumvalues += bottom_slice [h * width + w ] * bottom_slice [h * width + w ];
162
+ Dtype v = bottom_slice [h * width + w ];
163
+ cumsum += v ;
164
+ cumvalues += v * v ;
173
165
}
174
166
}
175
167
top_data [index ] = cumvalues / cumsum ;
176
168
}
177
169
}
170
+
171
+ #endif // KERNEL_*
0 commit comments