Skip to content

[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves #137397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 41 commits into from

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Oct 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137397

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit af679d5 with merge base 9012e7a (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
@XuehaiPan XuehaiPan added ciflow/trunk Trigger trunk jobs on your pull request module: pytree release notes: dynamo labels Oct 5, 2024
@XuehaiPan XuehaiPan requested review from zou3519 and jansel October 5, 2024 11:53
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 5, 2024
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use the python version of pytree that already exists rather than having a new implementation in polyfils?

[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 6, 2024
@XuehaiPan
Copy link
Collaborator Author

XuehaiPan commented Oct 6, 2024

Could we use the python version of pytree that already exists rather than having a new implementation in polyfils?

@jansel We can do this for functions that do not return the treespec variable. E.g.:

generator = tree_iter(tree)
leaves    = tree_leaves(tree)
newtree   = tree_map(func, tree)
tree_map_(func, tree)

I'd implement new polyfills to ensure behavior consistency between polyfill and C++ for the following functions that return/access the treespec variable. See:

leaves, treespec = tree_flatten(tree)
treespec         = tree_structure(tree)
tree             = tree_unflatten(leaves, treespec)

In most cases, the new polyfill is identical to the already existing Python pytree. For historical reasons, the Python pytree has existed for years and I think splitting the implement makes polyfill maintenance easier. cc @zou3519 for thoughts.

@atalman
Copy link
Contributor

atalman commented Dec 2, 2024

@pytorchmergebot revert -c ghfirst -m "Failing internal test"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@XuehaiPan your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Dec 2, 2024
…/ `tree_leaves` (#137397)"

This reverts commit 07850bb.

Reverted #137397 on behalf of https://github.com/atalman due to Failing internal test ([comment](#137397 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Dec 2, 2024
@atalman
Copy link
Contributor

atalman commented Dec 2, 2024

XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Dec 2, 2024
[ghstack-poisoned]
@XuehaiPan
Copy link
Collaborator Author

HI @XuehaiPan looks like its failing on: main/test/dynamo/test_trace_rules.py#L328

Updated.

@XuehaiPan
Copy link
Collaborator Author

@jansel @zou3519 Could you take a look at this and the follow-up PRs?

XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Dec 2, 2024
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Dec 2, 2024
@XuehaiPan
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…/ `tree_leaves` (pytorch#137397)"

This reverts commit 07850bb.

Reverted pytorch#137397 on behalf of https://github.com/atalman due to Failing internal test ([comment](pytorch#137397 (comment)))
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
@weifengpy
Copy link
Contributor

weifengpy commented Dec 6, 2024

I am seeing error in latest trunk, likely due to this PR. Is it possible to take a 2nd look if this breaks the trunk?

error about is_dict_insertion_ordered:

ERROR test/distributed/_tensor/test_math_ops.py - ValueError: Duplicate dispatch rule for <built-in method is_dict_insertion_ordered of PyCapsule object at 0x7f343d2b2910>: already registered in VariableBuilder's id dispatch map

how to repro:

from torch.testing._internal.common_utils import run_tests

full error

Traceback (most recent call last):
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/runner.py", line 341, in from_call
    result: Optional[TResult] = func()
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/runner.py", line 372, in <lambda>
    call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/python.py", line 531, in collect
    self._inject_setup_module_fixture()
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/python.py", line 545, in _inject_setup_module_fixture
    self.obj, ("setUpModule", "setup_module")
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/python.py", line 310, in obj
    self._obj = obj = self._getobj()
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/python.py", line 528, in _getobj
    return self._importtestmodule()
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/python.py", line 617, in _importtestmodule
    mod = import_path(self.path, mode=importmode, root=self.config.rootpath)
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/site-packages/_pytest/pathlib.py", line 565, in import_path
    importlib.import_module(module_name)
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/data/users/oss/pytorch/test/distributed/_tensor/test_math_ops.py", line 25, in <module>
    from torch.testing._internal.common_utils import run_tests
  File "/data/users/oss/pytorch/torch/testing/_internal/common_utils.py", line 1814, in <module>
    condition=torch._dynamo.config.inline_inbuilt_nn_modules,
  File "/data/users/oss/pytorch/torch/__init__.py", line 2647, in __getattr__
    return importlib.import_module(f".{name}", __name__)
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/data/users/oss/pytorch/torch/_dynamo/__init__.py", line 42, in <module>
    from .polyfills import loader as _  # usort: skip # noqa: F401
  File "/data/users/oss/pytorch/torch/_dynamo/polyfills/loader.py", line 24, in <module>
    POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
  File "/data/users/oss/pytorch/torch/_dynamo/polyfills/loader.py", line 25, in <genexpr>
    importlib.import_module(f".{submodule}", package=polyfills.__name__)
  File "/home/local/miniconda3/envs/torchtitan/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/data/users/oss/pytorch/torch/_dynamo/polyfills/pytree.py", line 31, in <module>
    def _(*args: Any, **kwargs: Any) -> bool:
  File "/data/users/oss/pytorch/torch/_dynamo/decorators.py", line 356, in wrapper
    raise ValueError(
ValueError: Duplicate dispatch rule for <built-in method is_dict_insertion_ordered of PyCapsule object at 0x7f343d2b2910>: already registered in VariableBuilder's id dispatch map
=========================================================================================================== short test summary info ===========================================================================================================
ERROR test/distributed/_tensor/test_math_ops.py - ValueError: Duplicate dispatch rule for <built-in method is_dict_insertion_ordered of PyCapsule object at 0x7f343d2b2910>: already registered in VariableBuilder's id dispatch map

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants