diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl index 5633328cf9..1583c89199 100644 --- a/python/private/pypi/parse_requirements.bzl +++ b/python/private/pypi/parse_requirements.bzl @@ -285,12 +285,17 @@ def _add_dists(*, requirement, index_urls, logger = None): if requirement.srcs.url: url = requirement.srcs.url _, _, filename = url.rpartition("/") + filename, _, _ = filename.partition("#sha256=") if "." not in filename: # detected filename has no extension, it might be an sdist ref # TODO @aignas 2025-04-03: should be handled if the following is fixed: # https://github.com/bazel-contrib/rules_python/issues/2363 return [], None + if "@" in filename: + # this is most likely foo.git@git_sha, skip special handling of these + return [], None + direct_url_dist = struct( url = url, filename = filename, diff --git a/tests/pypi/index_sources/index_sources_tests.bzl b/tests/pypi/index_sources/index_sources_tests.bzl index ffeed87a7b..9d12bc6399 100644 --- a/tests/pypi/index_sources/index_sources_tests.bzl +++ b/tests/pypi/index_sources/index_sources_tests.bzl @@ -21,38 +21,50 @@ _tests = [] def _test_no_simple_api_sources(env): inputs = { + "foo @ git+https://github.com/org/foo.git@deadbeef": struct( + requirement = "foo @ git+https://github.com/org/foo.git@deadbeef", + marker = "", + url = "git+https://github.com/org/foo.git@deadbeef", + shas = [], + version = "", + ), "foo==0.0.1": struct( requirement = "foo==0.0.1", marker = "", url = "", + version = "0.0.1", ), "foo==0.0.1 @ https://someurl.org": struct( requirement = "foo==0.0.1 @ https://someurl.org", marker = "", url = "https://someurl.org", + version = "0.0.1", ), "foo==0.0.1 @ https://someurl.org/package.whl": struct( requirement = "foo==0.0.1 @ https://someurl.org/package.whl", marker = "", url = "https://someurl.org/package.whl", + version = "0.0.1", ), "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef": struct( requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef", marker = "", url = "https://someurl.org/package.whl", shas = ["deadbeef"], + version = "0.0.1", ), "foo==0.0.1 @ https://someurl.org/package.whl; python_version < \"2.7\"\\ --hash=sha256:deadbeef": struct( requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef", marker = "python_version < \"2.7\"", url = "https://someurl.org/package.whl", shas = ["deadbeef"], + version = "0.0.1", ), } for input, want in inputs.items(): got = index_sources(input) env.expect.that_collection(got.shas).contains_exactly(want.shas if hasattr(want, "shas") else []) - env.expect.that_str(got.version).equals("0.0.1") + env.expect.that_str(got.version).equals(want.version) env.expect.that_str(got.requirement).equals(want.requirement) env.expect.that_str(got.requirement_line).equals(got.requirement) env.expect.that_str(got.marker).equals(want.marker) diff --git a/tests/pypi/parse_requirements/parse_requirements_tests.bzl b/tests/pypi/parse_requirements/parse_requirements_tests.bzl index 723bb605ce..c5b24870ea 100644 --- a/tests/pypi/parse_requirements/parse_requirements_tests.bzl +++ b/tests/pypi/parse_requirements/parse_requirements_tests.bzl @@ -30,12 +30,16 @@ foo[extra] @ https://some-url/package.whl bar @ https://example.org/bar-1.0.whl --hash=sha256:deadbeef baz @ https://test.com/baz-2.0.whl; python_version < "3.8" --hash=sha256:deadb00f qux @ https://example.org/qux-1.0.tar.gz --hash=sha256:deadbe0f +torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc """, "requirements_extra_args": """\ --index-url=example.org foo[extra]==0.0.1 \ --hash=sha256:deadbeef +""", + "requirements_git": """ +foo @ git+https://github.com/org/foo.git@deadbeef """, "requirements_linux": """\ foo==0.0.3 --hash=sha256:deadbaaf @@ -232,6 +236,31 @@ def _test_direct_urls(env): whls = [], ), ], + "torch": [ + struct( + distribution = "torch", + extra_pip_args = [], + is_exposed = True, + sdist = None, + srcs = struct( + marker = "", + requirement = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc", + requirement_line = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc", + shas = [], + url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc", + version = "", + ), + target_platforms = ["linux_x86_64"], + whls = [ + struct( + filename = "torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl", + sha256 = "", + url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc", + yanked = False, + ), + ], + ), + ], }) _tests.append(_test_direct_urls) @@ -623,6 +652,36 @@ def _test_optional_hash(env): _tests.append(_test_optional_hash) +def _test_git_sources(env): + got = parse_requirements( + ctx = _mock_ctx(), + requirements_by_platform = { + "requirements_git": ["linux_x86_64"], + }, + ) + env.expect.that_dict(got).contains_exactly({ + "foo": [ + struct( + distribution = "foo", + extra_pip_args = [], + is_exposed = True, + sdist = None, + srcs = struct( + marker = "", + requirement = "foo @ git+https://github.com/org/foo.git@deadbeef", + requirement_line = "foo @ git+https://github.com/org/foo.git@deadbeef", + shas = [], + url = "git+https://github.com/org/foo.git@deadbeef", + version = "", + ), + target_platforms = ["linux_x86_64"], + whls = [], + ), + ], + }) + +_tests.append(_test_git_sources) + def parse_requirements_test_suite(name): """Create the test suite.