RAG
Lets explore how to use the EmbeddingStreamer to manage embeddings, track metrics, and generate responses.
1
Install the dependencies
pipx install incribo faiss-cpu requests beautifulsoup4
2
Import all dependencies
rag.py
import time
import logging
from typing import List, Tuple
import faiss
import numpy as np
import torch
from incribo import EmbeddingStreamer, EmbeddingComparator, Embedding
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
3
Setup the Embedding model
rag.py
# Initialize sentence transformer for embeddings
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize language model for RAG (using a more advanced model)
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Embedding function
def create_embedding(text: str) -> np.ndarray:
try:
return sentence_model.encode(text, show_progress_bar=False)
except Exception as e:
logger.error(f"Error creating embedding: {e}")
return np.zeros(384) # Return zero vector in case of error
4
Initialize EmbeddingStreamer
rag.py
# Initialize EmbeddingStreamer
streamer = EmbeddingStreamer(create_embedding)
# Function to measure latency
def measure_latency(func):
start_time = time.time()
result = func()
end_time = time.time()
return result, end_time - start_time
# Class for managing Faiss index
class FaissIndex:
def __init__(self, dimension: int):
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.id_to_text = {}
def add_embedding(self, id: str, embedding: np.ndarray, text: str):
self.index.add(np.array([embedding]))
self.id_to_text[self.index.ntotal - 1] = (id, text)
def search(self, query_embedding: np.ndarray, k: int) -> List[Tuple[str, str, float]]:
distances, indices = self.index.search(np.array([query_embedding]), k)
results = []
for i, idx in enumerate(indices[0]):
if idx != -1:
id, text = self.id_to_text[idx]
results.append((id, text, distances[0][i]))
return results
5
Generate a response
rag.py
# Function to generate response
def generate_response(context: str, query: str) -> str:
try:
input_text = f"Context: {context}\nQuery: {query}\nAnswer:"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7
)
return tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
logger.error(f"Error generating response: {e}")
return "I'm sorry, but I couldn't generate a response at this time."
6
Optional: Processing larger datasets
rag.py
import requests
from bs4 import BeautifulSoup
def fetch_wikipedia_articles(num_articles: int) -> List[str]:
articles = []
while len(articles) < num_articles:
response = requests.get("https://en.wikipedia.org/wiki/Special:Random")
soup = BeautifulSoup(response.content, 'html.parser')
paragraphs = soup.find_all('p')
article_text = ' '.join([p.text for p in paragraphs if len(p.text) > 100])
if article_text:
articles.append(article_text[:1000]) # Limit to first 1000 characters
return articles
logger.info("Fetching Wikipedia articles...")
documents = fetch_wikipedia_articles(1000)
logger.info(f"Fetched {len(documents)} articles")
7
Optional: Compare Embeddings
rag.py
# Compare embeddings
logger.info("Comparing embeddings...")
comparator = EmbeddingComparator(None)
for i in range(min(100, len(documents))): # Limit to first 100 documents for efficiency
embedding = streamer.get_embedding(f"doc{i}")['vector']
comparator.add_embedding(Embedding(embedding, f"doc{i}"))
comparison_results = comparator.compare()
logger.info("Embedding Comparison Results:")
logger.info(comparison_results)
best_model = comparator.get_best_model()
logger.info(f"Best performing document: {best_model}")