Skip to content

Commit d6f5d78

Browse files
anlunxtensorflower-gardener
authored andcommitted
[XLA:GPU] Make GpuBfloat16Support compatible with AOT compilation
GpuBfloat16Support should work without stream executor. PiperOrigin-RevId: 495661838
1 parent 5133fce commit d6f5d78

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

tensorflow/compiler/xla/service/gpu/gpu_compiler.cc

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,14 @@ class GpuBfloat16Support : public BFloat16Support {
192192
explicit GpuBfloat16Support(bool supports_matrix_multiplication,
193193
se::StreamExecutor* stream_exec)
194194
: supports_matrix_multiplication_(supports_matrix_multiplication),
195-
stream_exec_(stream_exec) {}
195+
is_conv_bf16_supported_(IsConvBf16Supported(stream_exec)) {}
196+
197+
explicit GpuBfloat16Support(bool supports_matrix_multiplication,
198+
se::dnn::VersionInfo cudnn_version,
199+
se::CudaComputeCapability cuda_compute_capability)
200+
: supports_matrix_multiplication_(supports_matrix_multiplication),
201+
is_conv_bf16_supported_(
202+
IsConvBf16Supported(cudnn_version, cuda_compute_capability)) {}
196203

197204
bool SupportsBF16Operand(const HloInstruction& hlo,
198205
int64_t operand_index) const override {
@@ -235,30 +242,41 @@ class GpuBfloat16Support : public BFloat16Support {
235242
case HloOpcode::kBitcast:
236243
return true;
237244
case HloOpcode::kConvolution:
238-
return IsConvBF16Supported();
245+
return is_conv_bf16_supported_;
239246
default:
240247
return supports_matrix_multiplication_ &&
241248
gpu::IsMatrixMultiplication(hlo);
242249
}
243250
}
244251

245-
bool IsConvBF16Supported() const {
246-
if (se::dnn::DnnSupport* dnn = stream_exec_->AsDnn()) {
252+
static bool IsConvBf16Supported(se::StreamExecutor* stream_exec) {
253+
if (se::dnn::DnnSupport* dnn = stream_exec->AsDnn()) {
247254
se::port::StatusOr<se::dnn::VersionInfo> cudnn_version =
248255
dnn->GetVersion();
249-
return cudnn_version.ok() &&
250-
(cudnn_version->major_version() > 8 ||
251-
(cudnn_version->major_version() == 8 &&
252-
cudnn_version->minor_version() >= 2)) &&
253-
stream_exec_->GetDeviceDescription()
254-
.cuda_compute_capability()
255-
.IsAtLeast(se::CudaComputeCapability::AMPERE);
256+
if (cudnn_version.ok()) {
257+
auto cuda_compute_capability =
258+
stream_exec->GetDeviceDescription().cuda_compute_capability();
259+
return (cudnn_version->major_version() > 8 ||
260+
(cudnn_version->major_version() == 8 &&
261+
cudnn_version->minor_version() >= 2)) &&
262+
cuda_compute_capability.IsAtLeast(
263+
se::CudaComputeCapability::AMPERE);
264+
}
256265
}
257266
return false;
258267
}
259268

269+
static bool IsConvBf16Supported(
270+
se::dnn::VersionInfo cudnn_version,
271+
se::CudaComputeCapability cuda_compute_capability) {
272+
return (cudnn_version.major_version() > 8 ||
273+
(cudnn_version.major_version() == 8 &&
274+
cudnn_version.minor_version() >= 2)) &&
275+
cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE);
276+
}
277+
260278
bool supports_matrix_multiplication_;
261-
se::StreamExecutor* stream_exec_;
279+
bool is_conv_bf16_supported_;
262280
};
263281

264282
int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {

0 commit comments

Comments
 (0)