Skip to content

Commit 766e505

Browse files
authored
Merge pull request #271 from leeeizhang/lei/memory-cli
[MRG] add mle memory crud api
2 parents b9765e7 + bd5fe58 commit 766e505

File tree

2 files changed

+113
-15
lines changed

2 files changed

+113
-15
lines changed

mle/cli.py

+67
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,70 @@ def integrate(reset):
353353
"token": pickle.dumps(token, fix_imports=False),
354354
}
355355
write_config(config)
356+
357+
358+
@cli.command()
359+
@click.option('--add', default=None, help='Add files or directories into the local memory.')
360+
@click.option('--rm', default=None, help='Remove files or directories into the local memory.')
361+
@click.option('--update', default=None, help='Update files or directories into the local memory.')
362+
def memory(add, rm, update):
363+
memory = LanceDBMemory(os.getcwd())
364+
path = add or rm or update
365+
if path is None:
366+
return
367+
368+
source_files = []
369+
if os.path.isdir(path):
370+
source_files = list_files(path, ['*.py'])
371+
else:
372+
source_files = [path]
373+
374+
working_dir = os.getcwd()
375+
table_name = 'mle_chat_' + working_dir.split('/')[-1]
376+
chunker = CodeChunker(os.path.join(working_dir, '.mle', 'cache'), 'py')
377+
with Progress(
378+
SpinnerColumn(),
379+
TextColumn("[progress.description]{task.description}"),
380+
BarColumn(),
381+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
382+
TimeElapsedColumn(),
383+
console=console,
384+
) as progress:
385+
process_task = progress.add_task("Processing files...", total=len(source_files))
386+
387+
for file_path in source_files:
388+
raw_code = read_file(file_path)
389+
progress.update(
390+
process_task,
391+
advance=1,
392+
description=f"Process {os.path.basename(file_path)} for memory..."
393+
)
394+
395+
if add:
396+
# add file into memory
397+
chunks = chunker.chunk(raw_code, token_limit=100)
398+
memory.add(
399+
texts=list(chunks.values()),
400+
table_name=table_name,
401+
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
402+
)
403+
elif rm:
404+
# remove file from memory
405+
memory.delete_by_metadata(
406+
key="file",
407+
value=file_path,
408+
table_name=table_name,
409+
)
410+
elif update:
411+
# update file into memory
412+
chunks = chunker.chunk(raw_code, token_limit=100)
413+
memory.delete_by_metadata(
414+
key="file",
415+
value=file_path,
416+
table_name=table_name,
417+
)
418+
memory.add(
419+
texts=list(chunks.values()),
420+
table_name=table_name,
421+
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
422+
)

mle/utils/memory.py

+46-15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@ def __init__(self, project_path: str):
2525
else:
2626
raise NotImplementedError
2727

28+
def _open_table(self, table_name: str = None):
29+
"""
30+
Open a LanceDB table by table name. (Return None if not exists)
31+
Args:
32+
table_name (Optional[str]): The name of the table. Defaults to self.table_name.
33+
"""
34+
table_name = table_name or self.table_name
35+
try:
36+
table = self.client.open_table(table_name)
37+
except FileNotFoundError:
38+
return None
39+
return table
40+
2841
def add(
2942
self,
3043
texts: List[str],
@@ -73,7 +86,7 @@ def add(
7386
table = self.client.create_table(table_name, data=data)
7487
table.create_fts_index("id")
7588
else:
76-
self.client.open_table(table_name).add(data=data)
89+
self._open_table(table_name).add(data=data)
7790

7891
return ids
7992

@@ -90,8 +103,10 @@ def query(self, query_texts: List[str], table_name: Optional[str] = None, n_resu
90103
List[List[dict]]: A list of results for each query text, each result being a dictionary with
91104
keys such as "vector", "text", and "id".
92105
"""
93-
table_name = table_name or self.table_name
94-
table = self.client.open_table(table_name)
106+
table = self._open_table(table_name)
107+
if table is None:
108+
return []
109+
95110
query_embeds = self.text_embedding.compute_source_embeddings(query_texts)
96111

97112
results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
@@ -107,8 +122,10 @@ def list_all_keys(self, table_name: Optional[str] = None):
107122
Returns:
108123
List[str]: A list of all IDs in the table.
109124
"""
110-
table_name = table_name or self.table_name
111-
table = self.client.open_table(table_name)
125+
table = self._open_table(table_name)
126+
if table is None:
127+
return []
128+
112129
return [item["id"] for item in table.search(query_type="fts").to_list()]
113130

114131
def get(self, record_id: str, table_name: Optional[str] = None):
@@ -122,8 +139,10 @@ def get(self, record_id: str, table_name: Optional[str] = None):
122139
Returns:
123140
List[dict]: A list containing the matching record, or an empty list if not found.
124141
"""
125-
table_name = table_name or self.table_name
126-
table = self.client.open_table(table_name)
142+
table = self._open_table(table_name)
143+
if table is None:
144+
return []
145+
127146
return table.search(query_type="fts") \
128147
.where(f"id = '{record_id}'") \
129148
.limit(1).to_list()
@@ -141,8 +160,10 @@ def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None
141160
Returns:
142161
List[dict]: A list of records matching the metadata criteria.
143162
"""
144-
table_name = table_name or self.table_name
145-
table = self.client.open_table(table_name)
163+
table = self._open_table(table_name)
164+
if table is None:
165+
return []
166+
146167
return table.search(query_type="fts") \
147168
.where(f"metadata.{key} = '{value}'") \
148169
.limit(n_results).to_list()
@@ -158,8 +179,10 @@ def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
158179
Returns:
159180
bool: True if the deletion was successful, False otherwise.
160181
"""
161-
table_name = table_name or self.table_name
162-
table = self.client.open_table(table_name)
182+
table = self._open_table(table_name)
183+
if table is None:
184+
return True
185+
163186
return table.delete(f"id = '{record_id}'")
164187

165188
def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None):
@@ -174,8 +197,10 @@ def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = N
174197
Returns:
175198
bool: True if deletion was successful, False otherwise.
176199
"""
177-
table_name = table_name or self.table_name
178-
table = self.client.open_table(table_name)
200+
table = self._open_table(table_name)
201+
if table is None:
202+
return True
203+
179204
return table.delete(f"metadata.{key} = '{value}'")
180205

181206
def drop(self, table_name: Optional[str] = None) -> bool:
@@ -189,6 +214,10 @@ def drop(self, table_name: Optional[str] = None) -> bool:
189214
bool: True if the table was successfully dropped, False otherwise.
190215
"""
191216
table_name = table_name or self.table_name
217+
table = self._open_table(table_name)
218+
if table is None:
219+
return True
220+
192221
return self.client.drop_table(table_name)
193222

194223
def count(self, table_name: Optional[str] = None) -> int:
@@ -201,8 +230,10 @@ def count(self, table_name: Optional[str] = None) -> int:
201230
Returns:
202231
int: The number of records in the table.
203232
"""
204-
table_name = table_name or self.table_name
205-
table = self.client.open_table(table_name)
233+
table = self._open_table(table_name)
234+
if table is None:
235+
return 0
236+
206237
return table.count_rows()
207238

208239
def reset(self) -> None:

0 commit comments

Comments
 (0)