A Simple Off-Topic Guardrail

code
Author

Gabriel Chua

Published

August 17, 2024

Off-Topic Guardrails: Embeddings and KNN

In this post, we’ll explore how to build a straightforward off-topic detector using embeddings and K-Nearest Neighbors (KNN).

An off-topic guardrail is essential for filtering out queries that don’t align with the intended purpose of an LLM application. For instance, in an LLM application focused on car sales, we wouldn’t want it to start discussing Python code.

Since this guardrail will be integrated into the main LLM application, it must be fast and simple to implement, ensuring it doesn’t introduce significant latency or complexity.

The key lies in using embeddings to capture the semantic meaning of the input text. By converting both the input query and a set of predefined on-topic and off-topic texts into embeddings, we can represent them in a high-dimensional space and analyze their relative positions. In this example, we use OpenAI’s text-embedding-3-small, though any reasonably performant embedding model can be substituted.

The goal is to determine whether the input query closely relates to the topics of interest or if it ventures into “off-topic” territory.

This process essentially becomes a binary classification problem, where we leverage K-Nearest Neighbors (KNN) to classify the input text. Using the Annoy library, we can efficiently find the closest embeddings in our dataset. By analyzing the nearest neighbors and their associated labels, we calculate weighted probabilities to determine whether the input text is likely on-topic or off-topic.

Code implementation

import numpy as np
import argparse
from openai import OpenAI
from annoy import AnnoyIndex

# Categories for on_topic and off_topic texts
ON_TOPIC_TEXTS = [
    "What is the capital of China?",
    "What is the currency of UK?",
    "Timezone for New York?",
    "Which country has the largest population?",
    "What is the largest island in the world?",
]

OFF_TOPIC_TEXTS = [
    "Write a python code",
    "Explain the meaning of life",
    "Why is the sky blue?",
]

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Classify whether text is off-topic.")
parser.add_argument("input_text", type=str, help="Input text to classify.")
args = parser.parse_args()

# Initialize OpenAI client
client = OpenAI()

# Input text from CLI
input_text = args.input_text

# Combine all texts for embedding
all_texts = [input_text] + ON_TOPIC_TEXTS + OFF_TOPIC_TEXTS

# Get embeddings
response = client.embeddings.create(
    model="text-embedding-3-small",
    input=all_texts,
)

# Extract embeddings
input_embedding = response.data[0].embedding
on_topic_embeddings = [response.data[i].embedding for i in range(1, len(ON_TOPIC_TEXTS) + 1)]
off_topic_embeddings = [response.data[i].embedding for i in range(len(ON_TOPIC_TEXTS) + 1, len(all_texts))]

# Prepare data for Annoy
embeddings = on_topic_embeddings + off_topic_embeddings
labels = ["on_topic"] * len(on_topic_embeddings) + ["off_topic"] * len(off_topic_embeddings)

# Initialize Annoy index
f = len(input_embedding)
annoy_index = AnnoyIndex(f, 'angular')  # Using 'angular' for cosine similarity

# Add items to Annoy index
for i, emb in enumerate(embeddings):
    annoy_index.add_item(i, emb)

# Build Annoy index
annoy_index.build(10)  # Number of trees can be adjusted

# Define a function to calculate weighted probabilities
def get_weighted_probabilities(annoy_index, labels, query_vector, k=3):
    # Get k nearest neighbors with distances
    neighbors, distances = annoy_index.get_nns_by_vector(query_vector, k, include_distances=True)
    
    # Inverse distance weights (adding small epsilon to avoid division by zero)
    weights = 1 / (np.array(distances) + 1e-8)
    
    # Calculate weighted probabilities
    class_weights = {}
    for i, label in enumerate([labels[n] for n in neighbors]):
        if label in class_weights:
            class_weights[label] += weights[i]
        else:
            class_weights[label] = weights[i]
    
    total_weight = sum(class_weights.values())
    probabilities = {label: weight / total_weight for label, weight in class_weights.items()}
    return probabilities

# Classify input text and calculate weighted probability
probabilities = get_weighted_probabilities(annoy_index, labels, input_embedding, k=3)
classification = max(probabilities, key=probabilities.get)

# Output classification and probability
print(f"Input text classification: {classification}")
print(f"Probability off-topic: {probabilities.get('off_topic', 0)}")