Skip to content

[Profiler] Fix Empty C Call Queue #150370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

sraikund16
Copy link
Contributor

@sraikund16 sraikund16 commented Apr 1, 2025

Summary:
My commandeer of #150102

Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. It adds some unnecessary events at worst but prevents segfaults/failures. My PR just cleans up some refcount impl and logging.

Contributors: @arjun-choudhry

Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues.

Differential Revision: D72207570

Copy link

pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150370

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit f7b4515 with merge base 783f045 (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72207570

@sraikund16
Copy link
Contributor Author

@pramodk @oraluben do you mind checking if this PR fixes the issue?

@sraikund16 sraikund16 requested a review from briancoutinho April 1, 2025 00:07
@sraikund16 sraikund16 assigned ngimel and unassigned ngimel Apr 1, 2025
@sraikund16 sraikund16 requested a review from ngimel April 1, 2025 00:07
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72207570

Summary:
Pull Request resolved: pytorch#150370

My commandeer of pytorch#150102

Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. My diff just cleans up some refcount impl and logging.

Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues.

Differential Revision: D72207570
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72207570

@sraikund16 sraikund16 added release notes: profiler release notes category topic: bug fixes topic category labels Apr 1, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 1, 2025
@pramodk
Copy link

pramodk commented Apr 1, 2025

@sraikund16, thanks! I am able to build with latest change but I won't be able to test this at the moment / today. I am attaching here another simple test I saw failing with nemo:25.02 container. In case you could test this with local build (it's standalone):

import torch

def get_profiler():
    return torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        with_stack=True,
    )

def profile_tensor_ops():
    device = "cuda"
    with get_profiler() as prof:
        for _ in range(5):
            x = torch.randn(1000, 1000, device=device)
            y = torch.randn(1000, 1000, device=device)
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            prof.step()

    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))


if __name__ == "__main__":
    profile_tensor_ops()

and running as:

WORKSPACE_PATH=$(pwd)
docker run \
  --gpus all \
  -it \
  --rm \
  --ipc=host \
  --network=host \
  -v $WORKSPACE_PATH:$WORKSPACE_PATH \
  nvcr.io/nvidia/nemo:25.02 \
  python $WORKSPACE_PATH/test.py

was producing

   ....
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 777, in __exit__
    self.stop()
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 793, in stop
    self._transit_action(self.current_action, None)
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 836, in _transit_action
    action()
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 239, in stop_trace
    self.profiler.__exit__(None, None, None)
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/profiler.py", line 378, in __exit__
    self.kineto_results = _disable_profiler()
                          ^^^^^^^^^^^^^^^^^^^
RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/autograd/profiler_python.cpp":982, please report a bug to PyTorch. Python replay stack is empty.

Just one note: if you would test in local build, would be good to to verify if this test fails without this PR. Just for cross-checking as I was wondering if behavior changes a bit based on the different build configs. But I didn't get time to verify this thoroughly....

Edit 1: By the way, the reason for the above note is that I was further cross-checking above simple example with PyTorch containers:

  • nvcr.io/nvidia/pytorch:25.01-py3, 2.6.0a0+ecf3bae40a.nv25.0
  • nvcr.io/nvidia/pytorch:25.02-py3, 2.7.0a0+ecf3bae40a.nv25.02
  • nvcr.io/nvidia/pytorch:25.02-py3, 2.7.0a0+7c8ec84dab.nv25.03

they all use Python 3.12.3 but only 25.01-py3 (torch.__version__ -> 2.6.0a0+ecf3bae40a.nv25.01) fails with the assert. And hence, initial conclusion that it's "only" Python version related is no longer true (?).

Edit 2: one common thing between failing pytorch:25.01 and nemo:25.02 is that sys.version for both is the same '3.12.3 (main, Nov 6 2024, 18:32:19) [GCC 13.2.0]' i.e. GCC 13.2 whereas newer Pytorch are 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] !

@sraikund16
Copy link
Contributor Author

@sraikund16, thanks! I am able to build with latest change but I won't be able to test this at the moment / today. I am attaching here another simple test I saw failing with nemo:25.02 container. In case you could test this with local build (it's standalone):

import torch

def get_profiler():
    return torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        with_stack=True,
    )

def profile_tensor_ops():
    device = "cuda"
    with get_profiler() as prof:
        for _ in range(5):
            x = torch.randn(1000, 1000, device=device)
            y = torch.randn(1000, 1000, device=device)
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            prof.step()

    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))


if __name__ == "__main__":
    profile_tensor_ops()

and running as:

WORKSPACE_PATH=$(pwd)
docker run \
  --gpus all \
  -it \
  --rm \
  --ipc=host \
  --network=host \
  -v $WORKSPACE_PATH:$WORKSPACE_PATH \
  nvcr.io/nvidia/nemo:25.02 \
  python $WORKSPACE_PATH/test.py

was producing

   ....
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 777, in __exit__
    self.stop()
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 793, in stop
    self._transit_action(self.current_action, None)
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 836, in _transit_action
    action()
  File "/usr/local/lib/python3.12/dist-packages/torch/profiler/profiler.py", line 239, in stop_trace
    self.profiler.__exit__(None, None, None)
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/profiler.py", line 378, in __exit__
    self.kineto_results = _disable_profiler()
                          ^^^^^^^^^^^^^^^^^^^
RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/autograd/profiler_python.cpp":982, please report a bug to PyTorch. Python replay stack is empty.

Just one note: if you would test in local build, would be good to to verify if this test fails without this PR. Just for cross-checking as I was wondering if behavior changes a bit based on the different build configs. But I didn't get time to verify this thoroughly....

Edit 1: By the way, the reason for the above note is that I was further cross-checking above simple example with PyTorch containers:

  • nvcr.io/nvidia/pytorch:25.01-py3, 2.6.0a0+ecf3bae40a.nv25.0
  • nvcr.io/nvidia/pytorch:25.02-py3, 2.7.0a0+ecf3bae40a.nv25.02
  • nvcr.io/nvidia/pytorch:25.02-py3, 2.7.0a0+7c8ec84dab.nv25.03

they all use Python 3.12.3 but only 25.01-py3 (torch.__version__ -> 2.6.0a0+ecf3bae40a.nv25.01) fails with the assert. And hence, initial conclusion that it's "only" Python version related is no longer true (?).

Edit 2: one common thing between failing pytorch:25.01 and nemo:25.02 is that sys.version for both is the same '3.12.3 (main, Nov 6 2024, 18:32:19) [GCC 13.2.0]' i.e. GCC 13.2 whereas newer Pytorch are 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] !

I don't have a machine with those virtual environments set up so I won't be able to test it myself. Since this PR is supposed to fix multiple issues, lets get it in and then we can do follow up on these other potential issues later.

@sraikund16
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@arjun-choudhry
Copy link

arjun-choudhry commented Apr 2, 2025

@sraikund16 Can you please add me as a contributor in the PR? Thanks

@clee2000
Copy link
Contributor

clee2000 commented Apr 2, 2025

@pytorchbot revert -m "broke some profiler tests when building with debug asserts profiler/test_memory_profiler.py::TestMemoryProfiler::test_config_check GH job link HUD commit link" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Apr 2, 2025
This reverts commit 5734909.

Reverted #150370 on behalf of https://github.com/clee2000 due to broke some profiler tests when building with debug asserts profiler/test_memory_profiler.py::TestMemoryProfiler::test_config_check [GH job link](https://github.com/pytorch/pytorch/actions/runs/14211763078/job/39822158330) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/3ac5a499ddac701f607a9f7206f9bec8871e1cbb) ([comment](#150370 (comment)))
@pytorchmergebot
Copy link
Collaborator

@sraikund16 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 2, 2025
@facebook-github-bot
Copy link
Contributor

@sraikund16 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@sraikund16
Copy link
Contributor Author

@sraikund16 Can you please add me as a contributor in the PR? Thanks

Added you to the PR description, I think you need to push to the branch itself to be considered a contributor on GH though. Let me know if there is another way

@sraikund16
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@sraikund16
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Summary:
My commandeer of pytorch#150102

Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. It adds some unnecessary events at worst but prevents segfaults/failures. My PR just cleans up some refcount impl and logging.

Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues.

Differential Revision: D72207570

Pull Request resolved: pytorch#150370
Approved by: https://github.com/aaronenyeshi
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
This reverts commit 5734909.

Reverted pytorch#150370 on behalf of https://github.com/clee2000 due to broke some profiler tests when building with debug asserts profiler/test_memory_profiler.py::TestMemoryProfiler::test_config_check [GH job link](https://github.com/pytorch/pytorch/actions/runs/14211763078/job/39822158330) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/3ac5a499ddac701f607a9f7206f9bec8871e1cbb) ([comment](pytorch#150370 (comment)))
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Summary:
My commandeer of pytorch#150102

Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. It adds some unnecessary events at worst but prevents segfaults/failures. My PR just cleans up some refcount impl and logging.

Contributors: @arjun-choudhry

Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues.

Differential Revision: D72207570

Pull Request resolved: pytorch#150370
Approved by: https://github.com/aaronenyeshi
pytorchmergebot pushed a commit that referenced this pull request Jul 23, 2025
…55446)

Hi team,

Please help review this patch.

This PR #150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by python/cpython@257c413 on 3.12.5.

So I think the #150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: #155446
Approved by: https://github.com/sraikund16, https://github.com/cyyever
pytorchmergebot pushed a commit that referenced this pull request Jul 25, 2025
…55446)

Hi team,

Please help review this patch.

This PR #150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by python/cpython@257c413 on 3.12.5.

So I think the #150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: #155446
Approved by: https://github.com/sraikund16
pytorchmergebot pushed a commit that referenced this pull request Jul 30, 2025
…55446)

Hi team,

Please help review this patch.

This PR #150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by python/cpython@257c413 on 3.12.5.

So I think the #150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: #155446
Approved by: https://github.com/sraikund16
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…55446)

Hi team,

Please help review this patch.

This PR #150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by python/cpython@257c413 on 3.12.5.

So I think the #150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: #155446
Approved by: https://github.com/sraikund16
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…55446)

Hi team,

Please help review this patch.

This PR #150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by python/cpython@257c413 on 3.12.5.

So I think the #150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: #155446
Approved by: https://github.com/sraikund16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: profiler release notes category Reverted topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants