Skip to content

Commit cee1ff4

Browse files
committed
Add decompositions for median and nonmedian
ghstack-source-id: 12b12d2 Pull Request resolved: #134881
1 parent 1b9f51b commit cee1ff4

File tree

6 files changed

+148
-14
lines changed

6 files changed

+148
-14
lines changed

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -928,10 +928,6 @@ aten::max_unpool2d
928928
aten::max_unpool2d.out
929929
aten::max_unpool3d
930930
aten::max_unpool3d.out
931-
aten::median
932-
aten::median.dim
933-
aten::median.dim_values
934-
aten::median.out
935931
aten::min
936932
aten::min.dim
937933
aten::min.dim_min
@@ -994,10 +990,6 @@ aten::multilabel_margin_loss_backward
994990
aten::multilabel_margin_loss_backward.grad_input
995991
aten::multinomial
996992
aten::multinomial.out
997-
aten::nanmedian
998-
aten::nanmedian.dim
999-
aten::nanmedian.dim_values
1000-
aten::nanmedian.out
1001993
aten::native_group_norm.out
1002994
aten::native_norm
1003995
aten::native_norm.ScalarOpt_dim_dtype

torch/_decomp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
357357
aten.logsumexp.default,
358358
aten.masked_fill,
359359
aten.masked_fill_,
360+
aten.median,
360361
aten.mish,
361362
aten.mish_,
362363
aten.mse_loss,
@@ -366,6 +367,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
366367
aten.mv,
367368
aten.mvlgamma,
368369
aten.mvlgamma_,
370+
aten.nanmedian,
369371
aten.nansum,
370372
aten.nan_to_num,
371373
aten.nan_to_num_,

torch/_decomp/decompositions.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5088,6 +5088,95 @@ def resize_as(self, other, memory_format=None):
50885088
return aten.resize(self, other.shape, memory_format=memory_format)
50895089

50905090

