-
Notifications
You must be signed in to change notification settings - Fork 7
feat: add Hugging Face fetch timeout flag #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,13 @@ def _positive_int(value: str) -> int: | |
| return ivalue | ||
|
|
||
|
|
||
| def _positive_float(value: str) -> float: | ||
| fvalue = float(value) | ||
| if fvalue <= 0: | ||
| raise argparse.ArgumentTypeError("timeout must be greater than 0") | ||
| return fvalue | ||
|
Comment on lines
+44
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject non-finite timeout values in CLI validation.
Suggested fix def _positive_float(value: str) -> float:
fvalue = float(value)
- if fvalue <= 0:
- raise argparse.ArgumentTypeError("timeout must be greater than 0")
+ if fvalue <= 0 or fvalue != fvalue or fvalue == float("inf"):
+ raise argparse.ArgumentTypeError(
+ "timeout must be a finite number greater than 0"
+ )
return fvalue🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser( | ||
| prog="modelinfo", | ||
|
|
@@ -82,6 +89,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: | |
| action="store_true", | ||
| help="Deep dive: Fetch all remote tensor shards to display the exact tensor size breakdown.", | ||
| ) | ||
| parser.add_argument( | ||
| "--timeout", | ||
| type=_positive_float, | ||
| default=10.0, | ||
| help="Network timeout in seconds for remote Hugging Face fetches.", | ||
| ) | ||
| parser.add_argument( | ||
| "--topology", | ||
| type=str, | ||
|
|
@@ -122,6 +135,7 @@ def analyze_model( | |
| gpu_count: int = 1, | ||
| batch_size: int = 1, | ||
| fetch_tensors: bool = False, | ||
| timeout: float = 10.0, | ||
| topology: str = "pcie4", | ||
| strategy: str = "tp", | ||
| is_vllm: bool = False, | ||
|
|
@@ -136,7 +150,9 @@ def analyze_model( | |
|
|
||
| if not os.path.exists(file_path) and not file_path_lower.endswith((".safetensors", ".gguf", ".pt", ".bin", ".index.json")): | ||
| from modelinfo.parsers.huggingface import fetch_huggingface_repo | ||
| tensors, config, format_name, disk_size = fetch_huggingface_repo(file_path, fetch_tensors=fetch_tensors) | ||
| tensors, config, format_name, disk_size = fetch_huggingface_repo( | ||
| file_path, fetch_tensors=fetch_tensors, timeout=timeout | ||
| ) | ||
| elif file_path_lower.endswith(".safetensors") or file_path_lower.endswith(".index.json"): | ||
| tensors = parse_safetensors_header(file_path) | ||
| format_name = "SafeTensors" | ||
|
|
@@ -240,6 +256,7 @@ def main(argv: Sequence[str] | None = None) -> int: | |
| gpu_count=gpu_count, | ||
| batch_size=args.batch_size, | ||
| fetch_tensors=args.tensors, | ||
| timeout=args.timeout, | ||
| topology=args.topology, | ||
| strategy=args.strategy, | ||
| is_vllm=args.vllm, | ||
|
|
@@ -259,6 +276,7 @@ def main(argv: Sequence[str] | None = None) -> int: | |
| gpu_count=gpu_count, | ||
| batch_size=args.batch_size, | ||
| fetch_tensors=args.tensors, | ||
| timeout=args.timeout, | ||
| topology=args.topology, | ||
| strategy=args.strategy, | ||
| is_vllm=args.vllm, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,7 +47,12 @@ def _get_hf_token() -> str | None: | |
|
|
||
| return None | ||
|
|
||
| def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = None) -> bytes: | ||
| def _make_request( | ||
| url: str, | ||
| headers: Dict[str, str] = None, | ||
| limit: int | None = None, | ||
| timeout: float = 10.0, | ||
| ) -> bytes: | ||
| if headers is None: | ||
| headers = {} | ||
|
|
||
|
|
@@ -57,7 +62,7 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = | |
|
|
||
| req = urllib.request.Request(url, headers=headers) | ||
| try: | ||
| with urllib.request.urlopen(req, timeout=10) as response: | ||
| with urllib.request.urlopen(req, timeout=timeout) as response: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential user input in HTTP request may allow SSRF attack - medium severity Show fixRemediation: If possible, only allow requests to allowlisting domains. If not, consult the article linked above to learn about other mitigating techniques such as disabling redirects, blocking private IPs and making sure private services have internal authentication. If you return data coming from the request to the user, validate the data before returning it to make sure you don't return random data. Reply |
||
| if limit is not None: | ||
| return response.read(limit) | ||
| return response.read() | ||
|
|
@@ -68,16 +73,16 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = | |
| raise FileNotFoundError(f"Could not find repository or file on Hugging Face (404 Not Found): {url}") | ||
| raise | ||
|
|
||
| def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]: | ||
| def _fetch_safetensors_header(repo_id: str, filename: str, timeout: float = 10.0) -> Dict[str, Any]: | ||
| url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/{filename}" | ||
|
|
||
| # 1. Fetch the first 500KB in a single roundtrip | ||
| headers = {"Range": "bytes=0-500000"} | ||
| try: | ||
| chunk = _make_request(url, headers=headers, limit=500000) | ||
| chunk = _make_request(url, headers=headers, limit=500000, timeout=timeout) | ||
| except urllib.error.HTTPError as e: | ||
| if e.code == 416: # Range Not Satisfiable (file is smaller than 500KB) | ||
| chunk = _make_request(url, limit=500000) | ||
| chunk = _make_request(url, limit=500000, timeout=timeout) | ||
| else: | ||
| raise | ||
|
|
||
|
|
@@ -92,18 +97,18 @@ def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]: | |
| else: | ||
| # 3. Double-roundtrip only if the header is massive (>500KB) | ||
| headers = {"Range": f"bytes=8-{8+header_size-1}"} | ||
| json_bytes = _make_request(url, headers=headers, limit=header_size) | ||
| json_bytes = _make_request(url, headers=headers, limit=header_size, timeout=timeout) | ||
|
|
||
| return json.loads(json_bytes) | ||
|
|
||
| def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]: | ||
| def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False, timeout: float = 10.0) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]: | ||
| """ | ||
| Fetches the metadata directly from the Hugging Face Hub over the network. | ||
| Returns: (tensors, config, format_name, disk_size) | ||
| """ | ||
| api_url = f"{_get_hf_endpoint()}/api/models/{repo_id}" | ||
| try: | ||
| api_data = json.loads(_make_request(api_url).decode("utf-8")) | ||
| api_data = json.loads(_make_request(api_url, timeout=timeout).decode("utf-8")) | ||
| except urllib.error.HTTPError as e: | ||
| if e.code == 401: | ||
| raise PermissionError(f"Gated/Private Model (401 Unauthorized). Set the HF_TOKEN environment variable to access {repo_id}") | ||
|
|
@@ -117,15 +122,15 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D | |
| config = None | ||
| if "config.json" in filenames: | ||
| config_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/config.json" | ||
| config = json.loads(_make_request(config_url).decode("utf-8")) | ||
| config = json.loads(_make_request(config_url, timeout=timeout).decode("utf-8")) | ||
|
|
||
| tensors = {} | ||
| total_size = 0.0 | ||
|
|
||
| if "model.safetensors.index.json" in filenames: | ||
| # Sharded SafeTensors | ||
| index_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors.index.json" | ||
| index_data = json.loads(_make_request(index_url).decode("utf-8")) | ||
| index_data = json.loads(_make_request(index_url, timeout=timeout).decode("utf-8")) | ||
|
|
||
| weight_map = index_data.get("weight_map", {}) | ||
| unique_shards = list(set(weight_map.values())) | ||
|
|
@@ -146,7 +151,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D | |
| } | ||
| else: | ||
| def fetch_shard(shard: str): | ||
| return shard, _fetch_safetensors_header(repo_id, shard) | ||
| return shard, _fetch_safetensors_header(repo_id, shard, timeout=timeout) | ||
|
|
||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(8, len(unique_shards)))) as executor: | ||
| future_to_shard = {executor.submit(fetch_shard, shard): shard for shard in unique_shards} | ||
|
|
@@ -172,12 +177,12 @@ def fetch_shard(shard: str): | |
| if token: | ||
| req.add_header("Authorization", f"Bearer {token}") | ||
| try: | ||
| with urllib.request.urlopen(req) as response: | ||
| with urllib.request.urlopen(req, timeout=timeout) as response: | ||
| total_size = int(response.headers.get("Content-Length", 0)) | ||
| except Exception: | ||
| pass | ||
|
|
||
| header = _fetch_safetensors_header(repo_id, "model.safetensors") | ||
| header = _fetch_safetensors_header(repo_id, "model.safetensors", timeout=timeout) | ||
| tensors = header | ||
|
|
||
| format_name = "SafeTensors" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_positive_float allows NaN because
fvalue <= 0is false for NaN, so invalid--timeout nanpasses parsing and can break request timeout handling.Details
✨ AI Reasoning
The new timeout validator is meant to enforce a strictly positive value, but it only checks whether the parsed float is less than or equal to zero. A NaN value bypasses that condition because NaN is neither <= 0 nor > 0. That means an invalid timeout can pass argument parsing and propagate into request logic, where timeout handling may raise runtime errors. This is a control-flow validation bug in the new logic.
🔧 How do I fix it?
Trace execution paths carefully. Ensure precondition checks happen before using values, validate ranges before checking impossible conditions, and don't check for states that the code has already ruled out.
Reply
@AikidoSec feedback: [FEEDBACK]to get better review comments in the future.Reply
@AikidoSec ignore: [REASON]to ignore this issue.More info