Skip to content

Commit

Permalink
feat: add support for Databricks dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
retro committed May 1, 2024
1 parent 89dbeeb commit bcb0af2
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 70 deletions.
34 changes: 34 additions & 0 deletions src/__tests__/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2464,5 +2464,39 @@ await describe("semantic layer", async () => {

assert.deepEqual(query.bindings, [5000]);
});

await it("can build SQL for ANSI", () => {
const ansiQueryBuilder = repository.build("ansi");
const query = ansiQueryBuilder.buildQuery(
{
members: ["invoices.invoice_id"],
},
{ schema: "public" },
);

assert.equal(
query.sql,
'select "q0"."invoices___invoice_id" as "invoices___invoice_id" from (select "invoices_query"."invoices___invoice_id" as "invoices___invoice_id" from (select distinct "invoices"."InvoiceId" as "invoices___invoice_id" from (select * from "public"."Invoice") as "invoices") as "invoices_query") as "q0" order by "invoices___invoice_id" asc limit ?',
);

assert.deepEqual(query.bindings, [5000]);
});

await it("can build SQL for Databricks", () => {
const ansiQueryBuilder = repository.build("databricks");
const query = ansiQueryBuilder.buildQuery(
{
members: ["invoices.invoice_id"],
},
{ schema: "public" },
);

assert.equal(
query.sql,
"select `q0`.`invoices___invoice_id` as `invoices___invoice_id` from (select `invoices_query`.`invoices___invoice_id` as `invoices___invoice_id` from (select distinct `invoices`.`InvoiceId` as `invoices___invoice_id` from (select * from `public`.`Invoice`) as `invoices`) as `invoices_query`) as `q0` order by `invoices___invoice_id` asc limit ?",
);

assert.deepEqual(query.bindings, [5000]);
});
});
});
36 changes: 23 additions & 13 deletions src/lib/dialect/base.ts → src/lib/dialect/ansi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { From, SqlFragment, SqlQueryBuilder } from "./sql-query-builder.js";

import { Granularity } from "../types.js";

