6
6
from typing import List
7
7
8
8
import gzip
9
- import lance
10
- import numpy as np
11
- import pyarrow as pa
12
9
import requests
13
10
14
11
@@ -33,15 +30,15 @@ def cosine(X, Y):
33
30
def knn (
34
31
query : np .ndarray ,
35
32
data : np .ndarray ,
36
- metric : Literal ['L2' , ' cosine' ],
33
+ metric : Literal ["L2" , " cosine" ],
37
34
k : int ,
38
35
) -> np .ndarray :
39
- if metric == 'L2' :
36
+ if metric == "L2" :
40
37
dist = l2
41
- elif metric == ' cosine' :
38
+ elif metric == " cosine" :
42
39
dist = cosine
43
40
else :
44
- raise ValueError (' Invalid metric' )
41
+ raise ValueError (" Invalid metric" )
45
42
return np .argpartition (dist (query , data ), k , axis = 1 )[:, 0 :k ]
46
43
47
44
@@ -51,10 +48,12 @@ def write_lance(
51
48
):
52
49
dims = data .shape [1 ]
53
50
54
- schema = pa .schema ([
55
- pa .field ("vec" , pa .list_ (pa .float32 (), dims )),
56
- pa .field ("id" , pa .uint32 (), False ),
57
- ])
51
+ schema = pa .schema (
52
+ [
53
+ pa .field ("vec" , pa .list_ (pa .float32 (), dims )),
54
+ pa .field ("id" , pa .uint32 (), False ),
55
+ ]
56
+ )
58
57
59
58
fsl = pa .FixedSizeListArray .from_arrays (
60
59
pa .array (data .reshape (- 1 ).astype (np .float32 ), type = pa .float32 ()),
@@ -65,6 +64,7 @@ def write_lance(
65
64
66
65
lance .write_dataset (t , path )
67
66
67
+
68
68
# NYT
69
69
70
70
_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.nytimes.txt.gz"
@@ -112,7 +112,8 @@ def _get_nyt_vectors(
112
112
tfidf = TfidfTransformer ().fit_transform (freq )
113
113
print ("computing dense projection" )
114
114
dense_projection = random_projection .GaussianRandomProjection (
115
- n_components = output_dims , random_state = 42 ,
115
+ n_components = output_dims ,
116
+ random_state = 42 ,
116
117
).fit_transform (tfidf )
117
118
dense_projection = dense_projection .astype (np .float32 )
118
119
np .save (_CACHE_PATH , dense_projection )
0 commit comments