Skip to content

laife.embed.sentence_transformer_embeddings

Porting of HuggingFace sentence-transformers wrapper.

To avoid dependencies issues with ollama.

Original file: https://github.com/langchain-ai/langchain/blob/master/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py

Classes:

SentenceTransformersEmbeddings

SentenceTransformersEmbeddings(**kwargs: Any)

Bases: BaseModel, Embeddings

HuggingFace sentence_transformers embedding models.

To use, you should have the sentence_transformers python package installed.

Example

.. code-block:: python

from langchain_huggingface import SentenceTransformersEmbeddings

model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
hf = SentenceTransformersEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

Initialize the sentence_transformer.

Methods:

  • embed_documents

    Compute doc embeddings using a HuggingFace transformer model.

  • embed_query

    Compute query embeddings using a HuggingFace transformer model.

Attributes:

Source code in src/laife/embed/sentence_transformer_embeddings.py
def __init__(self, **kwargs: Any) -> None:  # noqa: ANN401
    """Initialize the sentence_transformer."""
    super().__init__(**kwargs)
    try:
        import sentence_transformers  # noqa: PLC0415
    except ImportError as exc:
        msg = (
            "Could not import sentence_transformers python package. "
            "Please install it with `pip install sentence-transformers`."
        )
        raise ImportError(msg) from exc

    self.client = sentence_transformers.SentenceTransformer(
        self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
    )

cache_folder class-attribute instance-attribute

cache_folder: str | None = None

Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.

encode_kwargs class-attribute instance-attribute

encode_kwargs: dict[str, Any] = Field(default_factory=dict)

Keyword arguments to pass when calling the encode method of the Sentence Transformer model, such as prompt_name, prompt, batch_size, precision, normalize_embeddings, and more. See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode

model_kwargs class-attribute instance-attribute

model_kwargs: dict[str, Any] = Field(default_factory=dict)

Keyword arguments to pass to the Sentence Transformer model, such as device, prompts, default_prompt_name, revision, trust_remote_code, or token. See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer

model_name class-attribute instance-attribute

model_name: str = DEFAULT_MODEL_NAME

Model name to use.

multi_process class-attribute instance-attribute

multi_process: bool = False

Run encode() on multiple GPUs.

show_progress class-attribute instance-attribute

show_progress: bool = False

Whether to show a progress bar.

embed_documents

embed_documents(texts: list[str]) -> list[list[float]]

Compute doc embeddings using a HuggingFace transformer model.

Parameters:

  • texts (list[str]) –

    The list of texts to embed.

Returns:

Source code in src/laife/embed/sentence_transformer_embeddings.py
def embed_documents(self, texts: list[str]) -> list[list[float]]:
    """Compute doc embeddings using a HuggingFace transformer model.

    Args:
        texts: The list of texts to embed.

    Returns:
        List of embeddings, one for each text.
    """
    texts = [x.replace("\n", " ") for x in texts]
    if self.multi_process:
        pool = self.client.start_multi_process_pool()
        embeddings = self.client.encode_multi_process(texts, pool)
        # we import sentence_transformers locally in __init__, so pyright doesn't know about it
        sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)  # pyright: ignore[reportUndefinedVariable]  # noqa: F821
    else:
        embeddings = self.client.encode(
            texts, show_progress_bar=self.show_progress, **self.encode_kwargs
        )

    return embeddings.tolist()

embed_query

embed_query(text: str) -> list[float]

Compute query embeddings using a HuggingFace transformer model.

Parameters:

  • text (str) –

    The text to embed.

Returns:

Source code in src/laife/embed/sentence_transformer_embeddings.py
def embed_query(self, text: str) -> list[float]:
    """Compute query embeddings using a HuggingFace transformer model.

    Args:
        text: The text to embed.

    Returns:
        Embeddings for the text.
    """
    return self.embed_documents([text])[0]