Skip to content

Commit

Permalink
Merge pull request #5 from wdoppenberg/feat/glowrs-update
Browse files Browse the repository at this point in the history
Added normalization; updated glowrs; added pre-commit hooks
  • Loading branch information
wdoppenberg authored Apr 22, 2024
2 parents dcbff2f + ec2ad95 commit 46793b1
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ docs/_build/

# Polars Extension
.so
.dll
.dll

notebooks/
23 changes: 23 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
repos:
- repo: local
hooks:
- id: rustfmt
name: rustfmt
entry: cargo fmt -- --check
language: system
types: [rust]
pass_filenames: false

- id: clippy
name: clippy
entry: cargo clippy
language: system
types: [rust]
pass_filenames: false

- id: cargo-check
name: cargo-check
entry: cargo check
language: system
types: [rust]
pass_filenames: false
7 changes: 3 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars-candle"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
authors = ["Wouter Doppenberg <[email protected]>"]
build = "src/build.rs"
Expand All @@ -15,7 +15,7 @@ crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.79"
glowrs = "0.2.1"
glowrs = "0.3.0"
chrono = "0.4.35"
ndarray = "0.15.6"
polars = { version = "0.39.2", features = ["lazy", "dtype-array", "dtype-categorical", "ndarray", "log"] }
Expand Down
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ import polars_candle # ignore: F401

df = pl.DataFrame({"s": ["This is a sentence", "This is another sentence"]})

embed_kwargs = {
"model_repo": "Snowflake/snowflake-arctic-embed-xs",
"pooling": "mean",
}

df = df.with_columns(
pl.col("s").candle.embed_text("Snowflake/snowflake-arctic-embed-xs").alias("s_embedding")
pl.col("s").candle.embed_text(**embed_kwargs).alias("s_embedding")
)
print(df)
# ┌──────────────────────────┬───────────────────────────────────┐
Expand All @@ -34,16 +39,23 @@ implementation for sentence embedding.

# Installation

Clone the repository and install the package using:
Make sure you have `polars` installed. If not, install it using `pip install polars`. Then, install `polars-candle` using

```bash
pip install polars-candle
```

If you want to install the latest version from the repository, you can use:

```bash
pip install .
pip install git+https://github.com/wdoppenberg/polars-candle.git
```

_Note:_ PyPI package is not available yet, will be in the future.
_Note:_ You need to have the Rust toolchain installed on your system to compile the library. See
[here](https://www.rust-lang.org/tools/install) for instructions on how to install Rust.

If you're on a Mac with an ARM processor, the library will install with Metal acceleration by default.
Should you want more control over the installation, you can install the package using:
If you're on a Mac with an ARM processor, the library will compile with Metal acceleration by default.
Should you want more control over the installation, you can set build features using `maturin`:

```bash
maturin develop --release -F <feature>
Expand Down
11 changes: 9 additions & 2 deletions polars_candle/candle_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ class CandleExt:
def __init__(self, expr: pl.Expr) -> None:
self._expr = expr

def embed_text(self, model_repo: str, pooling: Literal["max", "sum", "mean"] = "mean") -> pl.Expr:
def embed_text(
self,
model_repo: str,
pooling: Literal["max", "sum", "mean"] = "mean",
normalize: bool = False,
) -> pl.Expr:
"""
Embed text using a pre-trained model.
Expand All @@ -40,6 +45,8 @@ def embed_text(self, model_repo: str, pooling: Literal["max", "sum", "mean"] = "
The repository name of the text embedding model to use. E.g. "sentence-transformers/all-MiniLM-L6-v2".
pooling
The pooling strategy to use. One of "max", "sum", or "mean".
normalize
Whether to normalize (L2) the embeddings - meaning that all embeddings will have a length of 1.
Returns
-------
Expand All @@ -51,6 +58,6 @@ def embed_text(self, model_repo: str, pooling: Literal["max", "sum", "mean"] = "
plugin_path=Path(__file__).parent,
function_name="embed_text",
args=[self._expr],
kwargs={"model_repo": model_repo, "pooling": pooling},
kwargs={"model_repo": model_repo, "pooling": pooling, "normalize": normalize},
is_elementwise=True,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ maturin = "*"

[tool.poetry]
name = "polars_candle"
version = "0.1.3"
version = "0.1.4"
description = "A text embedding extension for the Polars Dataframe library."
keywords = ["polars", "dataframe", "embedding", "nlp", "candle"]
authors = ["Wouter Doppenberg <[email protected]>"]
Expand Down
7 changes: 6 additions & 1 deletion src/candle_ext/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub struct EmbeddingKwargs {

/// Pooling strategy
pub pooling: PoolingStrategy,

/// Normalize embeddings
pub normalize: bool,
}

#[polars_expr(output_type_func=array_f32_output)]
Expand All @@ -47,8 +50,10 @@ pub fn embed_text(s: &[Series], kwargs: EmbeddingKwargs) -> PolarsResult<Series>
.iter()
.filter_map(|(_, sentence)| **sentence)
.collect();

// Embed the sentences
let embeddings = model
.encode_batch(some_sentences, false)
.encode_batch(some_sentences, kwargs.normalize)
.map_err(|e| polars_err!(ComputeError: "Encoding failed with error:\n{}", e))?;

let (_, emb_dim) = embeddings.dims2().map_err(
Expand Down
26 changes: 23 additions & 3 deletions tests/test_candle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ def test_basic_with_none():
pl.col("s").candle.embed_text("Snowflake/snowflake-arctic-embed-xs").alias("s_embedding")
)
print(df)
assert df["s_embedding"].dtype == pl.Array

df = df.explode("s_embedding")
# Check if the None values are still there
assert df["s_embedding"].null_count() == 2

assert df["s_embedding"].dtype == pl.Float32
# Check if the None values are in the correct position
df_check = df.with_columns(
pl.col("s").is_null().alias("is_null"),
pl.col("s_embedding").is_null().alias("is_null_embedding")
)
assert df_check.select(pl.col("is_null").eq(pl.col("is_null_embedding")).all()).item()


def test_pooling():
Expand All @@ -46,3 +51,18 @@ def test_pooling():

assert df["s_embedding"].dtype == pl.Float32
assert df["s_embedding"].max() > 0.0


def test_normalize():
df = pl.DataFrame({"s": ["This is a sentence"]})

df = df.with_columns(
pl.col("s")
.candle.embed_text("Snowflake/snowflake-arctic-embed-xs", normalize=True)
.alias("s_embedding")
)

df = df.explode("s_embedding")
# Check if the embedding's length is 1

assert df.select(pl.col("s_embedding").pow(2)).sum().item() - 1. < 1e-5

0 comments on commit 46793b1

Please sign in to comment.