Skip to content

WIP WASM UDF branch #32561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
734 changes: 719 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions misc/helm-charts/operator/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ telemetry:
# Network policies configuration
networkPolicies:
# -- Whether to enable network policies for securing communication between pods
enabled: false
enabled: true
# -- Whether to enable internal communication between Materialize pods
internal:
enabled: false
enabled: true
# -- Whether to enable ingress to the SQL and HTTP interfaces
# on environmentd or balancerd
ingress:
Expand Down
1 change: 1 addition & 0 deletions src/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ uncased = "0.9.7"
uuid = { version = "1.16.0", features = ["v5"] }
proptest = { version = "1.6.0", default-features = false, features = ["std"] }
proptest-derive = { version = "0.5.1", features = ["boxed_union"] }
wasmtime = "32.0.0"
workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true }
smallvec = { version = "1.15.0" }

Expand Down
2 changes: 2 additions & 0 deletions src/expr/src/scalar.proto
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ message ProtoUnaryFunc {
}
reserved 5, 6, 15, 104, 111, 115, 212, 306, 313, 321;
oneof kind {
string unary_wasm = 123456;
google.protobuf.Empty not = 1;
google.protobuf.Empty is_null = 2;
google.protobuf.Empty is_true = 121;
Expand Down Expand Up @@ -670,6 +671,7 @@ message ProtoBinaryFunc {
bool array_contains_array = 194;
google.protobuf.Empty starts_with = 195;
google.protobuf.Empty get_bit = 196;
google.protobuf.Empty wasm = 200;
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/expr/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,12 @@ impl MirScalarExpr {
Err(err.clone()),
e.typ(column_types).scalar_type,
);
} else if let BinaryFunc::Wasm = func {
if expr2.is_literal() {
// We can at least precompile the regex.
let text = expr2.as_literal_str().unwrap().to_owned();
*e = expr1.take().call_unary(UnaryFunc::UnaryWasm2(crate::func::UnaryWasm2 { text, data: Default::default() }));
}
} else if let BinaryFunc::IsLikeMatch { case_insensitive } = func {
if expr2.is_literal() {
// We can at least precompile the regex.
Expand Down
123 changes: 120 additions & 3 deletions src/expr/src/scalar/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,7 @@ fn starts_with<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {

#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
pub enum BinaryFunc {
Wasm,
AddInt16,
AddInt32,
AddInt64,
Expand Down Expand Up @@ -3516,6 +3517,111 @@ pub enum BinaryFunc {
StartsWith,
}

fn wasm<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
let mut wasm_state = WasmState::new(b.unwrap_str(), "logic");
Ok(Datum::Int64(wasm_state.eval(a.unwrap_int64())))
}

pub struct WasmState {
engine: wasmtime::Engine,
module: wasmtime::Module,
store: wasmtime::Store<()>,
instance: wasmtime::Instance,
func: wasmtime::TypedFunc<i64, i64>,
}

impl WasmState {
fn eval(&mut self, input: i64) -> i64 {
self.func.call(&mut self.store, input).unwrap()
}

fn new(text: &str, name: &str) -> Self {
use wasmtime::*;
let engine = Engine::default();
let module = Module::new(&engine, text).unwrap();
let mut store = Store::new(&engine, ());
let instance = Instance::new(&mut store, &module, &[]).unwrap();
let func = instance
.get_typed_func::<i64, i64>(&mut store, name)
.unwrap();

Self { engine, module, store, instance, func }
}
}

#[derive(Serialize, Deserialize)]
pub struct UnaryWasm2 {
pub text: String,
#[serde(skip)]
pub data: std::sync::Mutex<Option<WasmState>>,
}

impl<'a> EagerUnaryFunc<'a> for UnaryWasm2 {
type Input = i64;
type Output = Result<i64, EvalError>;

fn call(&self, a: i64) -> Result<i64, EvalError> {
let mut lock = self.data.lock().unwrap();
if lock.is_none() {
*lock = Some(WasmState::new(&self.text, "logic"));
}

if let Some(wasm_state) = &mut *lock {
Ok(wasm_state.eval(a))
} else {
unimplemented!()
}
}

fn output_type(&self, input: ColumnType) -> ColumnType {
ScalarType::Int64.nullable(input.nullable)
}
}

impl Ord for UnaryWasm2 {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.text.cmp(&other.text)
}
}
impl PartialOrd for UnaryWasm2 {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for UnaryWasm2 {
fn eq(&self, other: &Self) -> bool {
self.text.eq(&other.text)
}
}
impl std::cmp::Eq for UnaryWasm2 { }
impl Clone for UnaryWasm2 {
fn clone(&self) -> Self {
Self { text: self.text.clone(), data: Default::default() }
}
}


impl fmt::Display for UnaryWasm2 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "wasm(...)")
}
}
impl fmt::Debug for UnaryWasm2 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "wasm(...)")
}
}
impl std::hash::Hash for UnaryWasm2 {
fn hash<H>(&self, state: &mut H) where H: std::hash::Hasher {
self.text.hash(state)
}
}
impl MzReflect for UnaryWasm2 {
fn add_to_reflected_type_info(_rti: &mut mz_lowertest::ReflectedTypeInfo) {
}
}


