diff --git a/src/modelinfo/parsers/huggingface.py b/src/modelinfo/parsers/huggingface.py index 713ce82..fe793c8 100644 --- a/src/modelinfo/parsers/huggingface.py +++ b/src/modelinfo/parsers/huggingface.py @@ -3,9 +3,27 @@ import os import struct import urllib.error +import urllib.parse import urllib.request from typing import Any, Dict, Tuple +def _get_hf_endpoint() -> str: + endpoint = os.environ.get("HF_ENDPOINT", "https://huggingface.co").strip() + if not endpoint: + raise ValueError("HF_ENDPOINT is set but empty; expected a valid HTTP(S) URL") + endpoint = endpoint.rstrip("/") + if not endpoint.startswith("https://"): + raise ValueError( + f"HF_ENDPOINT must use https:// scheme, got: {endpoint}" + ) + parsed = urllib.parse.urlparse(endpoint) + if not parsed.netloc: + raise ValueError( + f"HF_ENDPOINT must include a valid hostname, got: {endpoint}" + ) + return endpoint + + def _get_hf_token() -> str | None: token = os.environ.get("HF_TOKEN") if token: @@ -51,7 +69,7 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = raise def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]: - url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" + url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/{filename}" # 1. Fetch the first 500KB in a single roundtrip headers = {"Range": "bytes=0-500000"} @@ -83,7 +101,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D Fetches the metadata directly from the Hugging Face Hub over the network. Returns: (tensors, config, format_name, disk_size) """ - api_url = f"https://huggingface.co/api/models/{repo_id}" + api_url = f"{_get_hf_endpoint()}/api/models/{repo_id}" try: api_data = json.loads(_make_request(api_url).decode("utf-8")) except urllib.error.HTTPError as e: @@ -98,7 +116,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D config = None if "config.json" in filenames: - config_url = f"https://huggingface.co/{repo_id}/resolve/main/config.json" + config_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/config.json" config = json.loads(_make_request(config_url).decode("utf-8")) tensors = {} @@ -106,7 +124,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D if "model.safetensors.index.json" in filenames: # Sharded SafeTensors - index_url = f"https://huggingface.co/{repo_id}/resolve/main/model.safetensors.index.json" + index_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors.index.json" index_data = json.loads(_make_request(index_url).decode("utf-8")) weight_map = index_data.get("weight_map", {}) @@ -149,7 +167,7 @@ def fetch_shard(shard: str): # Single SafeTensors # Determine total size first - req = urllib.request.Request(f"https://huggingface.co/{repo_id}/resolve/main/model.safetensors", method="HEAD") + req = urllib.request.Request(f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors", method="HEAD") token = _get_hf_token() if token: req.add_header("Authorization", f"Bearer {token}") diff --git a/tests/test_parsers.py b/tests/test_parsers.py index c867c5c..c8acb21 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -45,3 +45,42 @@ def test_gguf_parser_metadata(): # Verify the architecture bypass parses it to titlecase and prevents "Unknown Architecture" arch_name = identify_architecture_name(tensors, num_layers=1) assert arch_name == "Qwen2 (1 transformer layers)" + + +# --- HF_ENDPOINT validation --- + +from modelinfo.parsers.huggingface import _get_hf_endpoint + + +def test_hf_endpoint_valid_https(monkeypatch): + """Valid https:// endpoint is accepted.""" + monkeypatch.setenv("HF_ENDPOINT", "https://huggingface.co") + assert _get_hf_endpoint() == "https://huggingface.co" + + +def test_hf_endpoint_default_https(monkeypatch): + """Default endpoint when HF_ENDPOINT is not set.""" + monkeypatch.delenv("HF_ENDPOINT", raising=False) + endpoint = _get_hf_endpoint() + assert endpoint == "https://huggingface.co" + + +def test_hf_endpoint_rejects_http(monkeypatch): + """http:// scheme is rejected with ValueError.""" + monkeypatch.setenv("HF_ENDPOINT", "http://localhost:8080") + with pytest.raises(ValueError, match="must use https:// scheme"): + _get_hf_endpoint() + + +def test_hf_endpoint_rejects_empty(monkeypatch): + """Empty string is rejected with ValueError.""" + monkeypatch.setenv("HF_ENDPOINT", "") + with pytest.raises(ValueError): + _get_hf_endpoint() + + +def test_hf_endpoint_rejects_no_hostname(monkeypatch): + """URL without a hostname is rejected with ValueError.""" + monkeypatch.setenv("HF_ENDPOINT", "https:///repo") + with pytest.raises(ValueError, match="must include a valid hostname"): + _get_hf_endpoint()