Skip to content

torchdim Python port #160236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/ezyang/3127/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
ezyang committed Aug 9, 2025
commit 7bf28faa567505b0cf3842a72a40cc0fbe19d30a
767 changes: 724 additions & 43 deletions functorch/dim/__init__.py

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions functorch/dim/_dim_entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Union

import torch


# NB: The old code represented dimension was from as negative number, so we
# follow this convention even though it shouldn't be necessary now
class DimEntry:
# The dimension this is from the rhs, or a FCD
data: Union[Dim, int]

def __init__(self, data=None):
if type(data) is int:
assert data < 0
if data is None:
data = 0
self.data = data

def __eq__(self, other):
if not isinstance(other, DimEntry):
return False
# Use 'is' for Dim objects to avoid triggering __torch_function__
# Use '==' only for positional (int) comparisons
if self.is_positional() and other.is_positional():
# Both are positional (ints)
return self.data == other.data
elif not self.is_positional() and not other.is_positional():
# Both are Dim objects - use 'is' to avoid __eq__
return self.data is other.data
else:
# One is positional, one is Dim - they can't be equal
return False

def is_positional(self):
return type(self.data) is int and self.data < 0

def is_none(self):
# Use isinstance to check for Dim objects, avoid triggering __torch_function__
from . import Dim

if isinstance(self.data, Dim):
# This is a Dim object, it can't be "none" (which is represented by 0)
return False
else:
# This is an int or other type
return self.data == 0

def position(self) -> int:
assert isinstance(self.data, int)
return self.data

def dim(self) -> Dim:
assert not isinstance(self.data, int)
return self.data

def __repr__(self):
return repr(self.data)


def ndim_of_levels(levels: Sequence[DimEntry]) -> int:
r = 0
for l in levels:
if l.is_positional():
r += 1
return r


def _match_levels(
tensor: torch.Tensor,
from_levels: list[DimEntry],
to_levels: list[DimEntry],
drop_levels: bool = False,
) -> torch.Tensor:
"""
Reshape a tensor to match target levels using as_strided.

This corresponds to the _match_levels function in C++ code.

Args:
tensor: Input tensor to reshape
from_levels: Current levels of the tensor
to_levels: Target levels to match
drop_levels: If True, missing dimensions are assumed to have stride 0

Returns:
Reshaped tensor
"""
if from_levels == to_levels:
return tensor

sizes = tensor.size()
strides = tensor.stride()

if not drop_levels:
assert len(from_levels) <= len(to_levels), (
"Cannot expand dimensions without drop_levels"
)

new_sizes = []
new_strides = []

for level in to_levels:
# Find index of this level in from_levels
try:
idx = from_levels.index(level)
except ValueError:
# Level not found in from_levels
if level.is_positional():
new_sizes.append(1)
else:
new_sizes.append(level.dim().size)
new_strides.append(0)
else:
new_sizes.append(sizes[idx])
new_strides.append(strides[idx])

return tensor.as_strided(new_sizes, new_strides, tensor.storage_offset())
120 changes: 120 additions & 0 deletions functorch/dim/_enable_all_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

import torch


class EnableAllLayers:
"""
RAII-style context manager for enabling functorch vmap layers.

This corresponds to the EnableAllLayers struct in the C++ code.
It manages the creation and cleanup of functorch dynamic layers.
"""

levels_start: int
levels_to_dim: list[Dim]

def __init__(self, levels: list[DimEntry]):
"""
Initialize and push dynamic layers for all first-class dimensions.

Args:
levels: List of dimension entries to create layers for
"""

self.levels_start = 0
self.levels_to_dim = []

for l in levels:
if not l.is_positional():
d = l.dim()
self.levels_to_dim.append(d)

# Sort by level for stable ordering
self.levels_to_dim.sort(key=lambda d: d._level)

def __enter__(self):
# Create functorch dynamic layers
for i, dim in enumerate(self.levels_to_dim):
batch_size = dim.size
level = torch._C._functorch._vmap_increment_nesting(batch_size, "different")
if i == 0:
self.levels_start = level
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean up dynamic layers in reverse order."""
to_remove = self.levels_start + len(self.levels_to_dim) - 1
for i in range(len(self.levels_to_dim)):
popped = torch._C._functorch._vmap_decrement_nesting()
assert popped == to_remove - i, (
f"Expected layer {to_remove - i}, got {popped}"
)

def from_batched(self, batchedtensor: torch.Tensor, has_device: bool) -> Tensor:
"""
Create a Tensor from a batched tensor by unwrapping functorch layers.

This corresponds to EnableAllLayers::from_batched in C++.

Args:
batchedtensor: Batched tensor from functorch operation
has_device: Whether tensor has device info

Returns:
Tensor with appropriate levels
"""
# Create positional levels for base dimensions
levels = []
for i in range(-batchedtensor.dim(), 0):
levels.append(i)

tensor = batchedtensor

while torch._C._functorch.is_batchedtensor(tensor):
level = torch._C._functorch.maybe_get_level(tensor)
assert level is not None
assert level >= self.levels_start and level < self.levels_start + len(
self.levels_to_dim
)
dim = self.levels_to_dim[level - self.levels_start]
bdim = torch._C._functorch.maybe_get_bdim(tensor)
assert bdim is not None
levels.insert(bdim, dim)
tensor = torch._C._functorch.get_unwrapped(tensor)

from . import Tensor

result = Tensor()
result._tensor = tensor
result._batchtensor = batchedtensor
result._has_device = has_device
result._levels = levels
return result

def inplace_update_layers(self, batchtensor: torch.Tensor, levels: list[DimEntry]):
"""
Update the levels of a batched tensor in place.

This corresponds to EnableAllLayers::inplace_update_layers in C++.
This requires the _unsafe_set_level binding that we'll add to functorch.

Args:
batchtensor: Batched tensor to update
levels: New levels to set
"""
# Check if tensor is batched
if not torch._C._functorch.is_batchedtensor(batchtensor):
return

impl = batchtensor

for i in reversed(range(len(self.levels_to_dim))):
if impl is None:
break

if any(l == self.levels_to_dim[i] for l in levels):
torch._C._functorch._unsafe_set_level(
impl, self.levels_start + i
)
impl = torch._C._functorch.get_unwrapped(batchtensor)
Loading
Loading