impl BinaryFunc {
pub fn eval<'a>(
&'a self,
Expand All @@ -3530,6 +3636,7 @@ impl BinaryFunc {
return Ok(Datum::Null);
}
match self {
BinaryFunc::Wasm => wasm(a, b),
BinaryFunc::AddInt16 => add_int16(a, b),
BinaryFunc::AddInt32 => add_int32(a, b),
BinaryFunc::AddInt64 => add_int64(a, b),
Expand Down Expand Up @@ -3823,6 +3930,7 @@ impl BinaryFunc {
| EncodedBytesCharLength
| SubDate => ScalarType::Int32.nullable(in_nullable),

Wasm |
AddInt64 | SubInt64 | MulInt64 | DivInt64 | ModInt64 | BitAndInt64 | BitOrInt64
| BitXorInt64 | BitShiftLeftInt64 | BitShiftRightInt64 => {
ScalarType::Int64.nullable(in_nullable)
Expand Down Expand Up @@ -4004,7 +4112,8 @@ impl BinaryFunc {
pub fn introduces_nulls(&self) -> bool {
use BinaryFunc::*;
match self {
AddInt16
Wasm
| AddInt16
| AddInt32
| AddInt64
| AddUInt16
Expand Down Expand Up @@ -4343,7 +4452,8 @@ impl BinaryFunc {
| RangeUnion
| RangeIntersection
| RangeDifference => true,
ToCharTimestamp
Wasm
| ToCharTimestamp
| ToCharTimestampTz
| AgeTimestamp
| AgeTimestampTz
Expand Down Expand Up @@ -4715,6 +4825,7 @@ impl BinaryFunc {
BinaryFunc::PrettySql => (false, false),
BinaryFunc::RegexpReplace { .. } => (false, false),
BinaryFunc::StartsWith => (false, false),
BinaryFunc::Wasm => (false, false),
}
}
}
Expand Down Expand Up @@ -4933,6 +5044,7 @@ impl fmt::Display for BinaryFunc {
limit
),
BinaryFunc::StartsWith => f.write_str("starts_with"),
BinaryFunc::Wasm => f.write_str("wasm"),
}
}
}
Expand Down Expand Up @@ -5356,6 +5468,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
})
}
BinaryFunc::StartsWith => StartsWith(()),
BinaryFunc::Wasm => Wasm(()),
};
ProtoBinaryFunc { kind: Some(kind) }
}
Expand Down Expand Up @@ -5570,6 +5683,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
limit: inner.limit.into_rust()?,
}),
StartsWith(()) => Ok(BinaryFunc::StartsWith),
Wasm(()) => Ok(BinaryFunc::Wasm),
}
} else {
Err(TryFromProtoError::missing_field("ProtoBinaryFunc::kind"))
Expand Down Expand Up @@ -6068,7 +6182,8 @@ derive_unary!(
KafkaMurmur2String,
SeahashBytes,
SeahashString,
Reverse
Reverse,
UnaryWasm2
);

impl UnaryFunc {
Expand Down Expand Up @@ -6490,6 +6605,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
use crate::scalar::proto_unary_func::Kind::*;
use crate::scalar::proto_unary_func::*;
let kind = match self {
UnaryFunc::UnaryWasm2(stuff) => UnaryWasm(stuff.text.clone()),
UnaryFunc::Not(_) => Not(()),
UnaryFunc::IsNull(_) => IsNull(()),
UnaryFunc::IsTrue(_) => IsTrue(()),
Expand Down Expand Up @@ -6903,6 +7019,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
use crate::scalar::proto_unary_func::Kind::*;
if let Some(kind) = proto.kind {
match kind {
UnaryWasm(text) => Ok(UnaryFunc::UnaryWasm2(UnaryWasm2 { text: text.clone(), data: Default::default() })),
Not(()) => Ok(impls::Not.into()),
IsNull(()) => Ok(impls::IsNull.into()),
IsTrue(()) => Ok(impls::IsTrue.into()),
Expand Down
1 change: 1 addition & 0 deletions src/pgrepr-consts/src/oid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,4 @@ pub const VIEW_MZ_WALLCLOCK_GLOBAL_LAG_OID: u32 = 17054;
pub const SOURCE_MZ_WALLCLOCK_GLOBAL_LAG_HISTOGRAM_RAW_OID: u32 = 17055;
pub const VIEW_MZ_WALLCLOCK_GLOBAL_LAG_HISTOGRAM_OID: u32 = 17056;
pub const TABLE_MZ_SQL_SERVER_SOURCE_TABLES_OID: u32 = 17057;
pub const FUNC_WASM_OID: u32 = 17058;
3 changes: 3 additions & 0 deletions src/sql/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3852,6 +3852,9 @@ pub static MZ_CATALOG_BUILTINS: LazyLock<BTreeMap<&'static str, Func>> = LazyLoc
"starts_with" => Scalar {
params!(String, String) => BinaryFunc::StartsWith => Bool, 3696;
},
"wasm" => Scalar {
params!(Int64, String) => BinaryFunc::Wasm => Int64, oid::FUNC_WASM_OID;
},
"timezone_offset" => Scalar {
params!(String, TimestampTz) => BinaryFunc::TimezoneOffset => RecordAny, oid::FUNC_TIMEZONE_OFFSET;
},
Expand Down