Cluster-Aware Retrieval for RAG Systems

November 15, 2024

Most RAG systems treat embedding spaces as flat, uniform distributions. They’re not. Real knowledge bases contain distinct semantic clusters: database docs, frontend frameworks, DevOps practices, each with different internal structure. Ignoring this wastes retrieval precision.

The Problem with Flat Retrieval

A query about “React hooks optimization” should pull from the frontend cluster, not equally consider database or infrastructure docs that happen to share semantic overlap. Standard cosine similarity doesn’t care about topical boundaries. You get results that are individually relevant but collectively unfocused.

Modeling Clusters with GMM

Gaussian Mixture Models assume your embeddings arise from \(K\) underlying Gaussian distributions:

$$p(v) = \sum_{k=1}^K \pi_k \mathcal{N}(v \mid \mu_k, \Sigma_k)$$

For a query \(q\), compute the posterior probability of each cluster:

$$p(k \mid q) = \frac{\pi_k \mathcal{N}(q \mid \mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j \mathcal{N}(q \mid \mu_j, \Sigma_j)}$$

This gives you soft assignments: the probability that a query belongs to each semantic cluster.

Two-Stage Retrieval

  1. Cluster selection: Pick cluster(s) with highest \(p(k \mid q)\). Take top-2 for ambiguous queries.
  2. Intra-cluster retrieval: Run k-NN within selected clusters.

The cluster boundaries act as a soft filter, avoiding the “dilution effect” where off-topic documents dominate results.

Mahalanobis Distance Per Cluster

Here’s the underexplored part: different clusters can use different distance metrics. For a cluster modeled as \(\mathcal{N}(\mu_k, \Sigma_k)\), the Mahalanobis distance accounts for the cluster’s shape:

$$d_{\text{Mah}}(q, v) = \sqrt{(q - v)^T \Sigma_k^{-1} (q - v)}$$

Elongated clusters in certain semantic directions get stretched appropriately. Cosine similarity treats all directions equally. Mahalanobis adapts.

Clusters as Agent Tools

In agentic RAG, each cluster becomes a tool the agent can invoke:

tools = [
    ClusterRetrievalTool(cluster_id=k, name=f"Search {topic_k}")
    for k in range(K)
]

The agent decides which clusters to search and in what order:

  • Query: “How does React’s context API compare to Redux?”
  • Agent plan:
    1. Search frontend cluster for React context
    2. Search state management cluster for Redux patterns
    3. Synthesize comparison

This beats flat retrieval for cross-topic synthesis.

Implementation

Fit GMM offline on document embeddings:

from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(n_components=K, covariance_type='full')
gmm.fit(document_embeddings)

# For query q:
cluster_probs = gmm.predict_proba(q.reshape(1, -1))[0]
selected_clusters = cluster_probs.argsort()[-2:][::-1]  # top-2

Store cluster assignments as metadata in your vector DB:

results = vector_db.query(
    query_embedding=q,
    filter={"cluster_id": {"$in": selected_clusters}},
    top_k=20
)

Key decisions:

  • Number of clusters: Use BIC/AIC or domain knowledge
  • Regularization: Add \(\lambda I\) to covariance matrices to prevent singularities
  • Initialization: k-means++ for better convergence

When It Helps

  • Topically diverse corpora: Multi-product docs, cross-domain papers
  • Single-topic queries: Clear primary topic to route to
  • Noise reduction: Distant-but-similar content diluting results

When it doesn’t:

Read More