-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[MPS] Extend addmm to integral types #160270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/malfet/479/base
Are you sure you want to change the base?
Conversation
Fixes #154901 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160270
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 16 PendingAs of commit cdfaf5a with merge base 86eb65f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
} | ||
}); | ||
return output; | ||
return output; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double return? Surprised this didn't error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I hoped some of the linters will be triggered by it, but feels like this is fine...
constant uint3& sizes [[buffer(6)]], | ||
uint2 tid [[thread_position_in_threadgroup]], | ||
uint2 thread_id [[thread_position_in_grid]]) { | ||
threadgroup T A_tile[TILE_DIM][TILE_DIM]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ugly, but can this be rewritten as an std array too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threadgroups are a bit weird(i.e. this statement affects GPU occupancy), let me give it a try in a separate PR, but make sure it would not regress the perf...
Fixes #154901 [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
By adding
addmm
kernel, which is a logical continuation ofmm
one. The only tricking part are how alpha and beta constants are handled, which are passed asoptmath_t
, i.e. that it could be, int64, int32 or floatUnified all MM flavors instantiations thru
INSTANTIATE_MM_OPS
and tested thataddmm
metal kernel works as expected for floating types as well by testing it viaFixes #154901