Skip to content

Commit 70792b5

Browse files
committed
Limit number of values per thread for reductions on three dimensions
1 parent 1465757 commit 70792b5

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ struct ReduceConfig {
209209
int values_per_thread() const {
210210
return div_up(num_inputs, step_input);
211211
}
212+
213+
int mock_values_per_thread(int parallelism) {
214+
return div_up(num_inputs, step_input * parallelism);
215+
}
212216
};
213217

214218
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
@@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
11661170
else if (config.ctas_per_output < 16)
11671171
config.ctas_per_output = 1;
11681172
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
1169-
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
1173+
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
11701174
config.ctas_per_output = 4;
1175+
int vpt = config.values_per_thread();
1176+
// Capping the number of values per thread to 2048 for now
1177+
// based on known use cases.
1178+
while (vpt >= 2048) {
1179+
config.ctas_per_output *= 2;
1180+
// Computes the new values per thread without side effects
1181+
vpt = config.mock_values_per_thread(config.ctas_per_output);
1182+
}
1183+
}
11711184
#endif
11721185
if (config.ctas_per_output > 1) {
11731186
config.input_mult[2] = config.split_input(config.ctas_per_output);

0 commit comments

Comments
 (0)