Fix tensor validation gaps in C and Rust APIs#2238
Conversation
📝 WalkthroughWalkthroughTwo independent change sets: (1) C API hardening and dtype validation with added tests for matrix row-slicing and pairwise distance; (2) Rust refactor introducing lifetime-parameterized ManagedTensor, to_device/to_host fixes, dlpack byte-size and validation updates, and propagation of lifetime changes across indices, examples, docs, and tests. ChangesMatrix Slice Rows API Hardening
Pairwise distance dtype validation and tests
Rust ManagedTensor lifetime refactor & API updates
🎯 4 (Complex) | ⏱️ ~45 minutes Suggested labels: Suggested reviewers:
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
c/tests/core/c_api.c (1)
71-73: ⚡ Quick winAdd a
start > endnegative test case.Line 71-Line 73 cover two invalid ranges, but not the explicit
start > endbranch now enforced incuvsMatrixSliceRows. Add one assertion to lock that contract path.Suggested test addition
expect_matrix_slice_error(res, &src_2d, -1, 1); expect_matrix_slice_error(res, &src_2d, 0, 4); + expect_matrix_slice_error(res, &src_2d, 2, 1);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@c/tests/core/c_api.c` around lines 71 - 73, Add a negative test for the start > end case by invoking expect_matrix_slice_error with the same result matrix and source (&src_2d) but with start greater than end (e.g., start=2, end=1) so the test asserts the new contract in cuvsMatrixSliceRows; place this new assertion alongside the existing invalid-range checks near expect_matrix_slice_error(res, &src_2d, -1, 1) and expect_matrix_slice_error(res, &src_2d, 0, 4) to cover the explicit start > end branch.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@c/tests/core/c_api.c`:
- Around line 71-73: Add a negative test for the start > end case by invoking
expect_matrix_slice_error with the same result matrix and source (&src_2d) but
with start greater than end (e.g., start=2, end=1) so the test asserts the new
contract in cuvsMatrixSliceRows; place this new assertion alongside the existing
invalid-range checks near expect_matrix_slice_error(res, &src_2d, -1, 1) and
expect_matrix_slice_error(res, &src_2d, 0, 4) to cover the explicit start > end
branch.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 86ba6ebd-3d99-42e6-9852-03a57d734c8a
📒 Files selected for processing (3)
c/include/cuvs/core/c_api.hc/src/core/c_api.cppc/tests/core/c_api.c
There was a problem hiding this comment.
🧹 Nitpick comments (2)
c/tests/distance/pairwise_distance_c.cu (1)
134-141: 💤 Low valueConsider a more specific error substring for clarity.
The test uses float32 inputs but checks for an error substring mentioning "float16 inputs". While the full error message does contain this substring (the complete message is "...for float16 inputs and match the input dtype otherwise"), checking for a substring like "match the input dtype" would more clearly reflect the actual failure mode for this test case.
📝 Suggested alternative substring
expect_pairwise_distance_error_contains( float_dtype(32), float_dtype(32), float_dtype(64), - "distances output to cuvsPairwiseDistance must have dtype float32 for float16 inputs"); + "match the input dtype");🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@c/tests/distance/pairwise_distance_c.cu` around lines 134 - 141, The test PairwiseDistanceC::FailsWithMismatchedFloatOutputDtype is asserting an error message that references "float16 inputs" even though the inputs used are float32; update the expected substring in expect_pairwise_distance_error_contains to a more accurate and specific phrase such as "match the input dtype" (or "must match the input dtype") so the assertion reflects the actual failure mode for float_dtype(32) inputs and mismatched float_dtype(64) output; locate the call to expect_pairwise_distance_error_contains in the TEST and replace the current substring accordingly.rust/cuvs/src/dlpack.rs (1)
212-218: 💤 Low valuePanic risk in
rmm_free_tensordeleter callback.
Resources::new().unwrap()can panic if resource creation fails. Since this runs insideDrop, a panic here could cause double-panic aborts. Consider handling the error gracefully (e.g., logging and continuing) or caching theResourceshandle.Note: This appears to be pre-existing behavior, so addressing it could be deferred.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rust/cuvs/src/dlpack.rs` around lines 212 - 218, rmm_free_tensor currently calls Resources::new().unwrap() inside the deleter which can panic during Drop; change this to avoid panicking by replacing the unwrap with fallible handling (e.g., call Resources::new().map(|res| { let bytes = dl_tensor_bytes(&(*self_).dl_tensor); let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes); }).unwrap_or_else(|err| { log the error via your logger and skip the free })) or alternatively cache a Resources handle for reuse so the deleter never creates resources; ensure you reference rmm_free_tensor, Resources::new(), dl_tensor_bytes, and ffi::cuvsRMMFree when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@c/tests/distance/pairwise_distance_c.cu`:
- Around line 134-141: The test
PairwiseDistanceC::FailsWithMismatchedFloatOutputDtype is asserting an error
message that references "float16 inputs" even though the inputs used are
float32; update the expected substring in
expect_pairwise_distance_error_contains to a more accurate and specific phrase
such as "match the input dtype" (or "must match the input dtype") so the
assertion reflects the actual failure mode for float_dtype(32) inputs and
mismatched float_dtype(64) output; locate the call to
expect_pairwise_distance_error_contains in the TEST and replace the current
substring accordingly.
In `@rust/cuvs/src/dlpack.rs`:
- Around line 212-218: rmm_free_tensor currently calls Resources::new().unwrap()
inside the deleter which can panic during Drop; change this to avoid panicking
by replacing the unwrap with fallible handling (e.g., call
Resources::new().map(|res| { let bytes = dl_tensor_bytes(&(*self_).dl_tensor);
let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes);
}).unwrap_or_else(|err| { log the error via your logger and skip the free })) or
alternatively cache a Resources handle for reuse so the deleter never creates
resources; ensure you reference rmm_free_tensor, Resources::new(),
dl_tensor_bytes, and ffi::cuvsRMMFree when making the change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b9c71ae5-b00c-4f30-9d47-c12f03247074
📒 Files selected for processing (15)
c/include/cuvs/distance/pairwise_distance.hc/src/distance/pairwise_distance.cppc/tests/distance/pairwise_distance_c.curust/cuvs/examples/cagra.rsrust/cuvs/src/brute_force.rsrust/cuvs/src/cagra/index.rsrust/cuvs/src/cagra/mod.rsrust/cuvs/src/cluster/kmeans/mod.rsrust/cuvs/src/distance/mod.rsrust/cuvs/src/dlpack.rsrust/cuvs/src/ivf_flat/index.rsrust/cuvs/src/ivf_flat/mod.rsrust/cuvs/src/ivf_pq/index.rsrust/cuvs/src/ivf_pq/mod.rsrust/cuvs/src/vamana/index.rs
✅ Files skipped from review due to trivial changes (2)
- c/include/cuvs/distance/pairwise_distance.h
- rust/cuvs/src/ivf_pq/mod.rs
Summary
cuvsMatrixSliceRowsfrom publishing partially initialized output metadata on validation failures, with C API coverage for invalid ranges and ranks.cuvsPairwiseDistancedtypes against the actualx,y, anddisttensors, including the supportedfloat16input tofloat32output case.ManagedTensorcarry the ndarray borrow lifetime, own its DLPack shape metadata, reject non-standard ndarray layouts, and validateto_host()destination shape/dtype before copying.Why
The C API paths were trusting or exposing metadata too early in a few validation cases. On the Rust side,
ManagedTensorlooked owning but stored borrowed ndarray data and shape pointers without a lifetime, which allowed safe Rust to build tensors with dangling metadata.to_host()also reused source tensor metadata for the destination, bypassing the C copy shape and dtype checks for the actual output buffer.Validation
pre-commit run clang-format --files c/src/distance/pairwise_distance.cpp c/tests/distance/pairwise_distance_c.cu c/include/cuvs/distance/pairwise_distance.hcargo fmt --allcargo check -p cuvs --features doc-only --all-targetsgit diff --checkI also tried
cargo test -p cuvs --features doc-only dlpack -- --nocapture, butdoc-onlytest binaries still need the cuVS C symbols at link time on this macOS ARM machine. CUDA C tests were not run locally becausenvccandnvidia-smiare not available here.