Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/modelinfo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def _positive_int(value: str) -> int:
return ivalue


def _positive_float(value: str) -> float:
fvalue = float(value)
if fvalue <= 0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_positive_float allows NaN because fvalue <= 0 is false for NaN, so invalid --timeout nan passes parsing and can break request timeout handling.

Details

✨ AI Reasoning
​The new timeout validator is meant to enforce a strictly positive value, but it only checks whether the parsed float is less than or equal to zero. A NaN value bypasses that condition because NaN is neither <= 0 nor > 0. That means an invalid timeout can pass argument parsing and propagate into request logic, where timeout handling may raise runtime errors. This is a control-flow validation bug in the new logic.

🔧 How do I fix it?
Trace execution paths carefully. Ensure precondition checks happen before using values, validate ranges before checking impossible conditions, and don't check for states that the code has already ruled out.

Reply @AikidoSec feedback: [FEEDBACK] to get better review comments in the future.
Reply @AikidoSec ignore: [REASON] to ignore this issue.
More info

raise argparse.ArgumentTypeError("timeout must be greater than 0")
return fvalue
Comment on lines +44 to +48

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reject non-finite timeout values in CLI validation.

_positive_float currently allows nan/inf, which can escape argument validation and fail later in networking code.

Suggested fix
 def _positive_float(value: str) -> float:
     fvalue = float(value)
