@@ -192,7 +192,14 @@ class GpuBfloat16Support : public BFloat16Support {
192
192
explicit GpuBfloat16Support (bool supports_matrix_multiplication,
193
193
se::StreamExecutor* stream_exec)
194
194
: supports_matrix_multiplication_(supports_matrix_multiplication),
195
- stream_exec_(stream_exec) {}
195
+ is_conv_bf16_supported_(IsConvBf16Supported(stream_exec)) {}
196
+
197
+ explicit GpuBfloat16Support (bool supports_matrix_multiplication,
198
+ se::dnn::VersionInfo cudnn_version,
199
+ se::CudaComputeCapability cuda_compute_capability)
200
+ : supports_matrix_multiplication_(supports_matrix_multiplication),
201
+ is_conv_bf16_supported_(
202
+ IsConvBf16Supported (cudnn_version, cuda_compute_capability)) {}
196
203
197
204
bool SupportsBF16Operand (const HloInstruction& hlo,
198
205
int64_t operand_index) const override {
@@ -235,30 +242,41 @@ class GpuBfloat16Support : public BFloat16Support {
235
242
case HloOpcode::kBitcast :
236
243
return true ;
237
244
case HloOpcode::kConvolution :
238
- return IsConvBF16Supported () ;
245
+ return is_conv_bf16_supported_ ;
239
246
default :
240
247
return supports_matrix_multiplication_ &&
241
248
gpu::IsMatrixMultiplication (hlo);
242
249
}
243
250
}
244
251
245
- bool IsConvBF16Supported () const {
246
- if (se::dnn::DnnSupport* dnn = stream_exec_ ->AsDnn ()) {
252
+ static bool IsConvBf16Supported (se::StreamExecutor* stream_exec) {
253
+ if (se::dnn::DnnSupport* dnn = stream_exec ->AsDnn ()) {
247
254
se::port::StatusOr<se::dnn::VersionInfo> cudnn_version =
248
255
dnn->GetVersion ();
249
- return cudnn_version.ok () &&
250
- (cudnn_version->major_version () > 8 ||
251
- (cudnn_version->major_version () == 8 &&
252
- cudnn_version->minor_version () >= 2 )) &&
253
- stream_exec_->GetDeviceDescription ()
254
- .cuda_compute_capability ()
255
- .IsAtLeast (se::CudaComputeCapability::AMPERE);
256
+ if (cudnn_version.ok ()) {
257
+ auto cuda_compute_capability =
258
+ stream_exec->GetDeviceDescription ().cuda_compute_capability ();
259
+ return (cudnn_version->major_version () > 8 ||
260
+ (cudnn_version->major_version () == 8 &&
261
+ cudnn_version->minor_version () >= 2 )) &&
262
+ cuda_compute_capability.IsAtLeast (
263
+ se::CudaComputeCapability::AMPERE);
264
+ }
256
265
}
257
266
return false ;
258
267
}
259
268
269
+ static bool IsConvBf16Supported (
270
+ se::dnn::VersionInfo cudnn_version,
271
+ se::CudaComputeCapability cuda_compute_capability) {
272
+ return (cudnn_version.major_version () > 8 ||
273
+ (cudnn_version.major_version () == 8 &&
274
+ cudnn_version.minor_version () >= 2 )) &&
275
+ cuda_compute_capability.IsAtLeast (se::CudaComputeCapability::AMPERE);
276
+ }
277
+
260
278
bool supports_matrix_multiplication_;
261
- se::StreamExecutor* stream_exec_ ;
279
+ bool is_conv_bf16_supported_ ;
262
280
};
263
281
264
282
int64_t GetSizeOfShape (const Shape& shape, int pointer_size) {
0 commit comments