Skip to content

Commit 0516a50

Browse files
committed
Update
[ghstack-poisoned]
2 parents 5bfce27 + 7f7d905 commit 0516a50

File tree

73 files changed

+585
-1314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+585
-1314
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ git submodule update --init --recursive
243243

244244
```bash
245245
conda install cmake ninja
246-
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below
246+
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section above
247247
pip install -r requirements.txt
248248
```
249249

@@ -560,7 +560,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi
560560
561561
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
562562
563-
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
563+
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
564564
A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). <!-- codespell:ignore -->
565565
566566
Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.

aten/src/ATen/core/ivalue.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
9797
return ComplexType::get();
9898
case Tag::Int:
9999
return IntType::get();
100+
case Tag::UInt:
101+
return IntType::get();
100102
case Tag::SymInt:
101103
return c10::SymIntType::get();
102104
case Tag::SymFloat:
@@ -320,6 +322,8 @@ IValue IValue::equals(const IValue& rhs) const {
320322
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
321323
case Tag::Int:
322324
return rhs.isInt() && lhs.toInt() == rhs.toInt();
325+
case Tag::UInt:
326+
return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt();
323327
case Tag::SymInt:
324328
return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
325329
case Tag::SymFloat:
@@ -379,6 +383,8 @@ size_t IValue::hash(const IValue& v) {
379383
case Tag::Int:
380384
return c10::get_hash(v.payload.u.as_int);
381385
// NB: these are technically strict aliasing violations
386+
case Tag::UInt:
387+
return c10::get_hash(v.payload.u.as_int);
382388
case Tag::SymInt:
383389
return c10::get_hash(v.payload.u.as_int);
384390
case Tag::SymFloat:
@@ -806,6 +812,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
806812
return printComplex(out, v);
807813
} case IValue::Tag::Int:
808814
return out << v.toInt();
815+
case IValue::Tag::UInt:
816+
return out << v.toUInt();
809817
case IValue::Tag::SymInt:
810818
return out << v.toSymInt();
811819
case IValue::Tag::SymFloat:

aten/src/ATen/core/ivalue.h

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <c10/macros/Export.h>
1313
#include <c10/util/MaybeOwned.h>
1414
#include <c10/util/intrusive_ptr.h>
15+
#include <limits>
1516
#include <type_traits>
1617
#include <unordered_map>
1718
#include <unordered_set>
@@ -160,6 +161,7 @@ struct Capsule {
160161
_(Double) \
161162
_(ComplexDouble) \
162163
_(Int) \
164+
_(UInt) \
163165
_(SymInt) \
164166
_(SymFloat) \
165167
_(SymBool) \
@@ -653,6 +655,29 @@ struct TORCH_API IValue final {
653655
}
654656
}
655657

658+
// Unsigned
659+
IValue(uint64_t u) : tag( u <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt) {
660+
payload.u.as_uint = u;
661+
}
662+
663+
664+
// See Note [Meaning of HAS_u]
665+
// IValue type model closely follows that of c10::Scalar
666+
// Where all integers are upcast to 64-bit representation, and `as_int` is used as default
667+
// representation unless value could not be represented as signed int
668+
bool isUnsigned() const {
669+
return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0);
670+
}
671+
672+
uint64_t toUInt() const {
673+
if (isUnsigned()) {
674+
return payload.u.as_uint;
675+
} else {
676+
TORCH_INTERNAL_ASSERT(0, "expected unsigned int");
677+
}
678+
}
679+
680+
656681
// Bool
657682
IValue(bool b) : tag(Tag::Bool) {
658683
#if defined(__clang__) && defined(__x86_64__)
@@ -893,8 +918,14 @@ struct TORCH_API IValue final {
893918
} else {
894919
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
895920
s.isIntegral(false), "Unknown type in Scalar");
896-
tag = Tag::Int;
897-
payload.u.as_int = s.toLong();
921+
if (s.isUnsigned()) {
922+
const auto val = s.toUInt64();
923+
payload.u.as_uint = val;
924+
tag = val <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt;
925+
} else {
926+
payload.u.as_int = s.toLong();
927+
tag = Tag::Int;
928+
}
898929
}
899930
}
900931

@@ -918,6 +949,8 @@ struct TORCH_API IValue final {
918949
return toSymFloat();
919950
else if (isSymBool())
920951
return toSymBool();
952+
else if (isUnsigned())
953+
return toUInt();
921954
TORCH_CHECK(false, "IValue is not a Scalar");
922955
}
923956

