|
14 | 14 | from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
15 | 15 | from langchain_community.vectorstores import FAISS
|
16 | 16 |
|
| 17 | +from langchain_community.chat_models import ChatOllama |
| 18 | +from langchain_aws import BedrockEmbeddings, ChatBedrock |
| 19 | +from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings |
| 20 | +from langchain_community.embeddings import OllamaEmbeddings |
| 21 | +from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI |
| 22 | +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings |
| 23 | +from langchain_fireworks import FireworksEmbeddings, ChatFireworks |
| 24 | +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI |
| 25 | +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA |
| 26 | + |
17 | 27 | from ..utils.logging import get_logger
|
18 | 28 | from .base_node import BaseNode
|
| 29 | +from ..helpers import models_tokens |
| 30 | +from ..models import DeepSeek |
19 | 31 |
|
20 | 32 |
|
21 | 33 | class RAGNode(BaseNode):
|
@@ -95,10 +107,21 @@ def execute(self, state: dict) -> dict:
|
95 | 107 | self.logger.info("--- (updated chunks metadata) ---")
|
96 | 108 |
|
97 | 109 | # check if embedder_model is provided, if not use llm_model
|
98 |
| - self.embedder_model = ( |
99 |
| - self.embedder_model if self.embedder_model else self.llm_model |
100 |
| - ) |
101 |
| - embeddings = self.embedder_model |
| 110 | + if self.embedder_model is not None: |
| 111 | + embeddings = self.embedder_model |
| 112 | + elif 'embeddings' in self.node_config: |
| 113 | + try: |
| 114 | + embeddings = self._create_embedder(self.node_config['embedder_config']) |
| 115 | + except Exception: |
| 116 | + try: |
| 117 | + embeddings = self._create_default_embedder() |
| 118 | + self.embedder_model = embeddings |
| 119 | + except ValueError: |
| 120 | + embeddings = self.llm_model |
| 121 | + self.embedder_model = self.llm_model |
| 122 | + else: |
| 123 | + embeddings = self.llm_model |
| 124 | + self.embedder_model = self.llm_model |
102 | 125 |
|
103 | 126 | folder_name = self.node_config.get("cache_path", "cache")
|
104 | 127 |
|
@@ -141,3 +164,116 @@ def execute(self, state: dict) -> dict:
|
141 | 164 |
|
142 | 165 | state.update({self.output[0]: compressed_docs})
|
143 | 166 | return state
|
| 167 | + |
| 168 | + |
| 169 | + def _create_default_embedder(self, llm_config=None) -> object: |
| 170 | + """ |
| 171 | + Create an embedding model instance based on the chosen llm model. |
| 172 | +
|
| 173 | + Returns: |
| 174 | + object: An instance of the embedding model client. |
| 175 | +
|
| 176 | + Raises: |
| 177 | + ValueError: If the model is not supported. |
| 178 | + """ |
| 179 | + if isinstance(self.llm_model, ChatGoogleGenerativeAI): |
| 180 | + return GoogleGenerativeAIEmbeddings( |
| 181 | + google_api_key=llm_config["api_key"], model="models/embedding-001" |
| 182 | + ) |
| 183 | + if isinstance(self.llm_model, ChatOpenAI): |
| 184 | + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, |
| 185 | + base_url=self.llm_model.openai_api_base) |
| 186 | + elif isinstance(self.llm_model, DeepSeek): |
| 187 | + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) |
| 188 | + elif isinstance(self.llm_model, ChatVertexAI): |
| 189 | + return VertexAIEmbeddings() |
| 190 | + elif isinstance(self.llm_model, AzureOpenAIEmbeddings): |
| 191 | + return self.llm_model |
| 192 | + elif isinstance(self.llm_model, AzureChatOpenAI): |
| 193 | + return AzureOpenAIEmbeddings() |
| 194 | + elif isinstance(self.llm_model, ChatFireworks): |
| 195 | + return FireworksEmbeddings(model=self.llm_model.model_name) |
| 196 | + elif isinstance(self.llm_model, ChatNVIDIA): |
| 197 | + return NVIDIAEmbeddings(model=self.llm_model.model_name) |
| 198 | + elif isinstance(self.llm_model, ChatOllama): |
| 199 | + # unwrap the kwargs from the model whihc is a dict |
| 200 | + params = self.llm_model._lc_kwargs |
| 201 | + # remove streaming and temperature |
| 202 | + params.pop("streaming", None) |
| 203 | + params.pop("temperature", None) |
| 204 | + |
| 205 | + return OllamaEmbeddings(**params) |
| 206 | + elif isinstance(self.llm_model, ChatHuggingFace): |
| 207 | + return HuggingFaceEmbeddings(model=self.llm_model.model) |
| 208 | + elif isinstance(self.llm_model, ChatBedrock): |
| 209 | + return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) |
| 210 | + else: |
| 211 | + raise ValueError("Embedding Model missing or not supported") |
| 212 | + |
| 213 | + |
| 214 | + def _create_embedder(self, embedder_config: dict) -> object: |
| 215 | + """ |
| 216 | + Create an embedding model instance based on the configuration provided. |
| 217 | +
|
| 218 | + Args: |
| 219 | + embedder_config (dict): Configuration parameters for the embedding model. |
| 220 | +
|
| 221 | + Returns: |
| 222 | + object: An instance of the embedding model client. |
| 223 | +
|
| 224 | + Raises: |
| 225 | + KeyError: If the model is not supported. |
| 226 | + """ |
| 227 | + embedder_params = {**embedder_config} |
| 228 | + if "model_instance" in embedder_config: |
| 229 | + return embedder_params["model_instance"] |
| 230 | + # Instantiate the embedding model based on the model name |
| 231 | + if "openai" in embedder_params["model"]: |
| 232 | + return OpenAIEmbeddings(api_key=embedder_params["api_key"]) |
| 233 | + if "azure" in embedder_params["model"]: |
| 234 | + return AzureOpenAIEmbeddings() |
| 235 | + if "nvidia" in embedder_params["model"]: |
| 236 | + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) |
| 237 | + try: |
| 238 | + models_tokens["nvidia"][embedder_params["model"]] |
| 239 | + except KeyError as exc: |
| 240 | + raise KeyError("Model not supported") from exc |
| 241 | + return NVIDIAEmbeddings(model=embedder_params["model"], |
| 242 | + nvidia_api_key=embedder_params["api_key"]) |
| 243 | + if "ollama" in embedder_params["model"]: |
| 244 | + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) |
| 245 | + try: |
| 246 | + models_tokens["ollama"][embedder_params["model"]] |
| 247 | + except KeyError as exc: |
| 248 | + raise KeyError("Model not supported") from exc |
| 249 | + return OllamaEmbeddings(**embedder_params) |
| 250 | + if "hugging_face" in embedder_params["model"]: |
| 251 | + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) |
| 252 | + try: |
| 253 | + models_tokens["hugging_face"][embedder_params["model"]] |
| 254 | + except KeyError as exc: |
| 255 | + raise KeyError("Model not supported") from exc |
| 256 | + return HuggingFaceEmbeddings(model=embedder_params["model"]) |
| 257 | + if "fireworks" in embedder_params["model"]: |
| 258 | + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) |
| 259 | + try: |
| 260 | + models_tokens["fireworks"][embedder_params["model"]] |
| 261 | + except KeyError as exc: |
| 262 | + raise KeyError("Model not supported") from exc |
| 263 | + return FireworksEmbeddings(model=embedder_params["model"]) |
| 264 | + if "gemini" in embedder_params["model"]: |
| 265 | + try: |
| 266 | + models_tokens["gemini"][embedder_params["model"]] |
| 267 | + except KeyError as exc: |
| 268 | + raise KeyError("Model not supported") from exc |
| 269 | + return GoogleGenerativeAIEmbeddings(model=embedder_params["model"]) |
| 270 | + if "bedrock" in embedder_params["model"]: |
| 271 | + embedder_params["model"] = embedder_params["model"].split("/")[-1] |
| 272 | + client = embedder_params.get("client", None) |
| 273 | + try: |
| 274 | + models_tokens["bedrock"][embedder_params["model"]] |
| 275 | + except KeyError as exc: |
| 276 | + raise KeyError("Model not supported") from exc |
| 277 | + return BedrockEmbeddings(client=client, model_id=embedder_params["model"]) |
| 278 | + |
| 279 | + raise ValueError("Model provided by the configuration not supported") |
0 commit comments