|
5 | 5 | from abc import ABC, abstractmethod
|
6 | 6 | from typing import Optional
|
7 | 7 |
|
8 |
| -from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock |
| 8 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 9 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
| 10 | +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings |
| 11 | + |
9 | 12 | from ..helpers import models_tokens
|
| 13 | +from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI |
10 | 14 |
|
11 | 15 |
|
12 | 16 | class AbstractGraph(ABC):
|
@@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
|
43 | 47 | self.source = source
|
44 | 48 | self.config = config
|
45 | 49 | self.llm_model = self._create_llm(config["llm"], chat=True)
|
46 |
| - self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm( |
| 50 | + self.embedder_model = self._create_default_embedder( |
| 51 | + ) if "embeddings" not in config else self._create_embedder( |
47 | 52 | config["embeddings"])
|
48 | 53 |
|
49 | 54 | # Set common configuration parameters
|
@@ -165,6 +170,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
|
165 | 170 | else:
|
166 | 171 | raise ValueError(
|
167 | 172 | "Model provided by the configuration not supported")
|
| 173 | + |
| 174 | + def _create_default_embedder(self) -> object: |
| 175 | + """ |
| 176 | + Create an embedding model instance based on the chosen llm model. |
| 177 | +
|
| 178 | + Returns: |
| 179 | + object: An instance of the embedding model client. |
| 180 | +
|
| 181 | + Raises: |
| 182 | + ValueError: If the model is not supported. |
| 183 | + """ |
| 184 | + |
| 185 | + if isinstance(self.llm_model, OpenAI): |
| 186 | + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) |
| 187 | + elif isinstance(self.llm_model, AzureOpenAIEmbeddings): |
| 188 | + return self.llm_model |
| 189 | + elif isinstance(self.llm_model, AzureOpenAI): |
| 190 | + return AzureOpenAIEmbeddings() |
| 191 | + elif isinstance(self.llm_model, Ollama): |
| 192 | + # unwrap the kwargs from the model whihc is a dict |
| 193 | + params = self.llm_model._lc_kwargs |
| 194 | + # remove streaming and temperature |
| 195 | + params.pop("streaming", None) |
| 196 | + params.pop("temperature", None) |
| 197 | + |
| 198 | + return OllamaEmbeddings(**params) |
| 199 | + elif isinstance(self.llm_model, HuggingFace): |
| 200 | + return HuggingFaceHubEmbeddings(model=self.llm_model.model) |
| 201 | + elif isinstance(self.llm_model, Bedrock): |
| 202 | + return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) |
| 203 | + else: |
| 204 | + raise ValueError("Embedding Model missing or not supported") |
| 205 | + |
| 206 | + def _create_embedder(self, embedder_config: dict) -> object: |
| 207 | + """ |
| 208 | + Create an embedding model instance based on the configuration provided. |
| 209 | +
|
| 210 | + Args: |
| 211 | + embedder_config (dict): Configuration parameters for the embedding model. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + object: An instance of the embedding model client. |
| 215 | +
|
| 216 | + Raises: |
| 217 | + KeyError: If the model is not supported. |
| 218 | + """ |
| 219 | + |
| 220 | + # Instantiate the embedding model based on the model name |
| 221 | + if "openai" in embedder_config["model"]: |
| 222 | + return OpenAIEmbeddings(api_key=embedder_config["api_key"]) |
| 223 | + |
| 224 | + elif "azure" in embedder_config["model"]: |
| 225 | + return AzureOpenAIEmbeddings() |
| 226 | + |
| 227 | + elif "ollama" in embedder_config["model"]: |
| 228 | + embedder_config["model"] = embedder_config["model"].split("/")[-1] |
| 229 | + try: |
| 230 | + models_tokens["ollama"][embedder_config["model"]] |
| 231 | + except KeyError: |
| 232 | + raise KeyError("Model not supported") |
| 233 | + return OllamaEmbeddings(**embedder_config) |
| 234 | + |
| 235 | + elif "hugging_face" in embedder_config["model"]: |
| 236 | + try: |
| 237 | + models_tokens["hugging_face"][embedder_config["model"]] |
| 238 | + except KeyError: |
| 239 | + raise KeyError("Model not supported") |
| 240 | + return HuggingFaceHubEmbeddings(model=embedder_config["model"]) |
| 241 | + |
| 242 | + elif "bedrock" in embedder_config["model"]: |
| 243 | + embedder_config["model"] = embedder_config["model"].split("/")[-1] |
| 244 | + try: |
| 245 | + models_tokens["bedrock"][embedder_config["model"]] |
| 246 | + except KeyError: |
| 247 | + raise KeyError("Model not supported") |
| 248 | + return BedrockEmbeddings(client=None, model_id=embedder_config["model"]) |
| 249 | + else: |
| 250 | + raise ValueError( |
| 251 | + "Model provided by the configuration not supported") |
168 | 252 |
|
169 | 253 | def get_state(self, key=None) -> dict:
|
170 | 254 | """""
|
|
0 commit comments