@@ -1247,6 +1280,8 @@ struct TORCH_API IValue final {
12471280
return true;
12481281
case Tag::Int:
12491282
return false;
1283+
case Tag::UInt:
1284+
return false;
12501285
case Tag::SymInt:
12511286
return true;
12521287
case Tag::SymFloat:
@@ -1343,6 +1378,8 @@ struct TORCH_API IValue final {
13431378
union TriviallyCopyablePayload {
13441379
TriviallyCopyablePayload() : as_int(0) {}
13451380
int64_t as_int;
1381+
// See Note [Meaning of HAS_u]
1382+
uint64_t as_uint;
13461383
double as_double;
13471384
bool as_bool;
13481385
// Invariant: never nullptr; null state is represented as

aten/src/ATen/native/mps/kernels/LinearAlgebra.metal

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,37 @@ kernel void matmul(
6868
}
6969
}
7070

71+
template <typename T>
72+
kernel void addmm(
73+
constant T* mat1Data [[buffer(0)]],
74+
constant T* mat2Data [[buffer(1)]],
75+
device T* outputData [[buffer(2)]],
76+
constant T* biasData [[buffer(3)]],
77+
constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4)]],
78+
constant array<ulong2, 4>& strides [[buffer(5)]],
79+
constant uint3& sizes [[buffer(6)]],
80+
uint2 tid [[thread_position_in_threadgroup]],
81+
uint2 thread_id [[thread_position_in_grid]]) {
82+
threadgroup T A_tile[TILE_DIM][TILE_DIM];
83+
threadgroup T B_tile[TILE_DIM][TILE_DIM];
84+
85+
auto sum = matmul_inner<T>(
86+
mat1Data,
87+
mat2Data,
88+
reinterpret_cast<constant array<ulong2, 3>&>(strides),
89+
sizes,
90+
A_tile,
91+
B_tile,
92+
tid,
93+
thread_id);
94+
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
95+
auto bias =
96+
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
97+
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
98+
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
99+
}
100+
}
101+
71102
template <typename T>
72103
kernel void naive_bmm(
73104
constant T* mat1Data [[buffer(0)]],
@@ -613,38 +644,42 @@ kernel void applyPivots(
613644
}
614645
}
615646

616-
#define INSTANTIATE_NAIVE_MM(DTYPE) \
617-
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
618-
constant DTYPE * mat1Data [[buffer(0)]], \
619-
constant DTYPE * mat2Data [[buffer(1)]], \
620-
device DTYPE * outputData [[buffer(2)]], \
621-
constant array<ulong2, 3> & strides [[buffer(3)]], \
622-
constant uint3 & sizes [[buffer(4)]], \
623-
uint2 tid [[thread_position_in_threadgroup]], \
624-
uint2 group_id [[threadgroup_position_in_grid]])
625-
626-
#define INSTANTIATE_NAIVE_BMM(DTYPE) \
647+
#define INSTANTIATE_MM_OPS(DTYPE) \
648+
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649+
constant DTYPE * mat1Data [[buffer(0)]], \
650+
constant DTYPE * mat2Data [[buffer(1)]], \
651+
device DTYPE * outputData [[buffer(2)]], \
652+
constant array<ulong2, 3> & strides [[buffer(3)]], \
653+
constant uint3 & sizes [[buffer(4)]], \
654+
uint2 tid [[thread_position_in_threadgroup]], \
655+
uint2 group_id [[threadgroup_position_in_grid]]); \
627656
template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
628657
constant DTYPE * mat1Data [[buffer(0)]], \
629658
constant DTYPE * mat2Data [[buffer(1)]], \
630659
device DTYPE * outputData [[buffer(2)]], \
631660
constant array<ulong, 9> & strides [[buffer(3)]], \
632661
constant uint4 & sizes [[buffer(4)]], \
633662
uint3 tid [[thread_position_in_threadgroup]], \
634-
uint3 group_id [[threadgroup_position_in_grid]])
663+
uint3 group_id [[threadgroup_position_in_grid]]); \
664+
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
665+
constant DTYPE * mat1Data [[buffer(0)]], \
666+
constant DTYPE * mat2Data [[buffer(1)]], \
667+
device DTYPE * outputData [[buffer(2)]], \
668+
constant DTYPE * biasData [[buffer(3)]], \
669+
constant array<c10::metal::opmath_t<DTYPE>, 2> & \
670+
alpha_beta [[buffer(4)]], \
671+
constant array<ulong2, 4> & strides [[buffer(5)]], \
672+
constant uint3 & sizes [[buffer(6)]], \
673+
uint2 tid [[thread_position_in_threadgroup]], \
674+
uint2 group_id [[threadgroup_position_in_grid]])
635675

636-
INSTANTIATE_NAIVE_MM(float);
637-
INSTANTIATE_NAIVE_MM(half);
638-
INSTANTIATE_NAIVE_MM(bfloat);
676+
INSTANTIATE_MM_OPS(float);
677+
INSTANTIATE_MM_OPS(half);
678+
INSTANTIATE_MM_OPS(bfloat);
639679