export class BaseDialect {
export class AnsiDialect {
withGranularity(granularity: Granularity, sql: string) {
switch (granularity) {
case "time":
Expand Down Expand Up @@ -38,10 +38,12 @@ export class BaseDialect {
throw new Error(`Unrecognized granularity: ${granularity}`);
}
}

asIdentifier(value: string) {
if (value === "*") return value;
return `"${value}"`;
}

aggregate(aggregateWith: string, sql: string) {
if (aggregateWith === "sum") {
return `COALESCE(SUM(${sql}), 0)`;
Expand All @@ -50,6 +52,25 @@ export class BaseDialect {
return `${aggregateWith.toUpperCase()}(${sql})`;
}

ilike(
startsWith: boolean,
endsWith: boolean,
negation: boolean,
memberSql: string,
) {
let like = "?";
if (startsWith) {
like = `'%' || ${like}`;
}
if (endsWith) {
like = `${like} || '%'`;
}
if (negation) {
return `${memberSql} not ilike ${like}`;
}
return `${memberSql} ilike ${like}`;
}

from(from: From) {
return new SqlQueryBuilder(this, from);
}
Expand All @@ -59,17 +80,6 @@ export class BaseDialect {
}

sqlToNative(sql: string) {
return this.positionBindings(sql);
}

positionBindings(sql: string) {
let questionCount = 0;
return sql.replace(/(\\*)(\?)/g, (_match, escapes) => {
if (escapes.length % 2) {
return "?";
}
questionCount++;
return `$${questionCount}`;
});
return sql;
}
}
8 changes: 8 additions & 0 deletions src/lib/dialect/databricks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { AnsiDialect } from "./ansi.js";

export class DatabricksDialect extends AnsiDialect {
asIdentifier(value: string) {
if (value === "*") return value;
return `\`${value}\``;
}
}
17 changes: 17 additions & 0 deletions src/lib/dialect/postgresql.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { AnsiDialect } from "./ansi.js";

export class PostgresqlDialect extends AnsiDialect {
sqlToNative(sql: string) {
return this.positionBindings(sql);
}
positionBindings(sql: string) {
let questionCount = 0;
return sql.replace(/(\\*)(\?)/g, (_match, escapes) => {
if (escapes.length % 2) {
return "?";
}
questionCount++;
return `$${questionCount}`;
});
}
}
4 changes: 2 additions & 2 deletions src/lib/dialect/sql-query-builder.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { SqlQuery, toSQL } from "./sql-query-builder/to-sql.js";

import { BaseDialect } from "./base.js";
import { AnsiDialect } from "./ansi.js";

export interface QueryJoin {
table: string | SqlFragment | SqlQueryBuilder;
Expand Down Expand Up @@ -35,7 +35,7 @@ export class SqlQueryBuilder {
joins: [],
};
constructor(
public readonly dialect: BaseDialect,
public readonly dialect: AnsiDialect,
public readonly from: From,
) {}

Expand Down
4 changes: 2 additions & 2 deletions src/lib/dialect/sql-query-builder/to-sql.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { SqlFragment, SqlQueryBuilder } from "../sql-query-builder.js";

import { BaseDialect } from "../base.js";
import { AnsiDialect } from "../ansi.js";

export class SqlQuery {
constructor(
private readonly dialect: BaseDialect,
private readonly dialect: AnsiDialect,
public readonly sql: string,
public readonly bindings: unknown[],
) {}
Expand Down
12 changes: 6 additions & 6 deletions src/lib/join.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import invariant from "tiny-invariant";
import type { BaseDialect } from "./dialect/base.js";
import type { AnsiDialect } from "./dialect/ansi.js";
import { AnyModel } from "./model.js";
import type { AnyRepository } from "./repository.js";
import { SqlWithBindings } from "./types.js";

export abstract class JoinRef {
public abstract render(
repository: AnyRepository,
dialect: BaseDialect,
dialect: AnsiDialect,
): SqlWithBindings;
}

Expand All @@ -22,7 +22,7 @@ export class JoinDimensionRef<
) {
super();
}
render(repository: AnyRepository, dialect: BaseDialect) {
render(repository: AnyRepository, dialect: AnsiDialect) {
return repository
.getModel(this.model)
.getDimension(this.dimension)
Expand All @@ -38,7 +38,7 @@ export class JoinColumnRef<N extends string> extends JoinRef {
) {
super();
}
render(repository: AnyRepository, dialect: BaseDialect) {
render(repository: AnyRepository, dialect: AnsiDialect) {
const model = repository.getModel(this.model);
const { sql: asSql, bindings } = model.getAs(dialect, this.context);
return {
Expand All @@ -52,7 +52,7 @@ export class JoinIdentifierRef extends JoinRef {
constructor(private readonly identifier: string) {
super();
}
render(_repository: AnyRepository, dialect: BaseDialect) {
render(_repository: AnyRepository, dialect: AnsiDialect) {
return {
sql: dialect.asIdentifier(this.identifier),
bindings: [],
Expand All @@ -79,7 +79,7 @@ export class JoinOnDef {
private readonly strings: string[],
private readonly values: unknown[],
) {}
render(repository: AnyRepository, dialect: BaseDialect) {
render(repository: AnyRepository, dialect: AnsiDialect) {
const sql: string[] = [];
const bindings: unknown[] = [];
for (let i = 0; i < this.strings.length; i++) {
Expand Down
32 changes: 16 additions & 16 deletions src/lib/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import {
} from "./types.js";

import { Simplify } from "type-fest";
import { BaseDialect } from "./dialect/base.js";
import { AnsiDialect } from "./dialect/ansi.js";
import { sqlAsSqlWithBindings } from "./query-builder/util.js";

export abstract class ModelRef {
public abstract render(
dialect: BaseDialect,
dialect: AnsiDialect,
context: unknown,
): SqlWithBindings;
}
Expand All @@ -28,7 +28,7 @@ export class ColumnRef extends ModelRef {
) {
super();
}
render(dialect: BaseDialect, context: unknown) {
render(dialect: AnsiDialect, context: unknown) {
const { sql: asSql, bindings } = this.model.getAs(dialect, context);
const sql = `${asSql}.${dialect.asIdentifier(this.name)}`;
return {
Expand All @@ -42,7 +42,7 @@ export class IdentifierRef extends ModelRef {
constructor(private readonly identifier: string) {
super();
}
render(dialect: BaseDialect, _context: unknown) {
render(dialect: AnsiDialect, _context: unknown) {
return {
sql: dialect.asIdentifier(this.identifier),
bindings: [],
Expand All @@ -54,7 +54,7 @@ export class DimensionRef extends ModelRef {
constructor(private readonly dimension: Dimension) {
super();
}
render(dialect: BaseDialect, context: unknown) {
render(dialect: AnsiDialect, context: unknown) {
return this.dimension.getSql(dialect, context);
}
}
Expand All @@ -66,7 +66,7 @@ export class SqlWithRefs extends ModelRef {
) {
super();
}
render(dialect: BaseDialect, context: unknown) {
render(dialect: AnsiDialect, context: unknown) {
const sql: string[] = [];
const bindings: unknown[] = [];
for (let i = 0; i < this.strings.length; i++) {
Expand Down Expand Up @@ -147,14 +147,14 @@ export abstract class Member {
public abstract props: AnyDimensionProps | AnyMetricProps;

abstract getSql(
dialect: BaseDialect,
dialect: AnsiDialect,
context: unknown,
modelAlias?: string,
): SqlWithBindings;
abstract isMetric(): this is Metric;
abstract isDimension(): this is Dimension;

getAlias(dialect: BaseDialect) {
getAlias(dialect: AnsiDialect) {
return dialect.asIdentifier(
`${this.model.name}___${this.name.replaceAll(".", "___")}`,
);
Expand All @@ -163,7 +163,7 @@ export abstract class Member {
return `${this.model.name}.${this.name}`;
}
renderSql(
dialect: BaseDialect,
dialect: AnsiDialect,
context: unknown,
): SqlWithBindings | undefined {
if (this.props.sql) {
Expand Down Expand Up @@ -200,7 +200,7 @@ export class Dimension extends Member {
) {
super();
}
getSql(dialect: BaseDialect, context: unknown, modelAlias?: string) {
getSql(dialect: AnsiDialect, context: unknown, modelAlias?: string) {
if (modelAlias) {
return sqlAsSqlWithBindings(
`${dialect.asIdentifier(modelAlias)}.${this.getAlias(dialect)}`,
Expand All @@ -216,7 +216,7 @@ export class Dimension extends Member {
}
return result;
}
getSqlWithoutGranularity(dialect: BaseDialect, context: unknown) {
getSqlWithoutGranularity(dialect: AnsiDialect, context: unknown) {
const result = this.renderSql(dialect, context);

if (result) {
Expand Down Expand Up @@ -255,7 +255,7 @@ export class Metric extends Member {
) {
super();
}
getSql(dialect: BaseDialect, context: unknown, modelAlias?: string) {
getSql(dialect: AnsiDialect, context: unknown, modelAlias?: string) {
if (modelAlias) {
return sqlAsSqlWithBindings(
`${dialect.asIdentifier(modelAlias)}.${this.getAlias(dialect)}`,
Expand All @@ -276,7 +276,7 @@ export class Metric extends Member {
bindings,
};
}
getAggregateSql(dialect: BaseDialect, context: unknown, modelAlias?: string) {
getAggregateSql(dialect: AnsiDialect, context: unknown, modelAlias?: string) {
const { sql, bindings } = this.getSql(dialect, context, modelAlias);
return {
sql: dialect.aggregate(this.props.aggregateWith, sql),
Expand Down Expand Up @@ -372,7 +372,7 @@ export class Model<
getMetrics() {
return Object.values(this.metrics);
}
getTableName(dialect: BaseDialect, context: C) {
getTableName(dialect: AnsiDialect, context: C) {
if (this.config.type === "table") {
if (typeof this.config.name === "string") {
return {
Expand All @@ -396,14 +396,14 @@ export class Model<

throw new Error("Model is not a table");
}
getAs(dialect: BaseDialect, context: C) {
getAs(dialect: AnsiDialect, context: C) {
if (this.config.type === "sqlQuery") {
return { sql: dialect.asIdentifier(this.config.alias), bindings: [] };
}

return this.getTableName(dialect, context);
}
getSql(dialect: BaseDialect, context: C) {
getSql(dialect: AnsiDialect, context: C) {
if (this.config.type === "sqlQuery") {
const result = this.config.sql({
identifier: (name: string) => new IdentifierRef(name),
Expand Down
4 changes: 2 additions & 2 deletions src/lib/query-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
} from "./types.js";

import { Simplify } from "type-fest";
import { BaseDialect } from "./dialect/base.js";
import { AnsiDialect } from "./dialect/ansi.js";
import { SqlQuery } from "./dialect/sql-query-builder/to-sql.js";
import { buildQuery } from "./query-builder/build-query.js";
import { FilterBuilder } from "./query-builder/filter-builder.js";
Expand Down Expand Up @@ -64,7 +64,7 @@ export class QueryBuilder<
public readonly querySchema: QuerySchema;
constructor(
public readonly repository: AnyRepository,
public readonly dialect: BaseDialect,
public readonly dialect: AnsiDialect,
) {
this.querySchema = buildQuerySchema(this);
}
Expand Down
6 changes: 3 additions & 3 deletions src/lib/query-builder/build-query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
} from "../types.js";

import invariant from "tiny-invariant";
import { BaseDialect } from "../dialect/base.js";
import { AnsiDialect } from "../dialect/ansi.js";
import type { AnyJoin } from "../join.js";
import { AnyModel } from "../model.js";
import { AnyQueryBuilder } from "../query-builder.js";
Expand Down Expand Up @@ -57,7 +57,7 @@ function getDefaultOrderBy(repository: AnyRepository, query: Query) {
}

function initializeQuerySegment(
dialect: BaseDialect,
dialect: AnsiDialect,
context: unknown,
model: AnyModel,
) {
Expand All @@ -75,7 +75,7 @@ function initializeQuerySegment(
}

function getJoinSubject(
dialect: BaseDialect,
dialect: AnsiDialect,
context: unknown,
model: AnyModel,
) {
Expand Down
Loading

0 comments on commit bcb0af2

Please sign in to comment.