diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index f052b858c556..53a4b0b65d8c 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -126,6 +126,7 @@ void f8f8bf16_rowwise_impl( std::optional bias, at::Tensor out, const int swizzle) { +#if (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 900) int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); @@ -298,6 +299,7 @@ void f8f8bf16_rowwise_impl( cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif } @@ -318,6 +320,7 @@ void f8f8bf16_rowwise_impl_sm100( std::optional bias, at::Tensor out, const int swizzle) { +#if (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030) int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); @@ -480,6 +483,7 @@ void f8f8bf16_rowwise_impl_sm100( cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif } // Cutlass rowwise kernel for SM89 @@ -498,6 +502,7 @@ void f8f8bf16_rowwise_impl_sm89( at::Tensor w_scale, std::optional bias, at::Tensor out) { +#if (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 890) int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); @@ -695,6 +700,7 @@ void f8f8bf16_rowwise_impl_sm89( cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif } template