您现在的位置是:首页 > 博文答疑 > faiss embedding 例子博文答疑

faiss embedding 例子

Leo2025-05-09【10】

简介

 

import os
import json
import numpy as np
import faiss
import openai
from dotenv import load_dotenv
from typing import List, Dict, Any, Optional, Union
import re

# 加载环境变量
load_dotenv()

# 配置OpenAI API
openai.api_key = os.getenv("OPENAI_API_KEY")

class FAISSIndex:
    """FAISS向量索引操作类"""
    
    def __init__(self, storage_path: str = "./index", index_type: str = "flat_l2", dimension: int = 1536):
        """初始化FAISS索引"""
        self.storage_path = storage_path
        self.index_type = index_type
        self.dimension = dimension
        self.index = None
        
        # 确保存储目录存在
        os.makedirs(storage_path, exist_ok=True)
        
        # 加载已存在的索引
        self._load_existing_index()
    
    def _load_existing_index(self):
        """加载已存在的FAISS索引"""
        index_path = os.path.join(self.storage_path, "index.faiss")
        
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print(f"已加载现有FAISS索引: {self.index.ntotal} 条记录")
        else:
            # 创建新索引
            self._create_index()
            print("创建新的FAISS索引")
    
    def _create_index(self):
        """根据指定类型创建FAISS索引"""
        if self.index_type == "flat_l2":
            # 精确L2距离搜索
            self.index = faiss.IndexFlatL2(self.dimension)
        elif self.index_type == "hnsw":
            # 基于图的近似搜索,速度更快
            self.index = faiss.IndexHNSWFlat(self.dimension, 32)
            self.index.hnsw.efConstruction = 40
        elif self.index_type == "ivfflat":
            # 倒排索引,适合大规模数据
            nlist = 100
            quantizer = faiss.IndexFlatL2(self.dimension)
            self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
        else:
            raise ValueError(f"不支持的索引类型: {self.index_type}")
    
    def add_vectors(self, vectors: np.ndarray):
        """添加向量到索引"""
        if not isinstance(vectors, np.ndarray):
            vectors = np.array(vectors, dtype=np.float32)
        
        # 添加到索引
        if self.index.is_trained:
            self.index.add(vectors)
        else:
            # IVF类型索引需要先训练
            self.index.train(vectors)
            self.index.add(vectors)
        
        # 保存索引
        self._save_index()
    
    def search(self, query_vector: np.ndarray, top_k: int = 5) -> (np.ndarray, np.ndarray):
        """搜索相似向量"""
        if self.index.ntotal == 0:
            return np.array([]), np.array([])
        
        if not isinstance(query_vector, np.ndarray):
            query_vector = np.array(query_vector, dtype=np.float32)
        
        # 确保查询向量是二维的
        if query_vector.ndim == 1:
            query_vector = query_vector.reshape(1, -1)
        
        # 搜索
        distances, indices = self.index.search(query_vector, top_k)
        return distances, indices
    
    def _save_index(self):
        """保存FAISS索引到文件"""
        index_path = os.path.join(self.storage_path, "index.faiss")
        faiss.write_index(self.index, index_path)

class OpenAIAgent:
    """OpenAI API操作类"""
    
    def __init__(self, api_key: Optional[str] = None, embedding_model: str = "text-embedding-ada-002", 
                 chat_model: str = "gpt-3.5-turbo"):
        """初始化OpenAI代理"""
        if api_key:
            openai.api_key = api_key
        elif not openai.api_key:
            raise ValueError("必须提供OpenAI API密钥")
            
        self.embedding_model = embedding_model
        self.chat_model = chat_model
    
    def get_embedding(self, text: str) -> np.ndarray:
        """获取文本嵌入向量"""
        try:
            response = openai.Embedding.create(
                input=text,
                model=self.embedding_model
            )
            return np.array(response['data'][0]['embedding'], dtype=np.float32)
        except Exception as e:
            print(f"获取嵌入失败: {e}")
            raise
    
    def generate_answer(self, query: str, context: List[str], system_prompt: str = None) -> str:
        """结合上下文生成回答"""
        if not system_prompt:
            system_prompt = "你是一个知识渊博的助手,使用提供的知识库回答用户问题。"
        
        context_text = "\n\n".join([f"知识库内容 {i+1}: {item}" for i, item in enumerate(context)])
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"问题: {query}"},
            {"role": "assistant", "content": f"参考资料:\n{context_text}"}
        ]
        
        try:
            response = openai.ChatCompletion.create(
                model=self.chat_model,
                messages=messages,
                temperature=0.2
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"生成回答失败: {e}")
            return "抱歉,我无法回答这个问题。"

