Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 44 additions & 45 deletions pkg/rag/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"os"
"path/filepath"
"slices"
"sync/atomic"
"time"

"github.com/docker/docker-agent/pkg/modelerrors"
"github.com/docker/docker-agent/pkg/rag/database"
"github.com/docker/docker-agent/pkg/rag/fusion"
"github.com/docker/docker-agent/pkg/rag/rerank"
Expand Down Expand Up @@ -60,6 +62,7 @@ type Manager struct {
strategyConfigs map[string]strategy.Config // Store configs for per-strategy operations
fusion fusion.Fusion // Fusion strategy for combining multi-strategy results
reranker rerank.Reranker // Optional reranker for result re-scoring
rerankDisabled atomic.Bool // Set after a non-retryable reranking error to stop doomed requests
events <-chan types.Event // Shared event channel from strategies and other RAG operations
}

Expand Down Expand Up @@ -243,30 +246,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes
"num_results", len(results))

// Apply reranking if configured
if m.reranker != nil {
beforeCount := len(results)
slog.DebugContext(ctx, "[RAG Manager] Applying reranking to single-strategy results",
"rag_name", m.name,
"strategy", strategyName,
"result_count_before", beforeCount)

rerankedResults, rerankErr := m.reranker.Rerank(ctx, query, results)
if rerankErr != nil {
slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original results",
"rag_name", m.name,
"strategy", strategyName,
"error", rerankErr)
// Continue with original results rather than failing completely
} else {
results = rerankedResults
slog.DebugContext(ctx, "[RAG Manager] Reranked single-strategy results",
"rag_name", m.name,
"strategy", strategyName,
"result_count_before", beforeCount,
"result_count_after", len(results),
"filtered", beforeCount-len(results))
}
}
results = m.rerank(ctx, query, results)

if limit := m.config.Results.Limit; limit > 0 && len(results) > limit {
slog.DebugContext(ctx, "[RAG Manager] Truncating to global result limit",
Expand Down Expand Up @@ -373,27 +353,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes
"result_limit", m.config.Results.Limit)

// Apply reranking if configured (before limit and deduplication)
if m.reranker != nil {
beforeCount := len(fusedResults)
slog.DebugContext(ctx, "[RAG Manager] Applying reranking to fused results",
"rag_name", m.name,
"result_count_before", beforeCount)

rerankedResults, rerankErr := m.reranker.Rerank(ctx, query, fusedResults)
if rerankErr != nil {
slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original fused results",
"rag_name", m.name,
"error", rerankErr)
// Continue with original fused results rather than failing completely
} else {
fusedResults = rerankedResults
slog.DebugContext(ctx, "[RAG Manager] Reranked fused results",
"rag_name", m.name,
"result_count_before", beforeCount,
"result_count_after", len(fusedResults),
"filtered", beforeCount-len(fusedResults))
}
}
fusedResults = m.rerank(ctx, query, fusedResults)

// Apply result limit if configured
if limit := m.config.Results.Limit; limit > 0 && len(fusedResults) > limit {
Expand Down Expand Up @@ -430,6 +390,45 @@ func getStrategyNames(stratMap map[string]strategy.Strategy) []string {
return slices.Collect(maps.Keys(stratMap))
}

// rerank applies the configured reranker to results, falling back to the
// original results on failure. After a non-retryable model error (e.g. an
// invalid reranking model name), the reranker is disabled for the lifetime
// of the manager so every subsequent query doesn't issue another request
// that is guaranteed to fail (see issue #3082).
func (m *Manager) rerank(ctx context.Context, query string, results []database.SearchResult) []database.SearchResult {
if m.reranker == nil || m.rerankDisabled.Load() || len(results) == 0 {
return results
}

beforeCount := len(results)
rerankedResults, err := m.reranker.Rerank(ctx, query, results)
if err == nil {
slog.DebugContext(ctx, "[RAG Manager] Reranked results",
"rag_name", m.name,
"result_count_before", beforeCount,
"result_count_after", len(rerankedResults),
"filtered", beforeCount-len(rerankedResults))
return rerankedResults
}

if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return results
}

if retryable, _, _ := modelerrors.ClassifyModelError(err); !retryable {
m.rerankDisabled.Store(true)
slog.ErrorContext(ctx, "[RAG Manager] Disabling reranking after non-retryable error; check the reranking model configuration",
"rag_name", m.name,
"error", err)
return results
}

slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original results",
"rag_name", m.name,
"error", err)
return results
}

