@@ -595,64 +595,6 @@ def test_serialization_array_with_storage(self):
595
595
q_copy [1 ].fill_ (10 )
596
596
self .assertEqual (q_copy [3 ], torch .cuda .IntStorage (10 ).fill_ (10 ))
597
597
598
- @setBlasBackendsToDefaultFinally
599
- def test_preferred_blas_library_settings (self ):
600
- def _check_default ():
601
- default = torch .backends .cuda .preferred_blas_library ()
602
- if torch .version .cuda :
603
- # CUDA logic is easy, it's always cublas
604
- self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
605
- else :
606
- # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else
607
- gcn_arch = str (
608
- torch .cuda .get_device_properties (0 ).gcnArchName .split (":" , 1 )[0 ]
609
- )
610
- if gcn_arch in ["gfx90a" , "gfx942" , "gfx950" ]:
611
- self .assertTrue (default == torch ._C ._BlasBackend .Cublaslt )
612
- else :
613
- self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
614
-
615
- _check_default ()
616
- # "Default" can be set but is immediately reset internally to the actual default value.
617
- self .assertTrue (
618
- torch .backends .cuda .preferred_blas_library ("default" )
619
- != torch ._C ._BlasBackend .Default
620
- )
621
- _check_default ()
622
- self .assertTrue (
623
- torch .backends .cuda .preferred_blas_library ("cublas" )
624
- == torch ._C ._BlasBackend .Cublas
625
- )
626
- self .assertTrue (
627
- torch .backends .cuda .preferred_blas_library ("hipblas" )
628
- == torch ._C ._BlasBackend .Cublas
629
- )
630
- # check bad strings
631
- with self .assertRaisesRegex (
632
- RuntimeError ,
633
- "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck." ,
634
- ):
635
- torch .backends .cuda .preferred_blas_library ("unknown" )
636
- # check bad input type
637
- with self .assertRaisesRegex (RuntimeError , "Unknown input value type." ):
638
- torch .backends .cuda .preferred_blas_library (1.0 )
639
- # check env var override
640
- custom_envs = [
641
- {"TORCH_BLAS_PREFER_CUBLASLT" : "1" },
642
- {"TORCH_BLAS_PREFER_HIPBLASLT" : "1" },
643
- ]
644
- test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())"
645
- for env_config in custom_envs :
646
- env = os .environ .copy ()
647
- for key , value in env_config .items ():
648
- env [key ] = value
649
- r = (
650
- subprocess .check_output ([sys .executable , "-c" , test_script ], env = env )
651
- .decode ("ascii" )
652
- .strip ()
653
- )
654
- self .assertEqual ("_BlasBackend.Cublaslt" , r )
655
-
656
598
@unittest .skipIf (TEST_CUDAMALLOCASYNC , "temporarily disabled for async" )
657
599
@setBlasBackendsToDefaultFinally
658
600
def test_cublas_workspace_explicit_allocation (self ):
0 commit comments