Skip to content

Detect torch function in lists as well #160256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/ezyang/3128/base
Choose a base branch
from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Aug 9, 2025

Stack from ghstack (oldest at bottom):

We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.

Signed-off-by: Edward Yang ezyang@meta.com

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Aug 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160256

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit c8f54e7 with merge base 842cc77 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Aug 9, 2025
This was done exclusively with claude code and I haven't reviewed it yet

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 0a57888
Pull-Request: #160256
[ghstack-poisoned]
@ezyang ezyang marked this pull request as ready for review August 10, 2025 04:28
@ezyang
Copy link
Contributor Author

ezyang commented Aug 10, 2025

I have reviewed it and some of the code is bad but it "works". Need to improve some performance characteristics for it.

return false;
bool has_torch_func = false;

for (long idx = 0; idx < size; idx++) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iteration here is the perf problem. Ideally we delay checking the insides until we are parsing. But this may result in a more involved change upstream as we typically assume by the time we parse TF cannot occur.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use c10::irange here, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just the color of the shed; the real problem is I'm adding O(n) extra CPython probes for int list arguments. I need to check to see if the overhead is perceptible.

@@ -905,26 +936,52 @@ static bool is_int_or_symint(PyObject* obj) {
static bool is_int_or_symint_list(
PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
int64_t* failed_idx = nullptr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reasons we want ptrs here instead of optional reference? Nullptr seems more error prone, especially when wrapping an integer type. We can statically guard against invalid std::optional accesses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-existing condition.

static bool is_scalar_list(PyObject* obj) {
static bool is_scalar_list(
PyObject* obj,
std::vector<PyObject*>* overloaded_args = nullptr) {
auto tuple = six::isTuple(obj);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Six? Uh we missed this in the upgrade didn't we... just use pybind11 handle APIs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do this separately

@ezyang
Copy link
Contributor Author

ezyang commented Aug 10, 2025

Some not very scientific benchmarking suggests this is something like 40ns overhead per call, where the calls end to end take 2000ns (so like 2% regression or something).

ezyang added 2 commits August 10, 2025 16:48
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 11, 2025
This was done exclusively with claude code and I haven't reviewed it yet

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 2b5d285
Pull-Request: #160256
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants