Skip to content

[cuda][complex] Use scaling to compute the absolute value of complex number to avoid overflow #158557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

thenumberouscode
Copy link
Contributor

@thenumberouscode thenumberouscode commented Jul 17, 2025

Fixes #158412

mini repro

  • just copy from the issue:
import torch
max_float64 = 1.7976931348623157e+308
real = torch.ones((3, 3), dtype=torch.float64)
imag = torch.full((3, 3), max_float64, dtype=torch.float64)
input = torch.complex(real, imag)
input_gpu = input.cuda()
try:
    out_gpu = torch.absolute(input=input_gpu)
    print(out_gpu) //all inf
except Exception as e:
    print(e)
input_cpu = input.cpu()
try:
    out_cpu = torch.absolute(input=input_cpu)
    print(out_cpu) //all 1.7977e+308
except Exception as e:
    print(e)

The output from cpu is 1.7976931348623157e+308 while cuda is inf

summary

proof procedure

cuda

  • Let us just write a cu program, and use a complex number whose real is 1.0 and imag is max value of double to reproduce the bug
#include <complex>
#include <cuda_runtime.h>
#include <cmath>
#include <cuda.h>

template <typename T>
__device__ void abs_kernel(T c) {
    printf("%f\n", hypot(c.real(), c.imag()));
}

__device__ void call_abs_kernel() {
    // 1.7976931348623157e+308 is the max value of double.
    std::complex<double> c(1.0, 1.7976931348623157e+308);
    abs_kernel(c);  
}


__global__ void wrapper_call() {
    call_abs_kernel();
}

int main() {
    const int numBlocks = 2;
    const int numThreadsPerBlock = 4;
    wrapper_call<<<numBlocks, numThreadsPerBlock>>>();
    cudaDeviceSynchronize();

    return 0;
}

use nvcc to compile it.

nvcc -o reproduce_abs_kernel reproduce_abs_kernel.cu --expt-relaxed-constexpr

the output is

inf
inf
inf
inf
inf
inf
inf
inf

c++

I wrote a C++ code that is implemented based on the PyTorch CPU absolute function.

return std::abs(static_cast<std::complex<T>>(z));

#include <iostream>
#include <complex>
#include <limits>
#include <cmath>

int main() {
    // std::numeric_limits<double>::max() equals 1.7976931348623157e+308
    std::complex<double> myComplex(1.0, std::numeric_limits<double>::max());

    double absValue = std::abs(myComplex);
    std::cout << "The absolute value of the complex number is: " << absValue << std::endl;

    return 0;

use g++ to compile it:

g++ -o complex_cpp complex_cpp.cpp

and the output is

The absolute value of the complex number is: 1.79769e+308

conclusion

  • The current absolute operation in CUDA cannot handle overflow correctly, so I used scaling to prevent overflow.

Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158557

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4a1ec62 with merge base 2ad5c25 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot label "release notes: cuda"

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Jul 17, 2025
@thenumberouscode
Copy link
Contributor Author

@eqy @syed-ahmed @mruberry Could you review my PR when you have a moment? thanks

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 17, 2025
@thenumberouscode
Copy link
Contributor Author

@eqy @syed-ahmed @mruberry I would really appreciate it if you could review the PR at your convenience.

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 21, 2025

@malfet @ngimel I would really appreciate it if you could review the PR at your convenience.

@ngimel
Copy link
Collaborator

ngimel commented Jul 21, 2025

Please provide performance benchmarks for this change.

@ngimel
Copy link
Collaborator

ngimel commented Jul 21, 2025

Also, due to additional division step, this change reduces precision of the abs calculation

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 23, 2025

Please provide performance benchmarks for this change.

@ngimel I used the following code to benchmark the new abs kernel:

import torch
import time

max_float64 = 1.7976931348623157e+308
# real = torch.ones((100, 100), dtype=torch.float64)
real = torch.rand((100, 100), dtype=torch.float64)
imag = torch.full((100, 100), max_float64, dtype=torch.float64)
input = torch.complex(real, imag)
input_gpu = input.cuda()
total = 0.0
for _ in range(100):
   t0 = time.time_ns()
   torch.cuda.synchronize()
   out_gpu = torch.absolute(input=input_gpu)
   torch.cuda.synchronize()
   t1 = time.time_ns()
   total += (t1-t0)/1e06
print(f"time consumed Delta: {total/10:.3E} ms")

the new abs kernel:

time consumed Delta: 1.274E+01 ms
time consumed Delta: 4.471E+00 ms
time consumed Delta: 4.485E+00 ms
time consumed Delta: 4.314E+00 ms
time consumed Delta: 4.381E+00 ms

the old one:

time consumed Delta: 1.048E+01 ms
time consumed Delta: 4.396E+00 ms
time consumed Delta: 4.394E+00 ms
time consumed Delta: 4.363E+00 ms
time consumed Delta: 4.338E+00 ms

summary:

  1. The first execution may have a high time consumption in both implementations (although, oddly, sometimes it does not).
  2. There is no significant difference in performance between the new and old kernels.

@thenumberouscode
Copy link
Contributor Author

@ngimel I added precision comparisons between the CUDA and CPU implementations in the unit test. It seems that precision may not be an issue based on the current torch.allclose parameters. Do you have a better method to test the precision?

@thenumberouscode
Copy link
Contributor Author

@ngime Shall we continue with the review? Thank you.

@ngimel
Copy link
Collaborator

ngimel commented Jul 25, 2025

You benchmark is for tiny sizes and measures overhead only. allclose parameters are too lax to judge 1 ulp in precision. Honestly I'm not sure that edge case is worth fixing.

@thenumberouscode
Copy link
Contributor Author

You benchmark is for tiny sizes and measures overhead only. allclose parameters are too lax to judge 1 ulp in precision. Honestly I'm not sure that edge case is worth fixing.

@ngimel Your suggestions regarding unit testing are correct. Concerning the bug fix's value, I believe it's necessary because the result will become infinity (inf) when real^2 + img^2 exceeds the type's maximum value. Note that this issue occurs if either the real or imaginary component is larger than the square root of the type's maximum value.

@ngimel
Copy link
Collaborator

ngimel commented Jul 28, 2025

Thrust/cccl is using hypot for abs:

template <typename T>
_CCCL_HOST_DEVICE T abs(const complex<T>& z)
{
  return hypot(z.real(), z.imag());
}

no reason for our implementation to diverge. Similarly the implementation you are changing comes directly from llvm https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex, and we should have very good reasons to deviate from that.

@thenumberouscode
Copy link
Contributor Author

https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex

@ngimel Got it. I agree with your perspective on modifying such foundational code. I still have some questions about how to benchmark the performance of my changes. Does PyTorch have any tools for that? I just want to learn the general methods for my future change.

@thenumberouscode
Copy link
Contributor Author

@ngimel I've changed my PR to maintain the original logic as much as possible. It will only scale the real and imaginary parts if the result of hypot is infinity. I've tested it locally, and the time cost will be roughly equal to the current implementation, but it will double if the result is infinity. Additionally, I've added torch.allclose with rtol and atol. Please check the changes and leave any comments if it's convenient for you.

…numbers, since hypot cannot handle overflow properly.
@ngimel
Copy link
Collaborator

ngimel commented Jul 29, 2025

https://docs.pytorch.org/docs/stable/benchmark_utils.html#module-torch.utils.benchmark run for a variety of sizes (so that runtime changes from microseconds to hundreds of microseconds)

@thenumberouscode
Copy link
Contributor Author

@ngimel I used torch.utils.benchmark to compare the new and old implementations of abs. test code:

from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias
import torch.utils.benchmark as benchmark
import torch
# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.
example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter('k0', minval=1, maxval=100000, distribution='loguniform'),
        FuzzedParameter('k1', minval=1, maxval=100000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=100000000, probability_contiguous=0.6, dtype=torch.complex128)
    ],
    seed=0,
)

def absolute(x):
    return torch.abs(x)

results = []
for tensors, tensor_params, params in example_fuzzer.take(100):
    # description is the column label
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='absolute(x)',
        setup='from __main__ import absolute',
        globals=tensors,
        label='abs',
        sub_label=sub_label,
        description='abs',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

The bash output includes the abs_old column, manually added from prior test results. The performance of abs_new and abs_old is approximately the same.

[--------------------- abs ---------------------------------------------------]
                                      |    abs_new   | abs_old
