diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index 3a68f6e8d..bbd918fb0 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -9,8 +9,10 @@ import ( "os" "path/filepath" "slices" + "sync/atomic" "time" + "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/fusion" "github.com/docker/docker-agent/pkg/rag/rerank" @@ -60,6 +62,7 @@ type Manager struct { strategyConfigs map[string]strategy.Config // Store configs for per-strategy operations fusion fusion.Fusion // Fusion strategy for combining multi-strategy results reranker rerank.Reranker // Optional reranker for result re-scoring + rerankDisabled atomic.Bool // Set after a non-retryable reranking error to stop doomed requests events <-chan types.Event // Shared event channel from strategies and other RAG operations } @@ -243,30 +246,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "num_results", len(results)) // Apply reranking if configured - if m.reranker != nil { - beforeCount := len(results) - slog.DebugContext(ctx, "[RAG Manager] Applying reranking to single-strategy results", - "rag_name", m.name, - "strategy", strategyName, - "result_count_before", beforeCount) - - rerankedResults, rerankErr := m.reranker.Rerank(ctx, query, results) - if rerankErr != nil { - slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original results", - "rag_name", m.name, - "strategy", strategyName, - "error", rerankErr) - // Continue with original results rather than failing completely - } else { - results = rerankedResults - slog.DebugContext(ctx, "[RAG Manager] Reranked single-strategy results", - "rag_name", m.name, - "strategy", strategyName, - "result_count_before", beforeCount, - "result_count_after", len(results), - "filtered", beforeCount-len(results)) - } - } + results = m.rerank(ctx, query, results) if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { slog.DebugContext(ctx, "[RAG Manager] Truncating to global result limit", @@ -373,27 +353,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "result_limit", m.config.Results.Limit) // Apply reranking if configured (before limit and deduplication) - if m.reranker != nil { - beforeCount := len(fusedResults) - slog.DebugContext(ctx, "[RAG Manager] Applying reranking to fused results", - "rag_name", m.name, - "result_count_before", beforeCount) - - rerankedResults, rerankErr := m.reranker.Rerank(ctx, query, fusedResults) - if rerankErr != nil { - slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original fused results", - "rag_name", m.name, - "error", rerankErr) - // Continue with original fused results rather than failing completely - } else { - fusedResults = rerankedResults - slog.DebugContext(ctx, "[RAG Manager] Reranked fused results", - "rag_name", m.name, - "result_count_before", beforeCount, - "result_count_after", len(fusedResults), - "filtered", beforeCount-len(fusedResults)) - } - } + fusedResults = m.rerank(ctx, query, fusedResults) // Apply result limit if configured if limit := m.config.Results.Limit; limit > 0 && len(fusedResults) > limit { @@ -430,6 +390,45 @@ func getStrategyNames(stratMap map[string]strategy.Strategy) []string { return slices.Collect(maps.Keys(stratMap)) } +// rerank applies the configured reranker to results, falling back to the +// original results on failure. After a non-retryable model error (e.g. an +// invalid reranking model name), the reranker is disabled for the lifetime +// of the manager so every subsequent query doesn't issue another request +// that is guaranteed to fail (see issue #3082). +func (m *Manager) rerank(ctx context.Context, query string, results []database.SearchResult) []database.SearchResult { + if m.reranker == nil || m.rerankDisabled.Load() || len(results) == 0 { + return results + } + + beforeCount := len(results) + rerankedResults, err := m.reranker.Rerank(ctx, query, results) + if err == nil { + slog.DebugContext(ctx, "[RAG Manager] Reranked results", + "rag_name", m.name, + "result_count_before", beforeCount, + "result_count_after", len(rerankedResults), + "filtered", beforeCount-len(rerankedResults)) + return rerankedResults + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return results + } + + if retryable, _, _ := modelerrors.ClassifyModelError(err); !retryable { + m.rerankDisabled.Store(true) + slog.ErrorContext(ctx, "[RAG Manager] Disabling reranking after non-retryable error; check the reranking model configuration", + "rag_name", m.name, + "error", err) + return results + } + + slog.WarnContext(ctx, "[RAG Manager] Reranking failed, using original results", + "rag_name", m.name, + "error", err) + return results +} + // CheckAndReindexChangedFiles checks for file changes and re-indexes if needed func (m *Manager) CheckAndReindexChangedFiles(ctx context.Context) error { for strategyName, strategyImpl := range m.strategies { diff --git a/pkg/rag/manager_rerank_test.go b/pkg/rag/manager_rerank_test.go new file mode 100644 index 000000000..14a414344 --- /dev/null +++ b/pkg/rag/manager_rerank_test.go @@ -0,0 +1,120 @@ +package rag + +import ( + "context" + "errors" + "net/http" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/rag/database" + "github.com/docker/docker-agent/pkg/rag/strategy" +) + +// failingReranker counts calls and always fails with a fixed error. +type failingReranker struct { + calls atomic.Int64 + err error +} + +func (r *failingReranker) Rerank(context.Context, string, []database.SearchResult) ([]database.SearchResult, error) { + r.calls.Add(1) + return nil, r.err +} + +// staticStrategy returns fixed results for every query. +type staticStrategy struct { + results []database.SearchResult +} + +func (s *staticStrategy) Initialize(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *staticStrategy) Query(context.Context, string, int, float64) ([]database.SearchResult, error) { + return s.results, nil +} + +func (s *staticStrategy) CheckAndReindexChangedFiles(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *staticStrategy) StartFileWatcher(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *staticStrategy) Close() error { return nil } + +func newRerankTestManager(t *testing.T, rerankErr error) (*Manager, *failingReranker) { + t.Helper() + + results := []database.SearchResult{ + {Document: database.Document{ID: "1", Content: "doc one"}, Similarity: 0.9}, + {Document: database.Document{ID: "2", Content: "doc two"}, Similarity: 0.8}, + } + + reranker := &failingReranker{err: rerankErr} + cfg := Config{ + StrategyConfigs: []strategy.Config{{ + Name: "static", + Strategy: &staticStrategy{results: results}, + Limit: 5, + }}, + Results: ResultsConfig{ + RerankingConfig: &RerankingConfig{Reranker: reranker}, + }, + } + + m, err := New(t.Context(), "test", cfg, nil) + require.NoError(t, err) + return m, reranker +} + +func TestQueryDisablesRerankerAfterNonRetryableError(t *testing.T) { + rerankErr := &modelerrors.StatusError{ + StatusCode: http.StatusNotFound, + Err: errors.New("not_found_error: model: claude-sonnet-4-7"), + } + m, reranker := newRerankTestManager(t, rerankErr) + + for range 3 { + results, err := m.Query(t.Context(), "some query") + require.NoError(t, err, "rerank failures must not fail the query") + assert.Len(t, results, 2, "original results are returned as fallback") + } + + assert.Equal(t, int64(1), reranker.calls.Load(), + "reranker must be disabled after the first non-retryable error instead of being called on every query") +} + +func TestQueryKeepsRerankerOnTransientError(t *testing.T) { + rerankErr := &modelerrors.StatusError{ + StatusCode: http.StatusInternalServerError, + Err: errors.New("server error"), + } + m, reranker := newRerankTestManager(t, rerankErr) + + for range 3 { + results, err := m.Query(t.Context(), "some query") + require.NoError(t, err) + assert.Len(t, results, 2) + } + + assert.Equal(t, int64(3), reranker.calls.Load(), + "transient errors should not disable the reranker") +} + +func TestQueryKeepsRerankerOnContextCancellation(t *testing.T) { + m, reranker := newRerankTestManager(t, context.Canceled) + + results, err := m.Query(t.Context(), "some query") + require.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, int64(1), reranker.calls.Load()) + assert.False(t, m.rerankDisabled.Load(), + "context cancellation must not permanently disable the reranker") +} diff --git a/pkg/rag/strategy/indexing_errors.go b/pkg/rag/strategy/indexing_errors.go new file mode 100644 index 000000000..50fd7ae29 --- /dev/null +++ b/pkg/rag/strategy/indexing_errors.go @@ -0,0 +1,42 @@ +package strategy + +import ( + "context" + "errors" + "fmt" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// errIndexingAborted marks a permanent model/provider failure (e.g. invalid +// model name, authentication failure, rate limit) encountered during indexing. +// When such an error occurs, the whole indexing run must stop immediately: +// every remaining file/chunk would trigger the same failing request, flooding +// the provider (see https://github.com/docker/docker-agent/issues/3082). +var errIndexingAborted = errors.New("indexing aborted due to non-retryable model error") + +// classifyModelCallError inspects an error returned by an embedding or LLM +// call made during indexing. Permanent failures are wrapped with +// errIndexingAborted so callers can abort the run; transient failures (5xx, +// timeouts) and context cancellation are returned unchanged so callers can +// skip the current file and continue. +func classifyModelCallError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return err + } + // Rate-limited (429) errors are also non-retryable here: continuing to + // index would keep hammering a provider that asked us to back off. + retryable, _, _ := modelerrors.ClassifyModelError(err) + if !retryable { + return fmt.Errorf("%w: %w", errIndexingAborted, err) + } + return err +} + +// isIndexingAborted reports whether err carries the errIndexingAborted marker. +func isIndexingAborted(err error) bool { + return errors.Is(err, errIndexingAborted) +} diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 9d4935b92..344c6a0cd 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -330,8 +330,16 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin // Index the file if err := s.indexFile(gctx, status.path); err != nil { + // Permanent model errors (invalid model, auth failure, rate limit) + // abort the whole run: every remaining file would trigger the + // same failing requests. Returning the error cancels gctx, which + // stops the other indexing goroutines. + if isIndexingAborted(err) || gctx.Err() != nil { + slog.ErrorContext(ctx, "Aborting indexing", "path", status.path, "error", err) + return err + } slog.ErrorContext(ctx, "Failed to index file", "path", status.path, "error", err) - // Don't return error - continue indexing other files + // Transient/local failure - continue indexing other files return nil } @@ -356,6 +364,7 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin // Wait for all files to be indexed if err := g.Wait(); err != nil { + s.emitEvent(types.Event{Type: types.EventTypeError, Error: err}) return err } @@ -429,6 +438,9 @@ func (s *VectorStore) CheckAndReindexChangedFiles(ctx context.Context, docPaths if needsIndexing { slog.InfoContext(ctx, "File changed, re-indexing", "path", filePath) if err := s.indexFile(ctx, filePath); err != nil { + if isIndexingAborted(err) { + return fmt.Errorf("failed to re-index file %s: %w", filePath, err) + } slog.ErrorContext(ctx, "Failed to re-index file", "path", filePath, "error", err) } } @@ -597,7 +609,7 @@ func (s *VectorStore) indexFile(ctx context.Context, filePath string) error { embeddings, err := s.embedder.EmbedBatch(ctx, chunkContents) if err != nil { - return fmt.Errorf("failed to generate embeddings: %w", err) + return fmt.Errorf("failed to generate embeddings: %w", classifyModelCallError(err)) } if len(embeddings) != len(validChunks) { @@ -675,6 +687,15 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string, } text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(gctx, filePath, ch) + if berr != nil { + // Permanent model errors abort the run instead of silently + // falling back: every remaining chunk would issue the same + // failing LLM request. Returning cancels gctx, stopping + // the other chunk builds. + if cerr := classifyModelCallError(berr); isIndexingAborted(cerr) { + return fmt.Errorf("failed to build embedding input for chunk %d of %s: %w", ch.Index, filePath, cerr) + } + } if berr != nil || strings.TrimSpace(text) == "" { slog.WarnContext(ctx, "Embedding input builder failed; falling back to raw chunk content", "strategy", s.name, @@ -701,6 +722,11 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string, } text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(ctx, filePath, ch) + if berr != nil { + if cerr := classifyModelCallError(berr); isIndexingAborted(cerr) { + return nil, fmt.Errorf("failed to build embedding input for chunk %d of %s: %w", ch.Index, filePath, cerr) + } + } if berr != nil || strings.TrimSpace(text) == "" { slog.WarnContext(ctx, "Embedding input builder failed; falling back to raw chunk content", "strategy", s.name, @@ -915,6 +941,10 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { Message: "Failed to re-index: " + filepath.Base(file), Error: err, }) + if isIndexingAborted(err) { + slog.ErrorContext(ctx, "Stopping re-indexing due to non-retryable model error", "strategy", s.name, "error", err) + break + } } } diff --git a/pkg/rag/strategy/vector_store_test.go b/pkg/rag/strategy/vector_store_test.go new file mode 100644 index 000000000..4c0a69030 --- /dev/null +++ b/pkg/rag/strategy/vector_store_test.go @@ -0,0 +1,231 @@ +package strategy + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/rag/chunk" + "github.com/docker/docker-agent/pkg/rag/database" + "github.com/docker/docker-agent/pkg/rag/embed" + "github.com/docker/docker-agent/pkg/tools" +) + +func TestClassifyModelCallError(t *testing.T) { + tests := []struct { + name string + err error + wantAborted bool + }{ + {name: "nil", err: nil, wantAborted: false}, + {name: "404 model not found", err: &modelerrors.StatusError{StatusCode: http.StatusNotFound, Err: errors.New("not_found_error: model: claude-sonnet-4-7")}, wantAborted: true}, + {name: "401 unauthorized", err: &modelerrors.StatusError{StatusCode: http.StatusUnauthorized, Err: errors.New("unauthorized")}, wantAborted: true}, + {name: "429 rate limited", err: &modelerrors.StatusError{StatusCode: http.StatusTooManyRequests, Err: errors.New("too many requests")}, wantAborted: true}, + {name: "500 server error", err: &modelerrors.StatusError{StatusCode: http.StatusInternalServerError, Err: errors.New("server error")}, wantAborted: false}, + {name: "timeout message", err: errors.New("request timeout"), wantAborted: false}, + {name: "context canceled", err: context.Canceled, wantAborted: false}, + {name: "context deadline", err: context.DeadlineExceeded, wantAborted: false}, + {name: "wrapped 404", err: fmt.Errorf("batch 1 failed: %w", &modelerrors.StatusError{StatusCode: http.StatusNotFound, Err: errors.New("no such model")}), wantAborted: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyModelCallError(tt.err) + assert.Equal(t, tt.wantAborted, isIndexingAborted(got)) + if tt.err != nil { + assert.ErrorIs(t, got, tt.err) + } + }) + } +} + +// fakeEmbeddingProvider counts embedding calls and always fails with a fixed error. +type fakeEmbeddingProvider struct { + calls atomic.Int64 + err error +} + +func (f *fakeEmbeddingProvider) ID() modelsdev.ID { return modelsdev.NewID("test", "fake-embed") } +func (f *fakeEmbeddingProvider) BaseConfig() base.Config { return base.Config{} } + +func (f *fakeEmbeddingProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return nil, errors.New("not implemented") +} + +func (f *fakeEmbeddingProvider) CreateEmbedding(context.Context, string) (*base.EmbeddingResult, error) { + f.calls.Add(1) + if f.err != nil { + return nil, f.err + } + return &base.EmbeddingResult{Embedding: []float64{0.1, 0.2}, TotalTokens: 1}, nil +} + +// fakeVectorDB is an in-memory vectorStoreDB used to drive indexing in tests. +type fakeVectorDB struct { + mu sync.Mutex + metadata map[string]database.FileMetadata +} + +func newFakeVectorDB() *fakeVectorDB { + return &fakeVectorDB{metadata: make(map[string]database.FileMetadata)} +} + +func (db *fakeVectorDB) AddDocumentWithEmbedding(context.Context, database.Document, []float64, string) error { + return nil +} + +func (db *fakeVectorDB) SearchSimilarVectors(context.Context, []float64, int) ([]VectorSearchResultData, error) { + return nil, nil +} + +func (db *fakeVectorDB) DeleteDocumentsByPath(context.Context, string) error { return nil } + +func (db *fakeVectorDB) GetFileMetadata(_ context.Context, sourcePath string) (*database.FileMetadata, error) { + db.mu.Lock() + defer db.mu.Unlock() + if meta, ok := db.metadata[sourcePath]; ok { + return &meta, nil + } + return nil, nil +} + +func (db *fakeVectorDB) SetFileMetadata(_ context.Context, meta database.FileMetadata) error { + db.mu.Lock() + defer db.mu.Unlock() + db.metadata[meta.SourcePath] = meta + return nil +} + +func (db *fakeVectorDB) GetAllFileMetadata(context.Context) ([]database.FileMetadata, error) { + db.mu.Lock() + defer db.mu.Unlock() + all := make([]database.FileMetadata, 0, len(db.metadata)) + for _, meta := range db.metadata { + all = append(all, meta) + } + return all, nil +} + +func (db *fakeVectorDB) DeleteFileMetadata(_ context.Context, sourcePath string) error { + db.mu.Lock() + defer db.mu.Unlock() + delete(db.metadata, sourcePath) + return nil +} + +func (db *fakeVectorDB) Close() error { return nil } + +func newTestVectorStore(t *testing.T, embedErr error) (*VectorStore, *fakeEmbeddingProvider, []string) { + t.Helper() + + dir := t.TempDir() + const fileCount = 5 + docPaths := make([]string, 0, fileCount) + for i := range fileCount { + path := filepath.Join(dir, fmt.Sprintf("doc%d.txt", i)) + require.NoError(t, os.WriteFile(path, fmt.Appendf(nil, "document %d content", i), 0o644)) + docPaths = append(docPaths, path) + } + + fake := &fakeEmbeddingProvider{err: embedErr} + store := NewVectorStore(VectorStoreConfig{ + Name: "test", + Database: newFakeVectorDB(), + Embedder: embed.New(fake), + EmbeddingConcurrency: 1, + FileIndexConcurrency: 1, + Chunking: ChunkingConfig{Size: 1024, Overlap: 0}, + }) + + return store, fake, docPaths +} + +func TestInitializeAbortsOnNonRetryableModelError(t *testing.T) { + embedErr := &modelerrors.StatusError{ + StatusCode: http.StatusNotFound, + Err: errors.New("not_found_error: model: claude-sonnet-4-7"), + } + store, fake, docPaths := newTestVectorStore(t, embedErr) + + err := store.Initialize(t.Context(), docPaths, ChunkingConfig{Size: 1024}) + require.Error(t, err) + assert.True(t, isIndexingAborted(err), "error should carry the abort marker") + assert.Equal(t, int64(1), fake.calls.Load(), + "indexing must stop after the first non-retryable model error instead of trying every file") +} + +func TestInitializeContinuesOnTransientModelError(t *testing.T) { + embedErr := &modelerrors.StatusError{ + StatusCode: http.StatusInternalServerError, + Err: errors.New("internal server error"), + } + store, fake, docPaths := newTestVectorStore(t, embedErr) + + err := store.Initialize(t.Context(), docPaths, ChunkingConfig{Size: 1024}) + require.NoError(t, err, "transient errors skip the file and keep indexing") + assert.Equal(t, int64(len(docPaths)), fake.calls.Load(), + "every file should still be attempted on transient errors") +} + +func TestCheckAndReindexAbortsOnNonRetryableModelError(t *testing.T) { + embedErr := &modelerrors.StatusError{ + StatusCode: http.StatusTooManyRequests, + Err: errors.New("too many requests"), + } + store, fake, docPaths := newTestVectorStore(t, embedErr) + + err := store.CheckAndReindexChangedFiles(t.Context(), docPaths, ChunkingConfig{Size: 1024}) + require.Error(t, err) + assert.True(t, isIndexingAborted(err)) + assert.Equal(t, int64(1), fake.calls.Load()) +} + +// abortingInputBuilder simulates a semantic-embeddings chat model that fails +// permanently (e.g. invalid chat_model name). +type abortingInputBuilder struct { + calls atomic.Int64 + err error +} + +func (b *abortingInputBuilder) BuildEmbeddingInput(context.Context, string, chunk.Chunk) (string, error) { + b.calls.Add(1) + return "", b.err +} + +func TestBuildEmbeddingInputsAbortsOnNonRetryableModelError(t *testing.T) { + store, fake, docPaths := newTestVectorStore(t, nil) + builder := &abortingInputBuilder{err: &modelerrors.StatusError{ + StatusCode: http.StatusNotFound, + Err: errors.New("not_found_error: model: claude-sonnet-4-7"), + }} + store.SetEmbeddingInputBuilder(builder) + + err := store.Initialize(t.Context(), docPaths, ChunkingConfig{Size: 1024}) + require.Error(t, err) + assert.True(t, isIndexingAborted(err), "permanent LLM errors must abort instead of falling back per chunk") + assert.Equal(t, int64(0), fake.calls.Load(), "no embedding requests should be sent after the abort") + assert.Equal(t, int64(1), builder.calls.Load()) +} + +func TestBuildEmbeddingInputsFallsBackOnTransientError(t *testing.T) { + store, fake, docPaths := newTestVectorStore(t, nil) + builder := &abortingInputBuilder{err: errors.New("request timeout")} + store.SetEmbeddingInputBuilder(builder) + + err := store.Initialize(t.Context(), docPaths, ChunkingConfig{Size: 1024}) + require.NoError(t, err, "transient LLM errors keep the raw-content fallback behavior") + assert.Equal(t, int64(len(docPaths)), fake.calls.Load()) +}