-    if fvalue <= 0:
-        raise argparse.ArgumentTypeError("timeout must be greater than 0")
+    if fvalue <= 0 or fvalue != fvalue or fvalue == float("inf"):
+        raise argparse.ArgumentTypeError(
+            "timeout must be a finite number greater than 0"
+        )
     return fvalue
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/modelinfo/cli.py` around lines 44 - 48, The _positive_float function in
src/modelinfo/cli.py currently accepts non-finite values like nan and inf, which
can cause issues downstream. Add a check after converting the string to float
using math.isfinite() to validate that the value is finite, and raise
argparse.ArgumentTypeError with an appropriate message if it is not (for
example, "timeout must be a finite number"). This validation should occur
alongside the existing check for positive values to ensure all invalid timeout
values are rejected during argument parsing.



def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="modelinfo",
Expand Down Expand Up @@ -82,6 +89,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
action="store_true",
help="Deep dive: Fetch all remote tensor shards to display the exact tensor size breakdown.",
)
parser.add_argument(
"--timeout",
type=_positive_float,
default=10.0,
help="Network timeout in seconds for remote Hugging Face fetches.",
)
parser.add_argument(
"--topology",
type=str,
Expand Down Expand Up @@ -122,6 +135,7 @@ def analyze_model(
gpu_count: int = 1,
batch_size: int = 1,
fetch_tensors: bool = False,
timeout: float = 10.0,
topology: str = "pcie4",
strategy: str = "tp",
is_vllm: bool = False,
Expand All @@ -136,7 +150,9 @@ def analyze_model(

if not os.path.exists(file_path) and not file_path_lower.endswith((".safetensors", ".gguf", ".pt", ".bin", ".index.json")):
from modelinfo.parsers.huggingface import fetch_huggingface_repo
tensors, config, format_name, disk_size = fetch_huggingface_repo(file_path, fetch_tensors=fetch_tensors)
tensors, config, format_name, disk_size = fetch_huggingface_repo(
file_path, fetch_tensors=fetch_tensors, timeout=timeout
)
elif file_path_lower.endswith(".safetensors") or file_path_lower.endswith(".index.json"):
tensors = parse_safetensors_header(file_path)
format_name = "SafeTensors"
Expand Down Expand Up @@ -240,6 +256,7 @@ def main(argv: Sequence[str] | None = None) -> int:
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
timeout=args.timeout,
topology=args.topology,
strategy=args.strategy,
is_vllm=args.vllm,
Expand All @@ -259,6 +276,7 @@ def main(argv: Sequence[str] | None = None) -> int:
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
timeout=args.timeout,
topology=args.topology,
strategy=args.strategy,
is_vllm=args.vllm,
Expand Down
31 changes: 18 additions & 13 deletions src/modelinfo/parsers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def _get_hf_token() -> str | None:

return None

def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = None) -> bytes:
def _make_request(
url: str,
headers: Dict[str, str] = None,
limit: int | None = None,
timeout: float = 10.0,
) -> bytes:
if headers is None:
headers = {}

Expand All @@ -57,7 +62,7 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None =

req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req, timeout=10) as response:
with urllib.request.urlopen(req, timeout=timeout) as response:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential user input in HTTP request may allow SSRF attack - medium severity
If an attacker can control the URL input leading into this HTTP request, the attack might be able to perform an SSRF attack. This kind of attack is even more dangerous if the application returns the response of the request to the user. It could allow them to retrieve information from higher privileged services within the network (such as the metadata service, which is commonly available in cloud services, and could allow them to retrieve credentials).

Show fix

Remediation: If possible, only allow requests to allowlisting domains. If not, consult the article linked above to learn about other mitigating techniques such as disabling redirects, blocking private IPs and making sure private services have internal authentication. If you return data coming from the request to the user, validate the data before returning it to make sure you don't return random data.

Reply @AikidoSec ignore: [REASON] to ignore this issue.
More info

if limit is not None:
return response.read(limit)
return response.read()
Expand All @@ -68,16 +73,16 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None =
raise FileNotFoundError(f"Could not find repository or file on Hugging Face (404 Not Found): {url}")
raise

def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]:
def _fetch_safetensors_header(repo_id: str, filename: str, timeout: float = 10.0) -> Dict[str, Any]:
url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/{filename}"

# 1. Fetch the first 500KB in a single roundtrip
headers = {"Range": "bytes=0-500000"}
try:
chunk = _make_request(url, headers=headers, limit=500000)
chunk = _make_request(url, headers=headers, limit=500000, timeout=timeout)
except urllib.error.HTTPError as e:
if e.code == 416: # Range Not Satisfiable (file is smaller than 500KB)
chunk = _make_request(url, limit=500000)
chunk = _make_request(url, limit=500000, timeout=timeout)
else:
raise

Expand All @@ -92,18 +97,18 @@ def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]:
else:
# 3. Double-roundtrip only if the header is massive (>500KB)
headers = {"Range": f"bytes=8-{8+header_size-1}"}
json_bytes = _make_request(url, headers=headers, limit=header_size)
json_bytes = _make_request(url, headers=headers, limit=header_size, timeout=timeout)

return json.loads(json_bytes)

def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]:
def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False, timeout: float = 10.0) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]:
"""
Fetches the metadata directly from the Hugging Face Hub over the network.
Returns: (tensors, config, format_name, disk_size)
"""
api_url = f"{_get_hf_endpoint()}/api/models/{repo_id}"
try:
api_data = json.loads(_make_request(api_url).decode("utf-8"))
api_data = json.loads(_make_request(api_url, timeout=timeout).decode("utf-8"))
except urllib.error.HTTPError as e:
if e.code == 401:
raise PermissionError(f"Gated/Private Model (401 Unauthorized). Set the HF_TOKEN environment variable to access {repo_id}")
Expand All @@ -117,15 +122,15 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D
config = None
if "config.json" in filenames:
config_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/config.json"
config = json.loads(_make_request(config_url).decode("utf-8"))
config = json.loads(_make_request(config_url, timeout=timeout).decode("utf-8"))

tensors = {}
total_size = 0.0

if "model.safetensors.index.json" in filenames:
# Sharded SafeTensors
index_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors.index.json"
index_data = json.loads(_make_request(index_url).decode("utf-8"))
index_data = json.loads(_make_request(index_url, timeout=timeout).decode("utf-8"))

weight_map = index_data.get("weight_map", {})
unique_shards = list(set(weight_map.values()))
Expand All @@ -146,7 +151,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D
}
else:
def fetch_shard(shard: str):
return shard, _fetch_safetensors_header(repo_id, shard)
return shard, _fetch_safetensors_header(repo_id, shard, timeout=timeout)

with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(8, len(unique_shards)))) as executor:
future_to_shard = {executor.submit(fetch_shard, shard): shard for shard in unique_shards}
Expand All @@ -172,12 +177,12 @@ def fetch_shard(shard: str):
if token:
req.add_header("Authorization", f"Bearer {token}")
try:
with urllib.request.urlopen(req) as response:
with urllib.request.urlopen(req, timeout=timeout) as response:
total_size = int(response.headers.get("Content-Length", 0))
except Exception:
pass

header = _fetch_safetensors_header(repo_id, "model.safetensors")
header = _fetch_safetensors_header(repo_id, "model.safetensors", timeout=timeout)
tensors = header

format_name = "SafeTensors"
Expand Down
86 changes: 86 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@ def test_batch_size_flag_rejects_negative():
assert exc_info.value.code == 2


def test_timeout_flag_defaults_to_ten_seconds():
args = parse_args(["model.gguf"])

assert args.timeout == 10.0


def test_timeout_flag_accepts_float():
args = parse_args(["--timeout", "30.5", "model.gguf"])

assert args.timeout == 30.5


def test_timeout_flag_rejects_zero():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "0", "model.gguf"])

assert exc_info.value.code == 2


def test_timeout_flag_rejects_negative():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "-1", "model.gguf"])

assert exc_info.value.code == 2


def test_analyze_model_passes_batch_size_to_footprint(monkeypatch, tmp_path):
model_path = tmp_path / "model.gguf"
model_path.write_bytes(b"mock")
Expand Down Expand Up @@ -77,3 +103,63 @@ def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs):

assert captured == {"batch_size": 4, "context_length": 128}
assert info["footprint"]["kv_cache_bytes"] == 4.0


def test_analyze_model_passes_timeout_to_huggingface(monkeypatch):
captured = {}

def fake_exists(path):
return False

def fake_fetch(repo_id, *, fetch_tensors, timeout):
captured["repo_id"] = repo_id
captured["fetch_tensors"] = fetch_tensors
captured["timeout"] = timeout
return (
{
"model.layers.0.self_attn.k_proj.weight": {
"shape": [1, 1],
"dtype": "F16",
}
},
None,
"SafeTensors",
7.0,
)

def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs):
return {
"total_params": 1,
"base_memory_bytes": 2.0,
"kv_cache_bytes": 1.0,
"overhead_bytes": 0.0,
"total_memory_bytes": 3.0,
"num_layers": 1,
"kv_dim": 1,
"primary_dtype": "F16",
"kv_is_estimate": False,
"penalty_percentage": 0.0,
"vllm_metrics": {},
}

from modelinfo.parsers import huggingface

monkeypatch.setattr(cli.os.path, "exists", fake_exists)
monkeypatch.setattr(huggingface, "fetch_huggingface_repo", fake_fetch)
monkeypatch.setattr(cli, "calculate_footprint", fake_calculate_footprint)
monkeypatch.setattr(
cli, "identify_architecture_name", lambda tensors, num_layers, config: "Mock"
)

cli.analyze_model(
"org/model",
context_override=128,
fetch_tensors=True,
timeout=22.5,
)

assert captured == {
"repo_id": "org/model",
"fetch_tensors": True,
"timeout": 22.5,
}
Loading