Skip to content

Commit

Permalink
Fix bvecs import use int8 rather than uint8
Browse files Browse the repository at this point in the history
  • Loading branch information
Ami11111 committed Aug 2, 2024
1 parent 692815d commit a69395d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/executor/operator/physical_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ void PhysicalImport::ImportBVECS(QueryContext *query_context, ImportOperatorStat
RecoverableError(status);
}
auto embedding_info = static_cast<EmbeddingInfo *>(column_type->type_info().get());
if (embedding_info->Type() != kElemInt8) {
Status status = Status::ImportFileFormatError("BVECS file must have only one embedding column with int8 element.");
if (embedding_info->Type() != kElemUInt8) {
Status status = Status::ImportFileFormatError("BVECS file must have only one embedding column with uint8 element.");
RecoverableError(status);
}

Expand Down Expand Up @@ -264,7 +264,7 @@ void PhysicalImport::ImportBVECS(QueryContext *query_context, ImportOperatorStat
RecoverableError(status);
}
SizeT file_size = fs.GetFileSize(*file_handler);
SizeT row_size = dimension * sizeof(i8) + sizeof(dimension);
SizeT row_size = dimension * sizeof(u8) + sizeof(dimension);
if (file_size % row_size != 0) {
String error_message = "Weird file size.";
UnrecoverableError(error_message);
Expand All @@ -281,7 +281,7 @@ void PhysicalImport::ImportBVECS(QueryContext *query_context, ImportOperatorStat
SizeT row_idx = 0;
auto buf_ptr = static_cast<ptr_t>(buffer_handle.GetDataMut());

UniquePtr<i8[]> i8_buffer = MakeUniqueForOverwrite<i8[]>(sizeof(i8) * dimension);
UniquePtr<u8[]> u8_buffer = MakeUniqueForOverwrite<u8[]>(sizeof(u8) * dimension);
while (true) {
i32 dim;
nbytes = fs.Read(*file_handler, &dim, sizeof(dimension));
Expand All @@ -290,12 +290,12 @@ void PhysicalImport::ImportBVECS(QueryContext *query_context, ImportOperatorStat
Status::ImportFileFormatError(fmt::format("Dimension in file ({}) doesn't match with table definition ({}).", dim, dimension));
RecoverableError(status);
}
fs.Read(*file_handler, i8_buffer.get(), sizeof(i8) * dimension);
fs.Read(*file_handler, u8_buffer.get(), sizeof(u8) * dimension);

i8 *dst_ptr = reinterpret_cast<i8 *>(buf_ptr + block_entry->row_count() * sizeof(i8) * dimension);
u8 *dst_ptr = reinterpret_cast<u8 *>(buf_ptr + block_entry->row_count() * sizeof(u8) * dimension);
for (i32 i = 0; i < dimension; ++i) {
i8 value = (i8_buffer.get())[i];
dst_ptr[i] = static_cast<i8>(value);
u8 value = (u8_buffer.get())[i];
dst_ptr[i] = static_cast<u8>(value);
}

block_entry->IncreaseRowCount(1);
Expand Down

0 comments on commit a69395d

Please sign in to comment.