1 threads: --------------------------------------------------------------
      3766   x 1032                   |    420000 |     420000
      131    x 1696                   |     23700 |      23800
      82     x 9091                   |     79700 |      80000
      692    x 42460                  |   3196000 |    3200000
      9903   x 202  (discontiguous)   |    210000 |     214000
      4010   x 813                    |    353000 |     350000
      6209   x 3                      |      2130 |       2130
      8      x 4837                   |      4280 |       4280
      4      x 41                     |       119 |        120
      12     x 85   (discontiguous)   |       223 |        224
      62     x 153  (discontiguous)   |      1150 |       1150
      2255   x 11                     |      2793 |       2796
      37     x 65                     |       364 |        364
      155    x 87472                  |   1476000 |    1475000
      1843   x 18                     |      3690 |       3690
      3      x 1912                   |       723 |        725
      9      x 69   (discontiguous)   |       181 |        180
      1300   x 2335 (discontiguous)   |    328600 |     329100
      353    x 1                      |       142 |        142
      1      x 161                    |       120 |        121
      2775   x 25256 (discontiguous)  |   7615000 |    7630000
      669    x 21159                  |   1539000 |    1545000
      1      x 16282                  |      1870 |       1870
      27     x 8                      |       124 |        125
      1      x 13954                  |      1620 |       1620
      2449   x 22   (discontiguous)   |      5900 |       5900
      17     x 759                    |      1510 |       1500
      726    x 13   (discontiguous)   |      1140 |       1140
      176    x 78136                  |   1495000 |    1498000
      6840   x 114  (discontiguous)   |     83330 |      84000
      806    x 25625 (discontiguous)  |   2247000 |    2250000
      1987   x 463                    |     99000 |      98600
      94     x 1241                   |     12500 |      12500
      224    x 3805                   |     91100 |      91000
      82     x 5569 (discontiguous)   |     48800 |      48900
      707    x 900                    |     68000 |      68000
      1845   x 1821                   |    364000 |     364000
      30390  x 68                     |    220000 |     221000
      3307   x 3    (discontiguous)   |      1195 |       1200
      221    x 16058 (discontiguous)  |    384000 |     380000
      6      x 1196                   |       878 |        881
      700    x 108                    |      8170 |       8200
      3070   x 185  (discontiguous)   |     60700 |      60800
      14     x 6676 (discontiguous)   |     10100 |      10100
      49     x 280                    |      1590 |       1600
      207    x 23597 (discontiguous)  |    530000 |     533000
      3327   x 1                      |       464 |        464
      1280   x 769                    |    105000 |     105000
      46888  x 1174                   |   5986000 |    5992000
      890    x 4473                   |    430000 |     430700
      97     x 11                     |       219 |        220
      283    x 13                     |       502 |        503
      1      x 148                    |       117 |        118
      3031   x 77                     |     24900 |      24900
      2494   x 185                    |     49300 |      49300
      12     x 2067                   |      2790 |       2795
      1      x 6192                   |       772 |        774
      82     x 874  (discontiguous)   |      7780 |       7790
      23080  x 23   (discontiguous)   |     56700 |      57000
      11     x 54557 (discontiguous)  |     64200 |      64200
      37     x 61270                  |    245000 |     246000
      3613   x 5    (discontiguous)   |      2078 |       2082
      10911  x 5437 (discontiguous)   |   6447000 |    6461000
      4      x 3839                   |      1770 |       1770
      671    x 8                      |       682 |        685
      275    x 59   (discontiguous)   |      1880 |       1880
      5432   x 2                      |      1280 |       1290
      5575   x 4554                   |   2762000 |    2764000
      11     x 27   (discontiguous)   |       146 |        147
      44426  x 2236 (discontiguous)   |  10800000 |   10820000
      3292   x 8                      |      2960 |       2970
      798    x 8663 (discontiguous)   |    753000 |     754000
      64246  x 14   (discontiguous)   |     96300 |      96400
      3782   x 269  (discontiguous)   |    109000 |     109000
      29     x 17567 (discontiguous)  |     54400 |      54500
      246    x 306  (discontiguous)   |      8160 |       8180
      112    x 138                    |      1780 |       1790
      101    x 4    (discontiguous)   |       156 |        158
      7417   x 46                     |     36400 |      36400
      108    x 14                     |       265 |        267
      1      x 4245                   |       563 |        567
      7126   x 5                      |      3950 |       3960
      2      x 2292                   |       598 |        602
      126    x 612  (discontiguous)   |      8352 |       8380
      911    x 241                    |     23400 |      23500
      196    x 419                    |      8860 |       8890
      18330  x 146  (discontiguous)   |    290000 |     290000
      2      x 1996                   |       533 |        538
      17233  x 2551                   |   4777000 |    4788000
      28856  x 4    (discontiguous)   |     12400 |      12400
      3766   x 284                    |    114000 |     115000
      25914  x 276                    |    773000 |     774000
      7      x 237  (discontiguous)   |       294 |        297
      479    x 2552                   |    131000 |     131000
      4      x 91   (discontiguous)   |       152 |        154
      33106  x 523                    |   1883000 |    1886000
      25716  x 196  (discontiguous)   |    545000 |     545000
      2      x 135                    |       133 |        134
      796    x 3812                   |    329000 |     329000
      4539   x 6836                   |   3375000 |    3380000


Times are in microseconds (us).

@thenumberouscode
Copy link
Contributor Author

@ngimel Does my current PR meet the requirements?

@ngimel
Copy link
Collaborator

ngimel commented Aug 4, 2025

The benchmark looks off, 3.4 s for 4539 x 6836 tensor? That can't be true.

@thenumberouscode
Copy link
Contributor Author

The benchmark looks off, 3.4 s for 4539 x 6836 tensor? That can't be true.

@ngimel Got it, I'll check the result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: cuda release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconsistent torch.absolute results on complex128 between CPU and CUDA
4 participants