Skip to content

Commit

Permalink
md5
Browse files Browse the repository at this point in the history
  • Loading branch information
Ami11111 committed Oct 15, 2024
1 parent ea06ce9 commit db0e6b2
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/function/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import pow;
import substring;
import substract;
import char_length;
import md5;
import default_values;
import special_function;
import internal_types;
Expand Down Expand Up @@ -104,7 +105,6 @@ void BuiltinFunctions::RegisterScalarFunction() {
RegisterLessEqualsFunction(catalog_ptr_);
RegisterGreaterFunction(catalog_ptr_);
RegisterGreaterEqualsFunction(catalog_ptr_);
RegisterCharLengthFunction(catalog_ptr_);

// like function
RegisterLikeFunction(catalog_ptr_);
Expand All @@ -115,6 +115,8 @@ void BuiltinFunctions::RegisterScalarFunction() {

// string functions
RegisterSubstringFunction(catalog_ptr_);
RegisterCharLengthFunction(catalog_ptr_);
RegisterMd5Function(catalog_ptr_);
}

void BuiltinFunctions::RegisterTableFunction() {}
Expand Down
69 changes: 69 additions & 0 deletions src/function/scalar/md5.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module;

#include <openssl/md5.h>

module md5;

import stl;
import catalog;
import status;
import infinity_exception;
import scalar_function;
import scalar_function_set;

import third_party;
import logical_type;
import internal_types;
import data_type;
import logger;
import column_vector;

namespace infinity {

struct Md5Function {
template <typename TA, typename TB>
static inline void Run(TA &left, TB &result) {
Status status = Status::NotSupport("Not implemented");
RecoverableError(status);
}
};

template <>
inline void Md5Function::Run(VarcharT &left, VarcharT &result) {
unsigned char digest[MD5_DIGEST_LENGTH];
const char *input = nullptr;
SizeT input_len = 0;
//GetReaderValue(left, input, input_len);
MD5(reinterpret_cast<const unsigned char *>(input), input_len, digest);

}


struct ColumnValueReaderMd5Function {
template <typename TA, typename TB>
static inline void Run(TA &left, TB &result) {
unsigned char digest[MD5_DIGEST_LENGTH];
const char *input = nullptr;
SizeT input_len = 0;
GetReaderValue(left, input, input_len);
MD5(reinterpret_cast<const unsigned char *>(input), input_len, digest);
SetReaderValue(result, reinterpret_cast<const char*>(digest), MD5_DIGEST_LENGTH);
}
};


void RegisterMd5Function(const UniquePtr<Catalog> &catalog_ptr){
String func_name = "md5";

SharedPtr<ScalarFunctionSet> function_set_ptr = MakeShared<ScalarFunctionSet>(func_name);

ScalarFunction md5_function(func_name,
{DataType(LogicalType::kVarchar)},
{DataType(LogicalType::kVarchar)},
&ScalarFunction::UnaryFunction<VarcharT, VarcharT, ColumnValueReaderMd5Function>);
function_set_ptr->AddFunction(md5_function);

Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr);
}

}
13 changes: 13 additions & 0 deletions src/function/scalar/md5.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module;

import stl;

export module md5;

namespace infinity {

class Catalog;

export void RegisterMd5Function(const UniquePtr<Catalog> &catalog_ptr);

}
11 changes: 11 additions & 0 deletions src/storage/column_vector/column_vector.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,17 @@ public:
return left_v.size() == right_v.size() && std::strncmp(left_v.data(), right_v.data(), left_v.size()) == 0;
}

friend void GetReaderValue(const IteratorType &left, const char* &dst, SizeT &dst_len) {
Span<const char> left_v = left.col_->GetVarchar(left.idx_);
dst = left_v.data();
dst_len = left_v.size();
}

friend void SetReaderValue(IteratorType &left, const char* dst, int dst_len) {
Span<const char> span(dst, dst_len);
left.col_->AppendVarchar(span);
}