5091+
@register_decomposition([aten.median.default, aten.median.out])
5092+
@out_wrapper()
5093+
def median(x):
5094+
if x.numel() == 0:
5095+
return x.new_full([], float("nan")).to(x.dtype)
5096+
5097+
return median_impl(x.flatten(), dim=0, keepdim=False, ignore_nan=False)[0]
5098+
5099+
5100+
@register_decomposition([aten.nanmedian.default, aten.nanmedian.out])
5101+
@out_wrapper()
5102+
def nanmedian(x):
5103+
if x.numel() == 0:
5104+
return x.new_full([], float("nan")).to(x.dtype)
5105+
5106+
return median_impl(x.flatten(), dim=0, keepdim=False, ignore_nan=True)[0]
5107+
5108+
5109+
@register_decomposition([aten.median.dim, aten.median.dim_values])
5110+
@out_wrapper("values", "indices")
5111+
def median_dim(x, dim, keepdim=False):
5112+
utils.alert_not_deterministic("median with indices output")
5113+
return median_impl(x, dim=dim, keepdim=keepdim, ignore_nan=False)
5114+
5115+
5116+
@register_decomposition([aten.nanmedian.dim, aten.nanmedian.dim_values])
5117+
@out_wrapper("values", "indices")
5118+
def nanmedian_dim(x, dim, keepdim=False):
5119+
utils.alert_not_deterministic("median with indices output")
5120+
return median_impl(x, dim=dim, keepdim=keepdim, ignore_nan=True)
5121+
5122+
5123+
def median_impl(x, dim, keepdim=False, ignore_nan=True):
5124+
dim = utils.canonicalize_dim(x.dim(), dim)
5125+
5126+
if x.ndim == 0:
5127+
return x.clone(), x.new_full(x.shape, 0, dtype=torch.int64)
5128+
5129+
size = x.shape[dim]
5130+
torch._check(
5131+
size != 0,
5132+
lambda: f"median(): Expected reduction dim {dim} to have non-zero size.",
5133+
)
5134+
5135+
result_shape = list(x.shape)
5136+
if keepdim:
5137+
result_shape[dim] = 1
5138+
else:
5139+
del result_shape[dim]
5140+
5141+
if x.numel() == 0:
5142+
return x.new_empty(result_shape), x.new_empty(result_shape, dtype=torch.int64)
5143+
5144+
sorted_vals, sorted_idxs = aten.sort(x, dim=dim)
5145+
5146+
if ignore_nan:
5147+
k = ((size - 1) - x.isnan().sum(dim=dim, keepdim=True)) // 2
5148+
strides = sorted_vals.stride()
5149+
indices = k * strides[dim]
5150+
for d in range(x.ndim):
5151+
if d == dim:
5152+
continue
5153+
idx_shape = [1] * x.ndim
5154+
idx_shape[d] = -1
5155+
indices = indices + strides[d] * torch.arange(
5156+
x.shape[d], device=x.device
5157+
).view(idx_shape)
5158+
5159+
result_val = aten._unsafe_index(sorted_vals.flatten(), [indices.flatten()])
5160+
result_ind = aten._unsafe_index(sorted_idxs.flatten(), [indices.flatten()])
5161+
else:
5162+
k = (size - 1) // 2
5163+
val_indices: List[Optional[TensorLike]] = [None] * x.ndim
5164+
val_indices[dim] = torch.tensor([k], device=x.device)
5165+
mask_indices: List[Optional[TensorLike]] = [None] * x.ndim
5166+
mask_indices[dim] = torch.tensor([x.shape[dim] - 1], device=x.device)
5167+
5168+
result_val = aten._unsafe_index(sorted_vals, val_indices)
5169+
result_ind = aten._unsafe_index(sorted_idxs, val_indices)
5170+
5171+
last_val = aten._unsafe_index(sorted_vals, mask_indices)
5172+
last_ind = aten._unsafe_index(sorted_idxs, mask_indices)
5173+
5174+
result_val = torch.where(last_val.isnan(), last_val, result_val)
5175+
result_ind = torch.where(last_val.isnan(), last_ind, result_ind)
5176+
5177+
return result_val.view(result_shape), result_ind.view(result_shape)
5178+
5179+
50915180
register_inplace(aten.addbmm_, aten.addbmm)
50925181
register_inplace(aten.addmm_, aten.addmm)
50935182
register_inplace(aten.addmv_, aten.addmv)

torch/_inductor/lowering.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,8 +2254,6 @@ def is_aligned(x):
22542254
make_fallback(aten.kthvalue)
22552255
make_fallback(aten.topk)
22562256
make_fallback(aten.mode)
2257-
make_fallback(aten.median)
2258-
make_fallback(aten.nanmedian)
22592257
make_fallback(aten.randperm)
22602258
# see: https://github.com/pytorch/pytorch/pull/121354
22612259
make_fallback(aten.resize_)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4967,6 +4967,19 @@ def sample_inputs_reduction_count_nonzero(*args, **kwargs):
49674967
sample.kwargs.pop('keepdim', None)
49684968
yield sample
49694969

