Off-Topic Guardrails: Embeddings and KNN
In this post, we’ll explore how to build a straightforward off-topic detector using embeddings and K-Nearest Neighbors (KNN).
An off-topic guardrail is essential for filtering out queries that don’t align with the intended purpose of an LLM application. For instance, in an LLM application focused on car sales, we wouldn’t want it to start discussing Python code.
Since this guardrail will be integrated into the main LLM application, it must be fast and simple to implement, ensuring it doesn’t introduce significant latency or complexity.
The key lies in using embeddings to capture the semantic meaning of the input text. By converting both the input query and a set of predefined on-topic and off-topic texts into embeddings, we can represent them in a high-dimensional space and analyze their relative positions. In this example, we use OpenAI’s text-embedding-3-small
, though any reasonably performant embedding model can be substituted.
The goal is to determine whether the input query closely relates to the topics of interest or if it ventures into “off-topic” territory.
This process essentially becomes a binary classification problem, where we leverage K-Nearest Neighbors (KNN) to classify the input text. Using the Annoy library, we can efficiently find the closest embeddings in our dataset. By analyzing the nearest neighbors and their associated labels, we calculate weighted probabilities to determine whether the input text is likely on-topic or off-topic.
Code implementation
import numpy as np
import argparse
from openai import OpenAI
from annoy import AnnoyIndex
# Categories for on_topic and off_topic texts
= [
ON_TOPIC_TEXTS "What is the capital of China?",
"What is the currency of UK?",
"Timezone for New York?",
"Which country has the largest population?",
"What is the largest island in the world?",
]
= [
OFF_TOPIC_TEXTS "Write a python code",
"Explain the meaning of life",
"Why is the sky blue?",
]
# Parse command-line arguments
= argparse.ArgumentParser(description="Classify whether text is off-topic.")
parser "input_text", type=str, help="Input text to classify.")
parser.add_argument(= parser.parse_args()
args
# Initialize OpenAI client
= OpenAI()
client
# Input text from CLI
= args.input_text
input_text
# Combine all texts for embedding
= [input_text] + ON_TOPIC_TEXTS + OFF_TOPIC_TEXTS
all_texts
# Get embeddings
= client.embeddings.create(
response ="text-embedding-3-small",
modelinput=all_texts,
)
# Extract embeddings
= response.data[0].embedding
input_embedding = [response.data[i].embedding for i in range(1, len(ON_TOPIC_TEXTS) + 1)]
on_topic_embeddings = [response.data[i].embedding for i in range(len(ON_TOPIC_TEXTS) + 1, len(all_texts))]
off_topic_embeddings
# Prepare data for Annoy
= on_topic_embeddings + off_topic_embeddings
embeddings = ["on_topic"] * len(on_topic_embeddings) + ["off_topic"] * len(off_topic_embeddings)
labels
# Initialize Annoy index
= len(input_embedding)
f = AnnoyIndex(f, 'angular') # Using 'angular' for cosine similarity
annoy_index
# Add items to Annoy index
for i, emb in enumerate(embeddings):
annoy_index.add_item(i, emb)
# Build Annoy index
10) # Number of trees can be adjusted
annoy_index.build(
# Define a function to calculate weighted probabilities
def get_weighted_probabilities(annoy_index, labels, query_vector, k=3):
# Get k nearest neighbors with distances
= annoy_index.get_nns_by_vector(query_vector, k, include_distances=True)
neighbors, distances
# Inverse distance weights (adding small epsilon to avoid division by zero)
= 1 / (np.array(distances) + 1e-8)
weights
# Calculate weighted probabilities
= {}
class_weights for i, label in enumerate([labels[n] for n in neighbors]):
if label in class_weights:
+= weights[i]
class_weights[label] else:
= weights[i]
class_weights[label]
= sum(class_weights.values())
total_weight = {label: weight / total_weight for label, weight in class_weights.items()}
probabilities return probabilities
# Classify input text and calculate weighted probability
= get_weighted_probabilities(annoy_index, labels, input_embedding, k=3)
probabilities = max(probabilities, key=probabilities.get)
classification
# Output classification and probability
print(f"Input text classification: {classification}")
print(f"Probability off-topic: {probabilities.get('off_topic', 0)}")