Skip to content

Commit 83ecc01

Browse files
authored
feat: update merge_insert to add statistics for inserted, updated, deleted rows (lancedb#2357)
Addresses lancedb#2019
1 parent e310ab4 commit 83ecc01

File tree

12 files changed

+249
-104
lines changed

12 files changed

+249
-104
lines changed

benchmarks/dbpedia-openai/benchmarks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def ground_truth(
4848

4949
def compute_recall(gt: np.ndarray, result: np.ndarray) -> float:
5050
recalls = [
51-
np.isin(rst, gt_vector).sum() / rst.shape[0] for (rst, gt_vector) in zip(result, gt)
51+
np.isin(rst, gt_vector).sum() / rst.shape[0]
52+
for (rst, gt_vector) in zip(result, gt)
5253
]
5354
return np.mean(recalls)
5455

benchmarks/flat/benchmark.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import time
1818

1919
import lance
20-
import matplotlib.pyplot as plt
2120
import numpy as np
2221
import pandas as pd
2322
import pyarrow as pa

benchmarks/full_report/_lib.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from typing import List
77

88
import gzip
9-
import lance
10-
import numpy as np
11-
import pyarrow as pa
129
import requests
1310

1411

@@ -33,15 +30,15 @@ def cosine(X, Y):
3330
def knn(
3431
query: np.ndarray,
3532
data: np.ndarray,
36-
metric: Literal['L2', 'cosine'],
33+
metric: Literal["L2", "cosine"],
3734
k: int,
3835
) -> np.ndarray:
39-
if metric == 'L2':
36+
if metric == "L2":
4037
dist = l2
41-
elif metric == 'cosine':
38+
elif metric == "cosine":
4239
dist = cosine
4340
else:
44-
raise ValueError('Invalid metric')
41+
raise ValueError("Invalid metric")
4542
return np.argpartition(dist(query, data), k, axis=1)[:, 0:k]
4643

4744

@@ -51,10 +48,12 @@ def write_lance(
5148
):
5249
dims = data.shape[1]
5350

54-
schema = pa.schema([
55-
pa.field("vec", pa.list_(pa.float32(), dims)),
56-
pa.field("id", pa.uint32(), False),
57-
])
51+
schema = pa.schema(
52+
[
53+
pa.field("vec", pa.list_(pa.float32(), dims)),
54+
pa.field("id", pa.uint32(), False),
55+
]
56+
)
5857

5958
fsl = pa.FixedSizeListArray.from_arrays(
6059
pa.array(data.reshape(-1).astype(np.float32), type=pa.float32()),
@@ -65,6 +64,7 @@ def write_lance(
6564

6665
lance.write_dataset(t, path)
6766

67+
6868
# NYT
6969

7070
_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.nytimes.txt.gz"
@@ -112,7 +112,8 @@ def _get_nyt_vectors(
112112
tfidf = TfidfTransformer().fit_transform(freq)
113113
print("computing dense projection")
114114
dense_projection = random_projection.GaussianRandomProjection(
115-
n_components=output_dims, random_state=42,
115+
n_components=output_dims,
116+
random_state=42,
116117
).fit_transform(tfidf)
117118
dense_projection = dense_projection.astype(np.float32)
118119
np.save(_CACHE_PATH, dense_projection)

benchmarks/sift/index.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from subprocess import check_output
2121

2222
import lance
23-
import pyarrow as pa
2423

2524

2625
def main():

benchmarks/tpch/benchmark.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Benchmark performance Lance vs Parquet w/ Tpch Q1 and Q6
22
import lance
3-
import pandas as pd
4-
import pyarrow as pa
53
import duckdb
64

75
import sys
@@ -46,10 +44,10 @@
4644
num_args = len(sys.argv)
4745
assert num_args == 2
4846

49-
query = ''
50-
if sys.argv[1] == 'q1':
47+
query = ""
48+
if sys.argv[1] == "q1":
5149
query = Q1
52-
elif sys.argv[1] == 'q6':
50+
elif sys.argv[1] == "q6":
5351
query = Q6
5452
else:
5553
sys.exit("We only support Q1 and Q6 for now")
@@ -62,17 +60,18 @@
6260
res1 = duckdb.sql(query).df()
6361
end1 = time.time()
6462

65-
print("Lance Latency: ",str(round(end1 - start1, 3)) + 's')
63+
print("Lance Latency: ", str(round(end1 - start1, 3)) + "s")
6664
print(res1)
6765

6866
##### Parquet #####
6967
lineitem = None
7068
start2 = time.time()
7169
# read from parquet and create a view instead of table from it
72-
duckdb.sql("CREATE VIEW lineitem AS SELECT * FROM read_parquet('./dataset/lineitem_sf1.parquet');")
70+
duckdb.sql(
71+
"CREATE VIEW lineitem AS SELECT * FROM read_parquet('./dataset/lineitem_sf1.parquet');"
72+
)
7373
res2 = duckdb.sql(query).df()
7474
end2 = time.time()
7575

76-
print("Parquet Latency: ",str(round(end2 - start2, 3)) + 's')
76+
print("Parquet Latency: ", str(round(end2 - start2, 3)) + "s")
7777
print(res2)
78-

docs/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Configuration file for the Sphinx documentation builder.
22

33
import shutil
4-
import subprocess
54

65

76
def run_apidoc(_):

docs/examples/gcs_example.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1-
#
1+
#
22
# Lance example loading a dataset from Google Cloud Storage
33
#
44
# You need to set one of the following environment variables in order to authenticate with GS
55
# - GOOGLE_SERVICE_ACCOUNT: location of service account file
66
# - GOOGLE_SERVICE_ACCOUNT_KEY: JSON serialized service account key
77
#
8-
# Follow this doc in order to create an service key: https://cloud.google.com/iam/docs/keys-create-delete
8+
# Follow this doc in order to create an service key: https://cloud.google.com/iam/docs/keys-create-delete
99
#
1010

1111
import lance
12+
import pandas as pd
1213

1314
ds = lance.dataset("gs://eto-public/datasets/oxford_pet/oxford_pet.lance")
1415
count = ds.count_rows()
1516
print(f"There are {count} pets")
1617

1718
# You can also write to GCS
18-
import pandas as pd
19+
1920
uri = "gs://eto-public/datasets/oxford_pet/example.lance"
20-
lance.write_dataset(pd.DataFrame({"a": pd.array([10], dtype="Int32")}), uri, mode='create')
21+
lance.write_dataset(
22+
pd.DataFrame({"a": pd.array([10], dtype="Int32")}), uri, mode="create"
23+
)
2124
assert lance.dataset(uri).version == 1
2225

23-
lance.write_dataset(pd.DataFrame({"a": pd.array([5], dtype="Int32")}), uri, mode='append')
26+
lance.write_dataset(
27+
pd.DataFrame({"a": pd.array([5], dtype="Int32")}), uri, mode="append"
28+
)
2429
assert lance.dataset(uri).version == 2
25-

python/python/lance/dataset.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ class MergeInsertBuilder(_MergeInsertBuilder):
8383
def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
8484
"""Executes the merge insert operation
8585
86-
There is no return value but the original dataset will be updated.
86+
This function updates the original dataset and returns a dictionary with
87+
information about merge statistics - i.e. the number of inserted, updated,
88+
and deleted rows.
8789
8890
Parameters
8991
----------
@@ -97,7 +99,8 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
9799
source is some kind of generator.
98100
"""
99101
reader = _coerce_reader(data_obj, schema)
100-
super(MergeInsertBuilder, self).execute(reader)
102+
103+
return super(MergeInsertBuilder, self).execute(reader)
101104

102105
# These next three overrides exist only to document the methods
103106

@@ -945,10 +948,11 @@ def merge_insert(
945948
>>> dataset = lance.write_dataset(table, "example")
946949
>>> new_table = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
947950
>>> # Perform a "upsert" operation
948-
>>> dataset.merge_insert("a") \\
949-
... .when_matched_update_all() \\
950-
... .when_not_matched_insert_all() \\
951-
... .execute(new_table)
951+
>>> dataset.merge_insert("a") \\
952+
... .when_matched_update_all() \\
953+
... .when_not_matched_insert_all() \\
954+
... .execute(new_table)
955+
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
952956
>>> dataset.to_table().sort_by("a").to_pandas()
953957
a b
954958
0 1 b

0 commit comments

Comments
 (0)