640680
// Integral MM
641-
INSTANTIATE_NAIVE_MM(short);
642-
INSTANTIATE_NAIVE_MM(int);
643-
INSTANTIATE_NAIVE_MM(long);
644-
INSTANTIATE_NAIVE_MM(char);
645-
INSTANTIATE_NAIVE_MM(uchar);
646-
INSTANTIATE_NAIVE_BMM(short);
647-
INSTANTIATE_NAIVE_BMM(int);
648-
INSTANTIATE_NAIVE_BMM(long);
649-
INSTANTIATE_NAIVE_BMM(char);
650-
INSTANTIATE_NAIVE_BMM(uchar);
681+
INSTANTIATE_MM_OPS(long);
682+
INSTANTIATE_MM_OPS(int);
683+
INSTANTIATE_MM_OPS(short);
684+
INSTANTIATE_MM_OPS(char);
685+
INSTANTIATE_MM_OPS(uchar);

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,61 @@
112112
return output;
113113
}
114114

115+
Tensor& do_metal_addmm(const Tensor& self,
116+
const Tensor& other,
117+
Tensor& output,
118+
const Scalar& alpha,
119+
const Scalar& beta,
120+
const Tensor& bias) {
121+
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
122+
return do_metal_mm(self, other, output);
123+
}
124+
auto stream = getCurrentMPSStream();
125+
auto device = MPSDevice::getInstance()->device();
126+
auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output));
127+
dispatch_sync_with_rethrow(stream->queue(), ^() {
128+
@autoreleasepool {
129+
getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other});
130+
auto computeEncoder = stream->commandEncoder();
131+
[computeEncoder setComputePipelineState:matmulPSO];
132+
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
133+
static_cast<uint32_t>(self.size(1)),
134+
static_cast<uint32_t>(output.size(1))};
135+
std::array<int64_t, 8> strides = {self.stride(0),
136+
self.stride(1),
137+
other.stride(0),
138+
other.stride(1),
139+
output.stride(0),
140+
output.stride(1),
141+
bias.stride(0),
142+
bias.stride(1)};
143+
union {
144+
std::array<int64_t, 2> i64;
145+
std::array<int32_t, 2> i32;
146+
std::array<float, 2> f32;
147+
} alpha_beta;
148+
if (output.scalar_type() == kLong) {
149+
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
150+
} else if (c10::isIntegralType(output.scalar_type(), true)) {
151+
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
152+
} else {
153+
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
154+
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
155+
}
156+
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
157+
uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
158+
uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;
159+
160+
MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
161+
MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
162+
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes);
163+
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
164+
getMPSProfiler().endProfileKernel(matmulPSO);
165+
}
166+
});
167+
return output;
168+
}
169+
115170
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
116171
const Tensor& self,
117172
const Tensor& other) {
@@ -644,7 +699,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
644699

645700
TORCH_CHECK(output.is_mps());
646701
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
647-
TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input");
648702

649703
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
650704
checkAllSameGPU(__func__, args);
@@ -671,6 +725,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
671725
return output;
672726
}
673727

728+
if (use_metal_mm(self, other, output)) {
729+
return do_metal_addmm(self, other, output, alpha, beta, *bias_);
730+
}
731+
674732
bool is_beta_non_zero = beta.toDouble() != 0.0;
675733

676734
struct CachedGraph : public mps::MPSCachedGraph {

aten/src/ATen/native/sparse/mps/kernels/Sparse.metal

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,9 @@ kernel void coalesce_with_positions_kernel(
120120
INSTANTIATE_COALESCE_WITH_POSITIONS(float);
121121
INSTANTIATE_COALESCE_WITH_POSITIONS(half);
122122
INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat);
123-
INSTANTIATE_COALESCE_WITH_POSITIONS(bool);
123+
INSTANTIATE_COALESCE_WITH_POSITIONS(bool);
124+
INSTANTIATE_COALESCE_WITH_POSITIONS(long);
125+
INSTANTIATE_COALESCE_WITH_POSITIONS(char);
126+
INSTANTIATE_COALESCE_WITH_POSITIONS(uchar);
127+
INSTANTIATE_COALESCE_WITH_POSITIONS(short);
128+
INSTANTIATE_COALESCE_WITH_POSITIONS(int);

c10/cuda/CUDAStream.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,6 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) {
216216
// Creates the low and high priority stream pools for the specified device
217217
// Warning: only call once per device!
218218
static void initDeviceStreamState(DeviceIndex device_index) {
219-
// Switches to the requested device so streams are properly associated
220-
// with it.
221-
CUDAGuard device_guard{device_index};
222219
for (const auto i : c10::irange(kStreamsPerPool)) {
223220
for (const auto p : c10::irange(max_stream_priorities)) {
224221
initSingleStream(p, device_index, i);

0 commit comments

Comments
 (0)