class FileManager:
    """文件读写操作类"""
    
    def __init__(self, storage_path: str = "./data"):
        """初始化文件管理器"""
        self.storage_path = storage_path
        
        # 确保存储目录存在
        os.makedirs(storage_path, exist_ok=True)
    
    def read_text_file(self, file_path: str) -> str:
        """读取文本文件内容"""
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"文件不存在: {file_path}")
        
        with open(file_path, "r", encoding="utf-8") as f:
            return f.read()
    
    def write_text_file(self, file_path: str, content: str):
        """写入文本文件"""
        full_path = os.path.join(self.storage_path, file_path)
        dir_path = os.path.dirname(full_path)
        
        # 确保目录存在
        os.makedirs(dir_path, exist_ok=True)
        
        with open(full_path, "w", encoding="utf-8") as f:
            f.write(content)
    
    def read_json_file(self, file_path: str) -> Any:
        """读取JSON文件"""
        full_path = os.path.join(self.storage_path, file_path)
        
        if not os.path.exists(full_path):
            return []
        
        with open(full_path, "r", encoding="utf-8") as f:
            return json.load(f)
    
    def write_json_file(self, file_path: str, data: Any):
        """写入JSON文件"""
        full_path = os.path.join(self.storage_path, file_path)
        dir_path = os.path.dirname(full_path)
        
        # 确保目录存在
        os.makedirs(dir_path, exist_ok=True)
        
        with open(full_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
    
    def split_text_into_chunks(self, text: str, chunk_size: int = 500, overlap: int = 100) -> List[str]:
        """将文本分割成有重叠的小块"""
        if not text:
            return []
        
        # 按段落分割
        paragraphs = re.split(r'\n\s*\n', text)
        chunks = []
        current_chunk = ""
        
        for para in paragraphs:
            # 如果添加当前段落会超过块大小,则创建新块
            if len(current_chunk) + len(para) > chunk_size:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                
                # 如果段落本身很长,则将其分割
                if len(para) > chunk_size:
                    # 尝试在句子结束处分割
                    sentences = re.split(r'(?<=[.!?])\s+', para)
                    sub_chunk = ""
                    for sent in sentences:
                        if len(sub_chunk) + len(sent) > chunk_size:
                            if sub_chunk:
                                chunks.append(sub_chunk.strip())
                            sub_chunk = sent
                        else:
                            sub_chunk += " " + sent if sub_chunk else sent
                    
                    if sub_chunk:
                        chunks.append(sub_chunk.strip())
                else:
                    # 开始新块,包含一些重叠
                    current_chunk = para
            else:
                current_chunk += " " + para if current_chunk else para
        
        # 添加最后一个块
        if current_chunk and current_chunk not in chunks:
            chunks.append(current_chunk.strip())
        
        # 应用重叠
        final_chunks = []
        for i in range(len(chunks)):
            if i == 0 or overlap <= 0:
                final_chunks.append(chunks[i])
            else:
                # 获取前一个块的最后部分作为重叠
                prev_chunk = chunks[i-1]
                overlap_text = prev_chunk[-overlap:] if len(prev_chunk) > overlap else prev_chunk
                final_chunks.append(overlap_text + " " + chunks[i])
        
        return final_chunks

class KnowledgeBase:
    """知识库主类,协调各功能模块"""
    
    def __init__(self, storage_path: str = "./knowledge_base", index_type: str = "flat_l2"):
        """初始化知识库"""
        self.storage_path = storage_path
        
        # 初始化各功能模块
        self.file_manager = FileManager(os.path.join(storage_path, "data"))
        self.faiss_index = FAISSIndex(os.path.join(storage_path, "index"), index_type)
        self.openai_agent = OpenAIAgent()
        
        # 加载知识数据
        self.texts = self.file_manager.read_json_file("texts.json")
        self.metadata = self.file_manager.read_json_file("metadata.json")
        
        print(f"知识库已初始化: {len(self.texts)} 条记录")
    
    def add_knowledge(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None):
        """添加知识到知识库"""
        if not texts:
            return
        
        # 确保元数据长度匹配
        if metadata is None:
            metadata = [{} for _ in texts]
        elif len(metadata) != len(texts):
            raise ValueError("文本和元数据数量不匹配")
        
        # 获取文本嵌入
        embeddings = np.array([self.openai_agent.get_embedding(text) for text in texts])
        
        # 添加到FAISS索引
        self.faiss_index.add_vectors(embeddings)
        
        # 更新文本和元数据
        self.texts.extend(texts)
        self.metadata.extend(metadata)
        
        # 保存数据
        self.file_manager.write_json_file("texts.json", self.texts)
        self.file_manager.write_json_file("metadata.json", self.metadata)
        
        print(f"成功添加 {len(texts)} 条知识")
    
    def import_from_text_file(self, file_path: str, chunk_size: int = 500, overlap: int = 100, 
                              metadata: Optional[Dict[str, Any]] = None) -> int:
        """从文本文件导入知识"""
        # 读取文件内容
        content = self.file_manager.read_text_file(file_path)
        
        # 分割文本
        chunks = self.file_manager.split_text_into_chunks(content, chunk_size, overlap)
        
        # 添加元数据
        if metadata is None:
            metadata = {}
        
        file_metadata = metadata.copy()
        file_metadata["source"] = file_path
        
        # 为每个块创建元数据
        chunk_metadata = [
            {**file_metadata, "chunk_index": i, "total_chunks": len(chunks)}
            for i in range(len(chunks))
        ]
        
        # 添加到知识库
        self.add_knowledge(chunks, chunk_metadata)
        return len(chunks)
    
    def search_knowledge(self, query: str, top_k: int = 5, threshold: float = 0.8) -> List[Dict[str, Any]]:
        """搜索相关知识"""
        # 获取查询嵌入
        query_embedding = self.openai_agent.get_embedding(query)
        
        # 搜索相似向量
        distances, indices = self.faiss_index.search(query_embedding, top_k)
        
        # 处理结果
        results = []
        for i, idx in enumerate(indices[0]):
            if idx == -1 or idx >= len(self.texts):  # 无效索引
                continue
                
            distance = distances[0][i]
            similarity = 1 / (1 + distance)  # 将L2距离转换为相似度
            
            if similarity >= threshold:
                results.append({
                    "text": self.texts[idx],
                    "metadata": self.metadata[idx],
                    "similarity": similarity,
                    "index": idx
                })
        
        # 按相似度排序
        return sorted(results, key=lambda x: x["similarity"], reverse=True)
    
    def generate_answer(self, query: str, top_k: int = 3) -> str:
        """生成回答"""
        # 搜索相关知识
        results = self.search_knowledge(query, top_k)
        context = [item["text"] for item in results]
        
        # 生成回答
        return self.openai_agent.generate_answer(query, context)

# 使用示例
if __name__ == "__main__":
    # 初始化知识库
    kb = KnowledgeBase(storage_path="./my_knowledge", index_type="flat_l2")
    
    # 从文本文件导入知识
    file_path = "example.txt"  # 替换为你的文本文件路径
    
    if os.path.exists(file_path):
        print(f"从文件 {file_path} 导入知识...")
        kb.import_from_text_file(
            file_path, 
            chunk_size=500, 
            overlap=100,
            metadata={"category": "文档"}
        )
    
    # 也可以添加示例知识
    example_knowledge = [
        "Python是一种高级编程语言,由Guido van Rossum于1991年开发。",
        "人工智能是指使计算机系统能够执行需要人类智能才能完成的任务的技术。",
        "机器学习是人工智能的一个子集,专注于使计算机能够从数据中学习而无需明确编程。",
        "向量数据库是一种专门用于存储和查询向量数据的数据库,常用于相似度搜索。",
        "FAISS是一个由Facebook开发的库,用于高效相似性搜索和密集向量聚类。"
    ]
    
    kb.add_knowledge(example_knowledge)
    
    # 用户查询
    user_query = "什么是FAISS?"
    
    # 生成回答
    answer = kb.generate_answer(user_query)
    
    print(f"问题: {user_query}")
    print(f"回答: {answer}")
    
    # 显示参考资料
    results = kb.search_knowledge(user_query, top_k=3)
    if results:
        print("\n参考资料:")
        for i, result in enumerate(results):
            print(f"{i+1}. {result['text']} (相似度: {result['similarity']:.4f})")