// CheckAndReindexChangedFiles checks for file changes and re-indexes if needed
func (m *Manager) CheckAndReindexChangedFiles(ctx context.Context) error {
for strategyName, strategyImpl := range m.strategies {
Expand Down
120 changes: 120 additions & 0 deletions pkg/rag/manager_rerank_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package rag

import (
"context"
"errors"
"net/http"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/modelerrors"
"github.com/docker/docker-agent/pkg/rag/database"
"github.com/docker/docker-agent/pkg/rag/strategy"
)

// failingReranker counts calls and always fails with a fixed error.
type failingReranker struct {
calls atomic.Int64
err error
}

func (r *failingReranker) Rerank(context.Context, string, []database.SearchResult) ([]database.SearchResult, error) {
r.calls.Add(1)
return nil, r.err
}

// staticStrategy returns fixed results for every query.
type staticStrategy struct {
results []database.SearchResult
}

func (s *staticStrategy) Initialize(context.Context, []string, strategy.ChunkingConfig) error {
return nil
}

func (s *staticStrategy) Query(context.Context, string, int, float64) ([]database.SearchResult, error) {
return s.results, nil
}

func (s *staticStrategy) CheckAndReindexChangedFiles(context.Context, []string, strategy.ChunkingConfig) error {
return nil
}

func (s *staticStrategy) StartFileWatcher(context.Context, []string, strategy.ChunkingConfig) error {
return nil
}

func (s *staticStrategy) Close() error { return nil }

func newRerankTestManager(t *testing.T, rerankErr error) (*Manager, *failingReranker) {
t.Helper()

results := []database.SearchResult{
{Document: database.Document{ID: "1", Content: "doc one"}, Similarity: 0.9},
{Document: database.Document{ID: "2", Content: "doc two"}, Similarity: 0.8},
}

reranker := &failingReranker{err: rerankErr}
cfg := Config{
StrategyConfigs: []strategy.Config{{
Name: "static",
Strategy: &staticStrategy{results: results},
Limit: 5,
}},
Results: ResultsConfig{
RerankingConfig: &RerankingConfig{Reranker: reranker},
},
}

m, err := New(t.Context(), "test", cfg, nil)
require.NoError(t, err)
return m, reranker
}

func TestQueryDisablesRerankerAfterNonRetryableError(t *testing.T) {
rerankErr := &modelerrors.StatusError{
StatusCode: http.StatusNotFound,
Err: errors.New("not_found_error: model: claude-sonnet-4-7"),
}
m, reranker := newRerankTestManager(t, rerankErr)

for range 3 {
results, err := m.Query(t.Context(), "some query")
require.NoError(t, err, "rerank failures must not fail the query")
assert.Len(t, results, 2, "original results are returned as fallback")
}

assert.Equal(t, int64(1), reranker.calls.Load(),
"reranker must be disabled after the first non-retryable error instead of being called on every query")
}

func TestQueryKeepsRerankerOnTransientError(t *testing.T) {
rerankErr := &modelerrors.StatusError{
StatusCode: http.StatusInternalServerError,
Err: errors.New("server error"),
}
m, reranker := newRerankTestManager(t, rerankErr)

for range 3 {
results, err := m.Query(t.Context(), "some query")
require.NoError(t, err)
assert.Len(t, results, 2)
}

assert.Equal(t, int64(3), reranker.calls.Load(),
"transient errors should not disable the reranker")
}

func TestQueryKeepsRerankerOnContextCancellation(t *testing.T) {
m, reranker := newRerankTestManager(t, context.Canceled)

results, err := m.Query(t.Context(), "some query")
require.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, int64(1), reranker.calls.Load())
assert.False(t, m.rerankDisabled.Load(),
"context cancellation must not permanently disable the reranker")
}
42 changes: 42 additions & 0 deletions pkg/rag/strategy/indexing_errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package strategy

import (
"context"
"errors"
"fmt"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// errIndexingAborted marks a permanent model/provider failure (e.g. invalid
// model name, authentication failure, rate limit) encountered during indexing.
// When such an error occurs, the whole indexing run must stop immediately:
// every remaining file/chunk would trigger the same failing request, flooding
// the provider (see https://github.com/docker/docker-agent/issues/3082).
var errIndexingAborted = errors.New("indexing aborted due to non-retryable model error")

// classifyModelCallError inspects an error returned by an embedding or LLM
// call made during indexing. Permanent failures are wrapped with
// errIndexingAborted so callers can abort the run; transient failures (5xx,
// timeouts) and context cancellation are returned unchanged so callers can
// skip the current file and continue.
func classifyModelCallError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
// Rate-limited (429) errors are also non-retryable here: continuing to
// index would keep hammering a provider that asked us to back off.
retryable, _, _ := modelerrors.ClassifyModelError(err)
if !retryable {
return fmt.Errorf("%w: %w", errIndexingAborted, err)
}
return err
}

// isIndexingAborted reports whether err carries the errIndexingAborted marker.
func isIndexingAborted(err error) bool {
return errors.Is(err, errIndexingAborted)
}
34 changes: 32 additions & 2 deletions pkg/rag/strategy/vector_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,16 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin

// Index the file
if err := s.indexFile(gctx, status.path); err != nil {
// Permanent model errors (invalid model, auth failure, rate limit)
// abort the whole run: every remaining file would trigger the
// same failing requests. Returning the error cancels gctx, which
// stops the other indexing goroutines.
if isIndexingAborted(err) || gctx.Err() != nil {
slog.ErrorContext(ctx, "Aborting indexing", "path", status.path, "error", err)
return err
}
slog.ErrorContext(ctx, "Failed to index file", "path", status.path, "error", err)
// Don't return error - continue indexing other files
// Transient/local failure - continue indexing other files
return nil
}

Expand All @@ -356,6 +364,7 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin

// Wait for all files to be indexed
if err := g.Wait(); err != nil {
s.emitEvent(types.Event{Type: types.EventTypeError, Error: err})
return err
}

Expand Down Expand Up @@ -429,6 +438,9 @@ func (s *VectorStore) CheckAndReindexChangedFiles(ctx context.Context, docPaths
if needsIndexing {
slog.InfoContext(ctx, "File changed, re-indexing", "path", filePath)
if err := s.indexFile(ctx, filePath); err != nil {
if isIndexingAborted(err) {
return fmt.Errorf("failed to re-index file %s: %w", filePath, err)
}
slog.ErrorContext(ctx, "Failed to re-index file", "path", filePath, "error", err)
}
}
Expand Down Expand Up @@ -597,7 +609,7 @@ func (s *VectorStore) indexFile(ctx context.Context, filePath string) error {

embeddings, err := s.embedder.EmbedBatch(ctx, chunkContents)
if err != nil {
return fmt.Errorf("failed to generate embeddings: %w", err)
return fmt.Errorf("failed to generate embeddings: %w", classifyModelCallError(err))
}

if len(embeddings) != len(validChunks) {
Expand Down Expand Up @@ -675,6 +687,15 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string,
}

text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(gctx, filePath, ch)
if berr != nil {
// Permanent model errors abort the run instead of silently
// falling back: every remaining chunk would issue the same
// failing LLM request. Returning cancels gctx, stopping
// the other chunk builds.
if cerr := classifyModelCallError(berr); isIndexingAborted(cerr) {
return fmt.Errorf("failed to build embedding input for chunk %d of %s: %w", ch.Index, filePath, cerr)
}
}
if berr != nil || strings.TrimSpace(text) == "" {
slog.WarnContext(ctx, "Embedding input builder failed; falling back to raw chunk content",
"strategy", s.name,
Expand All @@ -701,6 +722,11 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string,
}

text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(ctx, filePath, ch)
if berr != nil {
if cerr := classifyModelCallError(berr); isIndexingAborted(cerr) {
return nil, fmt.Errorf("failed to build embedding input for chunk %d of %s: %w", ch.Index, filePath, cerr)
}
}
if berr != nil || strings.TrimSpace(text) == "" {
slog.WarnContext(ctx, "Embedding input builder failed; falling back to raw chunk content",
"strategy", s.name,
Expand Down Expand Up @@ -915,6 +941,10 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) {
Message: "Failed to re-index: " + filepath.Base(file),
Error: err,
})
if isIndexingAborted(err) {
slog.ErrorContext(ctx, "Stopping re-indexing due to non-retryable model error", "strategy", s.name, "error", err)
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] watchLoop keeps running after abort — re-index will be retried on every subsequent filesystem event

The break on line 946 exits the inner for i, file := range filesToReindex loop inside the processChanges closure, stopping the current batch. However, watchLoop's outer for { select { ... } } loop continues running. On the next filesystem write/create event, processChanges is called again, attempts to re-index with the same broken model, hits the same non-retryable error, emits another error event, and breaks again.

This partially defeats the flood-prevention goal for the watcher path: the number of doomed requests is reduced from all files per event to one request per event, but remains unbounded over time. The Initialize and CheckAndReindexChangedFiles paths are correctly fixed (abort propagates to the caller), but the watcher path has no persistent "disabled" state equivalent to rerankDisabled.Store(true).

Suggested fix: introduce a persistent indexingDisabled atomic.Bool on VectorStore (similar to Manager.rerankDisabled) and check it at the top of processChanges (or at the start of watchLoop) after storing true on abort. Alternatively, cancel the watcher's context on abort.

}
}
}

Expand Down
Loading
Loading