private:
const VarcharT *data_ptr_ = nullptr;
// VectorBuffer *vec_buffer_ = nullptr;
Expand Down
74 changes: 74 additions & 0 deletions src/storage/column_vector/operator/unary_operator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <concepts>
#include <type_traits>

export module unary_operator;
Expand All @@ -29,8 +30,81 @@ namespace infinity {

struct ColumnVectorCastData;

template <typename InputType, typename Operator>
class VarcharResultUnaryOperator {
private:
static void inline VarcharResultExecuteFlatWithNull(const SharedPtr<ColumnVector> &input,
SharedPtr<ColumnVector> &result,
SizeT count,
void *state_ptr) {
const auto &input_null = input->nulls_ptr_;
auto &result_null = result->nulls_ptr_;
*result_null = *input_null;
ColumnValueReader<InputType> input_ptr(input);
ColumnValueReader<VarcharT> result_ptr(result);
result_null->RoaringBitmapApplyFunc([&](u32 row_index) -> bool {
if (row_index >= count) {
return false;
}
Operator::template Execute(input_ptr[row_index], result_ptr[row_index], result_null.get(), row_index, state_ptr);
return row_index + 1 < count;
});
}

public:
static void inline Execute(const SharedPtr<ColumnVector> &input,
SharedPtr<ColumnVector> &result,
SizeT count,
void *state_ptr,
bool nullable){
const SharedPtr<Bitmask> &input_null = input->nulls_ptr_;
SharedPtr<Bitmask> &result_null = result->nulls_ptr_;

switch (input->vector_type()) {
case ColumnVectorType::kFlat: {
if (nullable) {
VarcharResultExecuteFlatWithNull(input, result, count, state_ptr);

} else {
auto input_ptr = ColumnValueReader<InputType>(input);
auto result_ptr = ColumnValueReader<VarcharT>(result);
for (SizeT i = 0; i < count; ++i) {
Operator::template Execute(input_ptr[i], result_ptr[i], result_null.get(), 0, state_ptr);
}
}
// Result tail_index need to update.
result->Finalize(count);
return;
}
case ColumnVectorType::kConstant: {
if (count != 1) {
String error_message = "Attempting to execute more than one row of the constant column vector.";
UnrecoverableError(error_message);
}
if (nullable && !(input_null->IsAllTrue())) {
result_null->SetFalse(0);
} else {
result_null->SetAllTrue();
auto input_ptr = ColumnValueReader<InputType>(input);
auto result_ptr = ColumnValueReader<VarcharT>(result);
Operator::template Execute(input_ptr[0], result_ptr[0], result_null.get(), 0, state_ptr);
}
result->Finalize(1);
return;
}
default:
String error_message = "Invalid input ColumnVectorType. Support only kFlat and kConstant.";
UnrecoverableError(error_message);
}
}
};
export class UnaryOperator {
public:
template <std::same_as<VarcharT> InputType, std::same_as<VarcharT> ResultType, typename Operator>
static void inline Execute(const SharedPtr<ColumnVector> &input, SharedPtr<ColumnVector> &result, SizeT count, void *state_ptr, bool nullable) {
return VarcharResultUnaryOperator<InputType, Operator>::Execute(input, result, count, state_ptr, nullable);
}

template <typename InputType, typename ResultType, typename Operator>
static void inline Execute(const SharedPtr<ColumnVector> &input, SharedPtr<ColumnVector> &result, SizeT count, void *state_ptr, bool nullable) {
const auto *input_ptr = (const InputType *)(input->data());
Expand Down
12 changes: 12 additions & 0 deletions test/sql/dql/type/varchar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,17 @@ SELECT * FROM test_varchar_filter where char_length(c1) = 3;
----
abc abcd 5

query VIII
SELECT * FROM test_varchar_filter where md5(c1) = md5(c2);
----
abcddddd abcddddd 1
abcdddde abcdddde 4

query VIIII
SELECT * FROM test_varchar_filter where md5(c1) = md5('abcdddde');
----
abcdddde abcddddd 3
abcdddde abcdddde 4

statement ok
DROP TABLE test_varchar_filter;

0 comments on commit db0e6b2

Please sign in to comment.