4970+
4971+
def sample_inputs_reduction_unique(*args, **kwargs):
4972+
# for median with indices, the index might not be unique
4973+
# and depends on the device and the kernel. We return samples with unique
4974+
# values if `unique_values` is true and we return samples with non-unique
4975+
# value if `unique_values` is false.
4976+
unique_values = kwargs.pop('unique_values', False)
4977+
for sample in sample_inputs_reduction(*args, **kwargs):
4978+
is_nonunique = 'dim' in sample.kwargs and sample.args and sample.args[0].unique().numel() != sample.args[0].numel()
4979+
if is_nonunique == unique_values:
4980+
yield sample
4981+
4982+
49704983
def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs):
49714984
N = 10
49724985
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -14002,20 +14015,47 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1400214015
skips=(
1400314016
)),
1400414017
OpInfo('median',
14018+
variant_test_name='nonunique',
14019+
dtypes=all_types_and(torch.bfloat16, torch.float16),
14020+
# TODO: some signatures of median do support out
14021+
supports_out=False,
14022+
supports_forward_ad=True,
14023+
supports_fwgrad_bwgrad=True,
14024+
error_inputs_func=error_inputs_median,
14025+
sample_inputs_func=partial(sample_inputs_reduction_unique, supports_multiple_dims=False, unique_values=False),
14026+
skips=(
14027+
DecorateInfo(unittest.skip("Non-deterministic when non-unique values present"), 'TestDecomp', 'test_comprehensive'),
14028+
DecorateInfo(unittest.skip("Non-deterministic when non-unique values present"), 'TestDecomp', 'test_quick'),
14029+
)),
14030+
OpInfo('median',
14031+
variant_test_name='unique',
1400514032
dtypes=all_types_and(torch.bfloat16, torch.float16),
1400614033
# TODO: some signatures of median do support out
1400714034
supports_out=False,
1400814035
supports_forward_ad=True,
1400914036
supports_fwgrad_bwgrad=True,
1401014037
error_inputs_func=error_inputs_median,
14011-
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
14038+
sample_inputs_func=partial(sample_inputs_reduction_unique, supports_multiple_dims=False, unique_values=True)),
14039+
OpInfo('nanmedian',
14040+
variant_test_name='nonunique',
14041+
dtypes=all_types_and(torch.bfloat16, torch.float16),
14042+
# TODO: some signatures of nanmedian do support out
14043+
supports_out=False,
14044+
supports_forward_ad=True,
14045+
supports_fwgrad_bwgrad=True,
14046+
sample_inputs_func=partial(sample_inputs_reduction_unique, supports_multiple_dims=False, unique_values=False),
14047+
skips=(
14048+
DecorateInfo(unittest.skip("Non-deterministic when non-unique values present"), 'TestDecomp', 'test_comprehensive'),
14049+
DecorateInfo(unittest.skip("Non-deterministic when non-unique values present"), 'TestDecomp', 'test_quick'),
14050+
)),
1401214051
OpInfo('nanmedian',
14052+
variant_test_name='unique',
1401314053
dtypes=all_types_and(torch.bfloat16, torch.float16),
1401414054
# TODO: some signatures of nanmedian do support out
1401514055
supports_out=False,
1401614056
supports_forward_ad=True,
1401714057
supports_fwgrad_bwgrad=True,
14018-
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
14058+
sample_inputs_func=partial(sample_inputs_reduction_unique, supports_multiple_dims=False, unique_values=True)),
1401914059
OpInfo('var_mean',
1402014060
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
1402114061
sample_inputs_func=sample_inputs_std_var,

torch/testing/_internal/opinfo/definitions/_masked.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def masked_samples():
334334
def sample_inputs_masked_softmax(
335335
op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
336336
):
337-
"""Sample inputs for masked softmax, log_softmax, and softmin.
337+
"""Sample inputs for masked softmax, log_softmax, softmin, median,
338+
and nanmedian.
338339
339340
Masked normalization operator is a reduction operator with
340341
trailing mask optional argument. A mask is a bool tensor with the
@@ -856,9 +857,21 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
856857
DecorateInfo(
857858
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
858859
),
860+
DecorateInfo(
861+
unittest.skip("Non-deterministic when non-unique values present"),
862+
"TestDecomp",
863+
"test_comprehensive",
864+
),
865+
DecorateInfo(
866+
unittest.skip("Non-deterministic when non-unique values present"),
867+
"TestDecomp",
868+
"test_quick",
869+
),
859870
),
860871
sample_inputs_func=partial(
861-
sample_inputs_masked_softmax, use_zero_dimensions=False
872+
sample_inputs_masked_softmax,
873+
use_zero_dimensions=False,
874+
unique_values=True,
862875
),
863876
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
864877
),

0 commit comments

Comments
 (0)