-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Detect torch function in lists as well #160256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/ezyang/3128/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160256
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit c8f54e7 with merge base 842cc77 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I have reviewed it and some of the code is bad but it "works". Need to improve some performance characteristics for it. |
return false; | ||
bool has_torch_func = false; | ||
|
||
for (long idx = 0; idx < size; idx++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The iteration here is the perf problem. Ideally we delay checking the insides until we are parsing. But this may result in a more involved change upstream as we typically assume by the time we parse TF cannot occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use c10::irange here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just the color of the shed; the real problem is I'm adding O(n) extra CPython probes for int list arguments. I need to check to see if the overhead is perceptible.
@@ -905,26 +936,52 @@ static bool is_int_or_symint(PyObject* obj) { | |||
static bool is_int_or_symint_list( | |||
PyObject* obj, | |||
int broadcast_size, | |||
int64_t* failed_idx = nullptr) { | |||
int64_t* failed_idx = nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reasons we want ptrs here instead of optional reference? Nullptr seems more error prone, especially when wrapping an integer type. We can statically guard against invalid std::optional accesses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-existing condition.
static bool is_scalar_list(PyObject* obj) { | ||
static bool is_scalar_list( | ||
PyObject* obj, | ||
std::vector<PyObject*>* overloaded_args = nullptr) { | ||
auto tuple = six::isTuple(obj); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Six? Uh we missed this in the upgrade didn't we... just use pybind11 handle APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to do this separately
Some not very scientific benchmarking suggests this is something like 40ns overhead per call, where the calls end to end take 2000ns (so like 2% regression or something). |
Stack from ghstack (oldest at bottom):
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.
Signed-off-by: Edward Yang ezyang@meta.com