Skip to content

[inductor] dont reuse buffers if it affects peak (#145883) #159530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: gh/v0i0/4/base
Choose a base branch
from

Conversation

v0i0
Copy link
Contributor

@v0i0 v0i0 commented Jul 30, 2025

Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159530

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 45698e2 with merge base 2507ae6 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

v0i0 added a commit that referenced this pull request Jul 30, 2025
@v0i0
Copy link
Contributor Author

v0i0 commented Jul 30, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jul 30, 2025
@v0i0 v0i0 requested a review from eellison July 30, 2025 23:43
) * get_dtype_size(self.node.get_dtype())
if free_line_scheduler_node >= self_scheduler_node:
return False
peak_memory_in_range = max(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after we reuse a buffer, we need update the memory of nodes for its reuse window

Comment on lines 615 to 616
peak_memory_per_scheduler_node[free_line_scheduler_node:self_scheduler_node]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be potentially O(n^2), because at each node we are iterating through O(n) nodes.

Is there an O(n log n) solution we could do ? From looking around a bit - maybe https://en.wikipedia.org/wiki/Fenwick_tree ? note - I haven't looked especially closely at this yet.

If we can't figure out an O(n log n) solution we could also potentially do a sliding window, or add other heuristics like disallow buffer reuse of tensors if they are above a certain size.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Jul 31, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Jul 31, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 2, 2025
[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 5, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 7, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 7, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 7, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 7, 2025
@eellison eellison self-requested a review August 11, 2025 23:16
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, a couple questions about the tree. would you mind doing one dashboard run ? I believe we expect to see memory improvmenets in timm benchmark.

Comment on lines +85 to +86
if lazy_node is not None:
# Apply lazy update to current node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: early return instead of nesting ?

self.size = 1
while self.size < self.n:
self.size *= 2
self.size *= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth describing the data layout of the tree

[1 ... n ] base values then [1... n // 2, n // 2... n //4 ] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the opposite. will add a comment


# Initialize tree and lazy arrays
self.tree = [identity_element] * self.size
self.lazy: list[Optional[T]] = [None] * self.size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: describe what the lazy array will do ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

self._build(values, right_child, mid + 1, end)

# Update current node with summary of children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to use the identity element instead of having to handle special cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i meant, i think this is what the code is already doing, right.

) * get_dtype_size(self.node.get_dtype())
if self.should_reuse_buffer(free_line, size):
free_line.is_reused = True
self.wrapper.estimate_peak.update_peak_between(free_line, self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i guess the difference is, it used to be queries are O(n), now just updates are O(n). is that correct ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both are O(log n): consider the binary tree of nodes, and an interval between two nodes. to process the interval, we need to look at the path from the tree root to the left and the right boundary of the interval (2 * log n). For updates, it update those nodes and their direct neighbors within the interval. For queries, it reads those nodes, and potentially pushes lazy updates to the direct neighbors. So both are 2 * 2 * log_2 n.

return

mid = (start + end) // 2
left_child = 2 * node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking aloud, Do we need both _push_lazy and build ? I guess build could be replaced by iteratively _push_lazy each element in values ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd keep it as is. replacing build by update_range/push_lazy turns it from O(n) to O(n log n).

Comment on lines 90 to 104
left_child = 2 * node
right_child = 2 * node + 1

# Propagate to children
lazy_left_child = self.lazy[left_child]
if lazy_left_child is None:
self.lazy[left_child] = lazy_node
else:
self.lazy[left_child] = self.update_op(lazy_left_child, lazy_node)

lazy_right_child = self.lazy[right_child]
if lazy_right_child is None:
self.lazy[right_child] = lazy_node
else:
self.lazy[right_child] = self.update_op(lazy_right_child, lazy_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for loop over left/right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 140 to 150
lazy_left_child = self.lazy[left_child]
if lazy_left_child is None:
self.lazy[left_child] = value
else:
self.lazy[left_child] = self.update_op(lazy_left_child, value)

lazy_right_child = self.lazy[right_child]
if lazy_right_child is None:
self.lazy[right_child] = value
else:
self.lazy[right_child] = self.update_op(lazy_right_child, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for loop ? also, this seems the same as lines 87-104 above. refactor ?

@v0i0
Copy link
Contributor Author

v0i0 commented Aug 12, 2025

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
v0i0 added a commit that referenced this pull request Aug 12, 2025

self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
V.graph.scheduler.nodes,
{},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, one last thing, names_to_freeable_bufs is important for the backward when we need to know which activations will deallocate in order to have an accurate memory estimate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants