diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index da740e68de9c..4c7031d0a354 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -78,7 +78,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ package_exists = False elif pkg_name == "triton": try: - package_version = importlib.metadata.version("pytorch-triton") + # import triton works for both linux and windows + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") except Exception: package_exists = False else: