File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,10 @@ struct ReduceConfig {
209
209
int values_per_thread () const {
210
210
return div_up (num_inputs, step_input);
211
211
}
212
+
213
+ int mock_values_per_thread (int parallelism) {
214
+ return div_up (num_inputs, step_input * parallelism);
215
+ }
212
216
};
213
217
214
218
std::ostream& operator <<(std::ostream& out, const ReduceConfig& config);
@@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
1166
1170
else if (config.ctas_per_output < 16 )
1167
1171
config.ctas_per_output = 1 ;
1168
1172
bool is_channel_last = iter.tensor_base (1 ).is_contiguous (at::MemoryFormat::ChannelsLast);
1169
- if (iter.ndim () == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
1173
+ if (iter.ndim () == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
1170
1174
config.ctas_per_output = 4 ;
1175
+ int vpt = config.values_per_thread ();
1176
+ // Capping the number of values per thread to 2048 for now
1177
+ // based on known use cases.
1178
+ while (vpt >= 2048 ) {
1179
+ config.ctas_per_output *= 2 ;
1180
+ // Computes the new values per thread without side effects
1181
+ vpt = config.mock_values_per_thread (config.ctas_per_output );
1182
+ }
1183
+ }
1171
1184
#endif
1172
1185
if (config.ctas_per_output > 1 ) {
1173
1186
config.input_mult [2 ] = config.split_input (config.ctas_per_output );
You can’t perform that action at this time.
0 commit comments