Skip to content

[inductor] [cpu] [edge case] When processing torch.nan_to_num-.long(), inductor outputs the reciprocal of eager #151510

@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

symptom: First, using torch.nan_to_num to process float("inf") outputs correct res. But after using .long() to convert the dtype. CPU inductor outputs reciprocal results.
device backend: only CPP
repro

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.nan_to_num(x, nan=0, posinf=torch.iinfo(torch.int64).max, neginf=torch.iinfo(torch.int64).min)
        x = x.long()
        return x


model = Model()

x = torch.tensor([[float("inf")]])
inputs = [x]


def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model(*inputs)
    return output


output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')
print(output)
print(c_output)

Error logs

CPP

tensor([[-9223372036854775808]])
tensor([[9223372036854775807]])

triton

tensor([[9223372036854775807]], device='cuda:0')
tensor([[9223372036854775807]], device='cuda:0')

Versions

nightly 20250414

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions