Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions python/ray/_common/runtime_env_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import enum
import pathlib
import urllib.parse
from typing import Tuple
from urllib.parse import urlparse

_REMOTE_PROTOCOLS = ("http", "https", "s3", "gs", "azure", "abfss", "file")


class Protocol(enum.Enum):
# For packages dynamically uploaded and managed by the GCS.
GCS = "gcs"
# For conda environments installed locally on each node.
CONDA = "conda"
# For pip environments installed locally on each node.
PIP = "pip"
# For uv environments installed locally on each node.
UV = "uv"
# Remote http path, assumes everything packed in one zip file.
HTTP = "http"
# Remote https path, assumes everything packed in one zip file.
HTTPS = "https"
# Remote s3 path, assumes everything packed in one zip file.
S3 = "s3"
# Remote google storage path, assumes everything packed in one zip file.
GS = "gs"
# Remote azure blob storage path, assumes everything packed in one zip file.
AZURE = "azure"
# Remote Azure Blob File System Secure path, assumes everything packed in one zip file.
ABFSS = "abfss"
# File storage path, assumes everything packed in one zip file.
FILE = "file"

@classmethod
def remote_protocols(cls):
# Returns a list of protocols that support remote storage.
# These protocols should only be used with paths that end in
# ".zip", ".whl", ".tar.gz", or ".tgz".
return [cls[protocol.upper()] for protocol in _REMOTE_PROTOCOLS]


def _is_path(path_or_uri: str) -> bool:
"""Returns True if path_or_uri is a path and False otherwise."""
if not isinstance(path_or_uri, str):
raise TypeError(f" path_or_uri must be a string, got {type(path_or_uri)}.")

parsed_path = pathlib.Path(path_or_uri)
parsed_uri = urllib.parse.urlparse(path_or_uri)

if isinstance(parsed_path, pathlib.PurePosixPath):
return not parsed_uri.scheme
elif isinstance(parsed_path, pathlib.PureWindowsPath):
return parsed_uri.scheme == parsed_path.drive.strip(":").lower()
else:
# this should never happen.
raise TypeError(f"Unsupported path type: {type(parsed_path).__name__}")


def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
"""
Parse package uri into protocol and package name based on its format.
Note that the output of this function is not for handling actual IO, it's
only for setting up local directory folders by using package name as path.

>>> parse_uri("https://test.com/file.zip")
(<Protocol.HTTPS: 'https'>, 'https_test_com_file.zip')

>>> parse_uri("https://test.com/file.whl")
(<Protocol.HTTPS: 'https'>, 'file.whl')

"""
if _is_path(pkg_uri):
raise ValueError(f"Expected URI but received path {pkg_uri}")

uri = urlparse(pkg_uri)
try:
protocol = Protocol(uri.scheme)
except ValueError as e:
raise ValueError(
f'Invalid protocol for runtime_env URI "{pkg_uri}". '
f"Supported protocols: {Protocol._member_names_}. Original error: {e}"
)

if protocol in Protocol.remote_protocols():
if uri.path.endswith(".whl"):
# Don't modify the .whl filename. See
# https://peps.python.org/pep-0427/#file-name-convention
# for more information.
package_name = uri.path.split("/")[-1]
else:
package_name = f"{protocol.value}_{uri.netloc}{uri.path}"

disallowed_chars = ["/", ":", "@", "+", " ", "(", ")"]
for disallowed_char in disallowed_chars:
package_name = package_name.replace(disallowed_char, "_")

# Preserve compound extensions like .tar.gz before replacing dots.
compound_ext = None
if package_name.endswith(".tar.gz"):
compound_ext = ".tar.gz"
package_name = package_name[: -len(".tar.gz")]
elif package_name.endswith(".tar.bz2"):
compound_ext = ".tar.bz2"
package_name = package_name[: -len(".tar.bz2")]

if compound_ext:
package_name = package_name.replace(".", "_")
package_name += compound_ext
else:
# Remove all periods except the last, which is part of the
# file extension.
package_name = package_name.replace(
".", "_", package_name.count(".") - 1
)
else:
package_name = uri.netloc
return (protocol, package_name)
1 change: 1 addition & 0 deletions python/ray/_common/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ py_test_module_list(
"test_network_utils.py",
"test_ray_option_utils.py",
"test_retry.py",
"test_runtime_env_uri.py",
"test_signal_semaphore_utils.py",
"test_signature.py",
"test_tls_utils.py",
Expand Down
182 changes: 182 additions & 0 deletions python/ray/_common/tests/test_runtime_env_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import sys

import pytest

from ray._common.runtime_env_uri import Protocol, parse_uri


class TestParseUri:
@pytest.mark.parametrize(
"uri,protocol,package_name",
[
("gcs://file.zip", Protocol.GCS, "file.zip"),
("s3://bucket/file.zip", Protocol.S3, "s3_bucket_file.zip"),
("http://test.com/file.zip", Protocol.HTTP, "http_test_com_file.zip"),
("https://test.com/file.zip", Protocol.HTTPS, "https_test_com_file.zip"),
("gs://bucket/file.zip", Protocol.GS, "gs_bucket_file.zip"),
("azure://container/file.zip", Protocol.AZURE, "azure_container_file.zip"),
(
"abfss://container@account.dfs.core.windows.net/file.zip",
Protocol.ABFSS,
"abfss_container_account_dfs_core_windows_net_file.zip",
),
(
"https://test.com/package-0.0.1-py2.py3-none-any.whl?param=value",
Protocol.HTTPS,
"package-0.0.1-py2.py3-none-any.whl",
),
(
"http://test.com/package-0.0.1-py2.py3-none-any.whl?param=value",
Protocol.HTTP,
"package-0.0.1-py2.py3-none-any.whl",
),
],
)
def test_parsing_remote_basic(self, uri, protocol, package_name):
assert parse_uri(uri) == (protocol, package_name)

@pytest.mark.parametrize(
"uri,package_name",
[
(
"https://username:PAT@github.com/repo/archive/commit_hash.zip",
"https_username_PAT_github_com_repo_archive_commit_hash.zip",
),
(
(
"https://un:pwd@gitlab.com/user/repo/-/"
"archive/commit_hash/repo-commit_hash.zip"
),
(
"https_un_pwd_gitlab_com_user_repo_-_"
"archive_commit_hash_repo-commit_hash.zip"
),
),
],
)
def test_parse_private_git_https_uris(self, uri, package_name):
assert parse_uri(uri) == (Protocol.HTTPS, package_name)

@pytest.mark.parametrize(
"uri,protocol,package_name",
[
(
"https://username:PAT@github.com/repo/archive:2/commit_hash.zip",
Protocol.HTTPS,
"https_username_PAT_github_com_repo_archive_2_commit_hash.zip",
),
(
"gs://fake/2022-10-21T13:11:35+00:00/package.zip",
Protocol.GS,
"gs_fake_2022-10-21T13_11_35_00_00_package.zip",
),
(
"s3://fake/2022-10-21T13:11:35+00:00/package.zip",
Protocol.S3,
"s3_fake_2022-10-21T13_11_35_00_00_package.zip",
),
(
"azure://fake/2022-10-21T13:11:35+00:00/package.zip",
Protocol.AZURE,
"azure_fake_2022-10-21T13_11_35_00_00_package.zip",
),
(
(
"abfss://container@account.dfs.core.windows.net/"
"2022-10-21T13:11:35+00:00/package.zip"
),
Protocol.ABFSS,
(
"abfss_container_account_dfs_core_windows_net_"
"2022-10-21T13_11_35_00_00_package.zip"
),
),
(
"file:///fake/2022-10-21T13:11:35+00:00/package.zip",
Protocol.FILE,
"file__fake_2022-10-21T13_11_35_00_00_package.zip",
),
(
"file:///fake/2022-10-21T13:11:35+00:00/(package).zip",
Protocol.FILE,
"file__fake_2022-10-21T13_11_35_00_00__package_.zip",
),
],
)
def test_parse_uris_with_disallowed_chars(self, uri, protocol, package_name):
assert parse_uri(uri) == (protocol, package_name)

@pytest.mark.parametrize(
"uri,protocol,package_name",
[
(
"https://username:PAT@github.com/repo/archive:2/commit_hash.whl",
Protocol.HTTPS,
"commit_hash.whl",
),
(
"gs://fake/2022-10-21T13:11:35+00:00/package.whl",
Protocol.GS,
"package.whl",
),
(
"s3://fake/2022-10-21T13:11:35+00:00/package.whl",
Protocol.S3,
"package.whl",
),
(
"azure://fake/2022-10-21T13:11:35+00:00/package.whl",
Protocol.AZURE,
"package.whl",
),
(
(
"abfss://container@account.dfs.core.windows.net/"
"2022-10-21T13:11:35+00:00/package.whl"
),
Protocol.ABFSS,
"package.whl",
),
(
"file:///fake/2022-10-21T13:11:35+00:00/package.whl",
Protocol.FILE,
"package.whl",
),
],
)
def test_parse_remote_whl_uris(self, uri, protocol, package_name):
assert parse_uri(uri) == (protocol, package_name)

@pytest.mark.parametrize(
"gcs_uri",
["gcs://pip_install_test-0.5-py3-none-any.whl", "gcs://storing@here.zip"],
)
def test_parse_gcs_uri(self, gcs_uri):
"""GCS URIs should not be modified in this function."""
protocol, package_name = parse_uri(gcs_uri)
assert protocol == Protocol.GCS
assert package_name == gcs_uri.split("/")[-1]


def test_parse_uri_tar_gz():
protocol, package_name = parse_uri("s3://bucket/archive.tar.gz")
assert package_name.endswith(".tar.gz")
assert protocol == Protocol.S3

protocol, package_name = parse_uri("https://example.com/path/my.pkg.tar.gz")
assert package_name.endswith(".tar.gz")
assert "_" in package_name


def test_parse_uri_rejects_local_path():
with pytest.raises(ValueError, match="Expected URI but received path"):
parse_uri("/tmp/file.zip")


def test_parse_uri_rejects_invalid_protocol():
with pytest.raises(ValueError, match="Invalid protocol for runtime_env URI"):
parse_uri("unknown://file.zip")


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
3 changes: 2 additions & 1 deletion python/ray/_private/runtime_env/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from filelock import FileLock

import ray
from ray._common.runtime_env_uri import parse_uri
from ray._common.utils import (
get_or_create_event_loop,
try_to_create_directory,
Expand All @@ -25,8 +26,8 @@
get_conda_info_json,
)
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import Protocol, parse_uri
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.runtime_env.protocol import Protocol
from ray._private.runtime_env.validation import parse_and_validate_conda
from ray._private.utils import (
get_directory_size_bytes,
Expand Down
Loading
Loading