diff --git a/README.md b/README.md index 5242f07..76cb374 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ This library allows you to define models and their respective fields, including import * as semanticLayer from "@verybigthings/semantic-layer"; const customersModel = semanticLayer - .model("customers") + .model() + .withName("customers") .fromTable("Customer") .withDimension("customer_id", { type: "number", @@ -53,7 +54,8 @@ const customersModel = semanticLayer }); const invoicesModel = semanticLayer - .model("invoices") + .model() + .withName("invoices") .fromTable("Invoice") .withDimension("invoice_id", { type: "number", @@ -68,7 +70,8 @@ const invoicesModel = semanticLayer }); const invoiceLinesModel = semanticLayer - .model("invoice_lines") + .model() + .withName("invoice_lines") .fromTable("InvoiceLine") .withDimension("invoice_line_id", { type: "number", diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 9454ab9..cfb7b8a 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -13,50 +13,10 @@ import path from "node:path"; import pg from "pg"; import { zodToJsonSchema } from "zod-to-json-schema"; -//import { format as sqlFormat } from "sql-formatter"; +// import { format as sqlFormat } from "sql-formatter"; const __dirname = path.dirname(new URL(import.meta.url).pathname); -/*const query = built.query({ - dimensions: [ - "customers.customer_id", - //'invoice_lines.invoice_line_id', - //'invoices.invoice_id', - //'Track.track_id,', - //'albums.title', - ], - metrics: ["invoice_lines.total_unit_price", "invoices.total"], - filters: [ - { - operator: "inDateRange", - member: "invoices.invoice_date", - value: "from Jan 1st 2011 at 00:00 to Dec 31th 2012 23:00", - }, - { operator: 'set', member: 'customers.customer_id' }, - { - operator: 'notContains', - member: 'invoice_lines.total_unit_price', - value: ['0.99', '1'], - }, - { operator: 'notEquals', member: 'invoices.total', value: ['0.99'] }, - { - operator: 'or', - filters: [ - { operator: 'notEquals', member: 'invoices.invoice_id', value: ['1'] }, - { - operator: 'notEquals', - member: 'invoice_lines.invoice_line_id', - value: ['3'], - }, - ], - }, - ], - order: { - // 'invoice_lines.unit_price': 'asc', 'customers.customer_id': 'asc', - "invoices.invoice_date.year": "desc", - }, -});*/ - await describe("semantic layer", async () => { let container: StartedPostgreSqlContainer; let client: pg.Client; @@ -88,7 +48,8 @@ await describe("semantic layer", async () => { await describe("models from tables", async () => { const customersModel = semanticLayer - .model("customers") + .model() + .withName("customers") .fromTable("Customer") .withDimension("customer_id", { type: "number", @@ -116,7 +77,8 @@ await describe("semantic layer", async () => { }); const invoicesModel = semanticLayer - .model("invoices") + .model() + .withName("invoices") .fromTable("Invoice") .withDimension("invoice_id", { type: "number", @@ -138,7 +100,8 @@ await describe("semantic layer", async () => { }); const invoiceLinesModel = semanticLayer - .model("invoice_lines") + .model() + .withName("invoice_lines") .fromTable("InvoiceLine") .withDimension("invoice_line_id", { type: "number", @@ -165,7 +128,8 @@ await describe("semantic layer", async () => { }); const tracksModel = semanticLayer - .model("tracks") + .model() + .withName("tracks") .fromTable("Track") .withDimension("track_id", { type: "number", @@ -182,7 +146,8 @@ await describe("semantic layer", async () => { }); const albumsModel = semanticLayer - .model("albums") + .model() + .withName("albums") .fromTable("Album") .withDimension("album_id", { type: "number", @@ -204,26 +169,34 @@ await describe("semantic layer", async () => { .joinOneToMany( "customers", "invoices", - ({ sql, dimensions }) => - sql`${dimensions.customers.customer_id} = ${dimensions.invoices.customer_id}`, + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, ) .joinOneToMany( "invoices", "invoice_lines", - ({ sql, dimensions }) => - sql`${dimensions.invoices.invoice_id} = ${dimensions.invoice_lines.invoice_id}`, + ({ sql, models, getContext }) => + sql`${models.invoices.dimension( + "invoice_id", + )} = ${models.invoice_lines.dimension("invoice_id")} ${getContext()}`, ) .joinOneToMany( "invoice_lines", "tracks", - ({ sql, dimensions }) => - sql`${dimensions.invoice_lines.track_id} = ${dimensions.tracks.track_id}`, + ({ sql, models }) => + sql`${models.invoice_lines.dimension( + "track_id", + )} = ${models.tracks.dimension("track_id")}`, ) .joinManyToMany( "tracks", "albums", - ({ sql, dimensions }) => - sql`${dimensions.tracks.album_id} = ${dimensions.albums.album_id}`, + ({ sql, models }) => + sql`${models.tracks.dimension( + "album_id", + )} = ${models.albums.dimension("album_id")}`, ); const queryBuilder = repository.build("postgresql"); @@ -452,8 +425,11 @@ await describe("semantic layer", async () => { await describe("models from sql queries", async () => { const customersModel = semanticLayer - .model("customers") - .fromSqlQuery('select * from "Customer"') + .model() + .withName("customers") + .fromSqlQuery( + ({ sql, identifier }) => sql`select * from ${identifier("Customer")}`, + ) .withDimension("customer_id", { type: "number", primaryKey: true, @@ -461,8 +437,11 @@ await describe("semantic layer", async () => { }); const invoicesModel = semanticLayer - .model("invoices") - .fromSqlQuery('select * from "Invoice"') + .model() + .withName("invoices") + .fromSqlQuery( + ({ sql, identifier }) => sql`select * from ${identifier("Invoice")}`, + ) .withDimension("invoice_id", { type: "number", primaryKey: true, @@ -485,8 +464,10 @@ await describe("semantic layer", async () => { .joinOneToMany( "customers", "invoices", - ({ sql, dimensions }) => - sql`${dimensions.customers.customer_id} = ${dimensions.invoices.customer_id}`, + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, ); const queryBuilder = repository.build("postgresql"); @@ -552,8 +533,11 @@ await describe("semantic layer", async () => { await describe("query schema", async () => { await it("can parse a valid query", () => { const customersModel = semanticLayer - .model("customers") - .fromSqlQuery('select * from "Customer"') + .model() + .withName("customers") + .fromSqlQuery( + ({ sql, identifier }) => sql`select * from ${identifier("Customer")}`, + ) .withDimension("customer_id", { type: "number", primaryKey: true, @@ -561,8 +545,11 @@ await describe("semantic layer", async () => { }); const invoicesModel = semanticLayer - .model("invoices") - .fromSqlQuery('select * from "Invoice"') + .model() + .withName("invoices") + .fromSqlQuery( + ({ sql, identifier }) => sql`select * from ${identifier("Invoice")}`, + ) .withDimension("invoice_id", { type: "number", primaryKey: true, @@ -585,8 +572,10 @@ await describe("semantic layer", async () => { .joinOneToMany( "customers", "invoices", - ({ sql, dimensions }) => - sql`${dimensions.customers.customer_id} = ${dimensions.invoices.customer_id}`, + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, ); const queryBuilder = repository.build("postgresql"); @@ -1097,8 +1086,9 @@ await describe("semantic layer", async () => { await describe("model descriptions and query introspection", async () => { const customersModel = semanticLayer - .model("customers") - .fromSqlQuery('select * from "Customer"') + .model() + .withName("customers") + .fromTable("Customer") .withDimension("customer_id", { type: "number", primaryKey: true, @@ -1107,8 +1097,9 @@ await describe("semantic layer", async () => { }); const invoicesModel = semanticLayer - .model("invoices") - .fromSqlQuery('select * from "Invoice"') + .model() + .withName("invoices") + .fromTable("Invoice") .withDimension("invoice_id", { type: "number", primaryKey: true, @@ -1134,8 +1125,10 @@ await describe("semantic layer", async () => { .joinOneToMany( "customers", "invoices", - ({ sql, dimensions }) => - sql`${dimensions.customers.customer_id} = ${dimensions.invoices.customer_id}`, + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, ); const queryBuilder = repository.build("postgresql"); @@ -1214,7 +1207,8 @@ await describe("semantic layer", async () => { await describe("full repository", async () => { const customersModel = semanticLayer - .model("customers") + .model() + .withName("customers") .fromTable("Customer") .withDimension("customer_id", { type: "number", @@ -1274,7 +1268,8 @@ await describe("semantic layer", async () => { }); const invoicesModel = semanticLayer - .model("invoices") + .model() + .withName("invoices") .fromTable("Invoice") .withDimension("invoice_id", { type: "number", @@ -1316,12 +1311,13 @@ await describe("semantic layer", async () => { .withMetric("sum_total", { type: "number", aggregateWith: "sum", - description: "Sum of the invoice totals across dimensions.", + description: "Sum of the invoice totals across models.", sql: ({ model }) => model.dimension("total"), }); const invoiceLinesModel = semanticLayer - .model("invoice_lines") + .model() + .withName("invoice_lines") .fromTable("InvoiceLine") .withDimension("invoice_line_id", { type: "number", @@ -1347,18 +1343,19 @@ await describe("semantic layer", async () => { .withMetric("sum_quantity", { type: "number", aggregateWith: "sum", - description: "Sum of the track quantities across dimensions.", + description: "Sum of the track quantities across models.", sql: ({ model }) => model.dimension("quantity"), }) .withMetric("sum_unit_price", { type: "number", aggregateWith: "sum", - description: "Sum of the track unit prices across dimensions.", + description: "Sum of the track unit prices across models.", sql: ({ model }) => model.dimension("unit_price"), }); const tracksModel = semanticLayer - .model("tracks") + .model() + .withName("tracks") .fromTable("Track") .withDimension("track_id", { type: "number", @@ -1400,12 +1397,13 @@ await describe("semantic layer", async () => { .withMetric("sum_unit_price", { type: "number", aggregateWith: "sum", - description: "Sum of the track unit prices across dimensions.", + description: "Sum of the track unit prices across models.", sql: ({ model }) => model.dimension("unit_price"), }); const albumsModel = semanticLayer - .model("albums") + .model() + .withName("albums") .fromTable("Album") .withDimension("album_id", { type: "number", @@ -1422,7 +1420,8 @@ await describe("semantic layer", async () => { }); const artistModel = semanticLayer - .model("artists") + .model() + .withName("artists") .fromTable("Artist") .withDimension("artist_id", { type: "number", @@ -1435,7 +1434,8 @@ await describe("semantic layer", async () => { }); const mediaTypeModel = semanticLayer - .model("media_types") + .model() + .withName("media_types") .fromTable("MediaType") .withDimension("media_type_id", { type: "number", @@ -1448,7 +1448,8 @@ await describe("semantic layer", async () => { }); const genreModel = semanticLayer - .model("genres") + .model() + .withName("genres") .fromTable("Genre") .withDimension("name", { type: "string", @@ -1461,7 +1462,8 @@ await describe("semantic layer", async () => { }); const playlistModel = semanticLayer - .model("playlists") + .model() + .withName("playlists") .fromTable("Playlist") .withDimension("playlist_id", { type: "number", @@ -1474,7 +1476,8 @@ await describe("semantic layer", async () => { }); const playlistTrackModel = semanticLayer - .model("playlist_tracks") + .model() + .withName("playlist_tracks") .fromTable("PlaylistTrack") .withDimension("playlist_id", { type: "number", @@ -1500,56 +1503,74 @@ await describe("semantic layer", async () => { .joinOneToMany( "customers", "invoices", - ({ sql, dimensions }) => - sql`${dimensions.customers.customer_id} = ${dimensions.invoices.customer_id}`, + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, ) .joinOneToMany( "invoices", "invoice_lines", - ({ sql, dimensions }) => - sql`${dimensions.invoices.invoice_id} = ${dimensions.invoice_lines.invoice_id}`, + ({ sql, models }) => + sql`${models.invoices.dimension( + "invoice_id", + )} = ${models.invoice_lines.dimension("invoice_id")}`, ) .joinManyToOne( "invoice_lines", "tracks", - ({ sql, dimensions }) => - sql`${dimensions.invoice_lines.track_id} = ${dimensions.tracks.track_id}`, + ({ sql, models }) => + sql`${models.invoice_lines.dimension( + "track_id", + )} = ${models.tracks.dimension("track_id")}`, ) .joinOneToMany( "albums", "tracks", - ({ sql, dimensions }) => - sql`${dimensions.tracks.album_id} = ${dimensions.albums.album_id}`, + ({ sql, models }) => + sql`${models.tracks.dimension( + "album_id", + )} = ${models.albums.dimension("album_id")}`, ) .joinManyToOne( "albums", "artists", - ({ sql, dimensions }) => - sql`${dimensions.albums.artist_id} = ${dimensions.artists.artist_id}`, + ({ sql, models }) => + sql`${models.albums.dimension( + "artist_id", + )} = ${models.artists.dimension("artist_id")}`, ) .joinOneToOne( "tracks", "media_types", - ({ sql, dimensions }) => - sql`${dimensions.tracks.media_type_id} = ${dimensions.media_types.media_type_id}`, + ({ sql, models }) => + sql`${models.tracks.dimension( + "media_type_id", + )} = ${models.media_types.dimension("media_type_id")}`, ) .joinOneToOne( "tracks", "genres", - ({ sql, dimensions }) => - sql`${dimensions.tracks.genre_id} = ${dimensions.genres.genre_id}`, + ({ sql, models }) => + sql`${models.tracks.dimension( + "genre_id", + )} = ${models.genres.dimension("genre_id")}`, ) .joinManyToMany( "playlists", "playlist_tracks", - ({ sql, dimensions }) => - sql`${dimensions.playlists.playlist_id} = ${dimensions.playlist_tracks.playlist_id}`, + ({ sql, models }) => + sql`${models.playlists.dimension( + "playlist_id", + )} = ${models.playlist_tracks.dimension("playlist_id")}`, ) .joinManyToMany( "playlist_tracks", "tracks", - ({ sql, dimensions }) => - sql`${dimensions.playlist_tracks.track_id} = ${dimensions.tracks.track_id}`, + ({ sql, models }) => + sql`${models.playlist_tracks.dimension( + "track_id", + )} = ${models.tracks.dimension("track_id")}`, ); const queryBuilder = repository.build("postgresql"); @@ -1678,4 +1699,79 @@ await describe("semantic layer", async () => { assert.deepEqual(query, parsedQuery); }); }); + + describe("repository with context", async () => { + await it("propagates context to all sql functions", async () => { + type QueryContext = { + customerId: number; + }; + + const customersModel = semanticLayer + .model() + .withName("customers") + .fromSqlQuery( + ({ sql, identifier, getContext }) => + sql`select * from ${identifier("Customer")} where ${identifier( + "CustomerId", + )} = ${getContext().customerId}`, + ) + .withDimension("customer_id", { + type: "number", + primaryKey: true, + sql: ({ model, sql, getContext }) => + sql`${model.column("CustomerId")} || cast(${ + getContext().customerId + } as text)`, + }) + .withDimension("first_name", { + type: "string", + sql: ({ model }) => model.column("FirstName"), + }); + + const invoicesModel = semanticLayer + .model() + .withName("invoices") + .fromTable("Invoice") + .withDimension("invoice_id", { + type: "number", + primaryKey: true, + sql: ({ model }) => model.column("InvoiceId"), + }) + .withDimension("customer_id", { + type: "number", + sql: ({ model }) => model.column("CustomerId"), + }); + + const repository = semanticLayer + .repository() + .withModel(customersModel) + .withModel(invoicesModel) + .joinOneToMany( + "customers", + "invoices", + ({ sql, models, getContext }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")} and ${ + getContext().customerId + } = ${getContext().customerId}`, + ); + + const queryBuilder = repository.build("postgresql"); + const query = queryBuilder.buildQuery( + { + dimensions: ["customers.customer_id", "invoices.invoice_id"], + }, + { customerId: 1 }, + ); + + assert.equal( + query.sql, + 'select "q0"."customers___customer_id" as "customers___customer_id", "q0"."invoices___invoice_id" as "invoices___invoice_id" from (select "invoices_query"."customers___customer_id" as "customers___customer_id", "invoices_query"."invoices___invoice_id" as "invoices___invoice_id" from (select distinct "Invoice"."InvoiceId" as "invoices___invoice_id", "customers"."CustomerId" || cast($1 as text) as "customers___customer_id" from "Invoice" right join (select * from "Customer" where "CustomerId" = $2) as customers on "customers"."CustomerId" || cast($3 as text) = "Invoice"."CustomerId" and $4 = $5) as "invoices_query") as "q0" order by "customers___customer_id" asc limit $6', + ); + + // First 5 bindings are for the customerId, last one is for the limit + assert.deepEqual(query.bindings, [1, 1, 1, 1, 1, 5000]); + }); + }); }); diff --git a/src/lib/dialect/base.ts b/src/lib/dialect/base.ts index b532b39..88f2717 100644 --- a/src/lib/dialect/base.ts +++ b/src/lib/dialect/base.ts @@ -1,8 +1,6 @@ -import knex from "knex"; import { Granularity } from "../types.js"; export class BaseDialect { - constructor(private sqlQuery: knex.Knex.QueryBuilder) {} withGranularity(granularity: Granularity, sql: string) { switch (granularity) { case "day": @@ -28,8 +26,7 @@ export class BaseDialect { } } asIdentifier(value: string) { - return this.sqlQuery.client - .wrapIdentifier(value, this.sqlQuery.queryContext()) - .trim(); + if (value === "*") return value; + return `"${value}"`; } } diff --git a/src/lib/join.ts b/src/lib/join.ts index 29f2d8e..de953ec 100644 --- a/src/lib/join.ts +++ b/src/lib/join.ts @@ -1,31 +1,92 @@ +import invariant from "tiny-invariant"; import type { BaseDialect } from "./dialect/base.js"; -import type { Repository } from "./repository.js"; +import { AnyModel } from "./model.js"; +import type { AnyRepository } from "./repository.js"; +import { SqlWithBindings } from "./types.js"; -export class JoinDimensionRef { +export abstract class JoinRef { + public abstract render( + repository: AnyRepository, + dialect: BaseDialect, + ): SqlWithBindings; +} + +export class JoinDimensionRef< + N extends string, + DN extends string, +> extends JoinRef { constructor( private readonly model: N, private readonly dimension: DN, - ) {} - render(repository: Repository, dialect: BaseDialect) { + private readonly context: unknown, + ) { + super(); + } + render(repository: AnyRepository, dialect: BaseDialect) { return repository .getModel(this.model) .getDimension(this.dimension) - .getSql(dialect); + .getSql(dialect, this.context); } } + +export class JoinColumnRef extends JoinRef { + constructor( + private readonly model: N, + private readonly column: string, + ) { + super(); + } + render(repository: AnyRepository, dialect: BaseDialect) { + const model = repository.getModel(this.model); + return { + sql: `${dialect.asIdentifier(model.getAs())}.${dialect.asIdentifier( + this.column, + )}`, + bindings: [], + }; + } +} + +export class JoinIdentifierRef extends JoinRef { + constructor(private readonly identifier: string) { + super(); + } + render(_repository: AnyRepository, dialect: BaseDialect) { + return { + sql: dialect.asIdentifier(this.identifier), + bindings: [], + }; + } +} + +export function makeModelJoinPayload(model: AnyModel, context: unknown) { + return { + dimension: (name: string) => { + const dimension = model.getDimension(name); + invariant( + dimension, + `Dimension ${name} not found in model ${model.name}`, + ); + return new JoinDimensionRef(model.name, name, context); + }, + column: (name: string) => new JoinColumnRef(model.name, name), + }; +} + export class JoinOnDef { constructor( private readonly strings: string[], private readonly values: unknown[], ) {} - render(repository: Repository, dialect: BaseDialect) { + render(repository: AnyRepository, dialect: BaseDialect) { const sql: string[] = []; const bindings: unknown[] = []; for (let i = 0; i < this.strings.length; i++) { sql.push(this.strings[i]!); if (this.values[i]) { const value = this.values[i]; - if (value instanceof JoinDimensionRef) { + if (value instanceof JoinRef) { const result = value.render(repository, dialect); sql.push(result.sql); bindings.push(...result.bindings); @@ -35,6 +96,7 @@ export class JoinOnDef { } } } + return { sql: sql.join(""), bindings, @@ -42,21 +104,26 @@ export class JoinOnDef { } } -export interface Join { +export interface Join { left: string; right: string; - joinOnDef: JoinOnDef; + joinOnDef: (context: C) => JoinOnDef; reversed: boolean; type: "oneToOne" | "oneToMany" | "manyToOne" | "manyToMany"; } +// biome-ignore lint/suspicious/noExplicitAny: +export type AnyJoin = Join; export type JoinFn< + C, DN extends string, N1 extends string, N2 extends string, > = (args: { sql: (strings: TemplateStringsArray, ...values: unknown[]) => JoinOnDef; - dimensions: JoinDimensions; + models: JoinDimensions; + identifier: (name: string) => JoinIdentifierRef; + getContext: () => C; }) => JoinOnDef; export type ModelDimensionsWithoutModelPrefix< @@ -70,22 +137,28 @@ export type JoinDimensions< N2 extends string, > = { [TK in N1]: { - [DK in ModelDimensionsWithoutModelPrefix]: JoinDimensionRef; + dimension: ( + name: ModelDimensionsWithoutModelPrefix, + ) => JoinDimensionRef>; + column: (name: string) => JoinColumnRef; }; } & { [TK in N2]: { - [DK in ModelDimensionsWithoutModelPrefix]: JoinDimensionRef; + dimension: ( + name: ModelDimensionsWithoutModelPrefix, + ) => JoinDimensionRef>; + column: (name: string) => JoinColumnRef; }; }; -export const JOIN_WEIGHTS: Record = { +export const JOIN_WEIGHTS: Record = { oneToOne: 1, oneToMany: 3, manyToOne: 2, manyToMany: 4, }; -export const REVERSED_JOIN: Record = { +export const REVERSED_JOIN: Record = { oneToOne: "oneToOne", oneToMany: "manyToOne", manyToOne: "oneToMany", diff --git a/src/lib/model.ts b/src/lib/model.ts index 44b9d06..ec03398 100644 --- a/src/lib/model.ts +++ b/src/lib/model.ts @@ -7,21 +7,25 @@ import { SqlWithBindings, } from "./types.js"; +import { Simplify } from "type-fest"; import { BaseDialect } from "./dialect/base.js"; import { sqlAsSqlWithBindings } from "./query-builder/util.js"; -export abstract class Ref { - public abstract render(dialect: BaseDialect): SqlWithBindings; +export abstract class ModelRef { + public abstract render( + dialect: BaseDialect, + context: unknown, + ): SqlWithBindings; } -export class ColumnRef extends Ref { +export class ColumnRef extends ModelRef { constructor( public readonly model: AnyModel, public readonly name: string, ) { super(); } - render(dialect: BaseDialect) { + render(dialect: BaseDialect, _context: unknown) { const sql = `${dialect.asIdentifier( this.model.getAs(), )}.${dialect.asIdentifier(this.name)}`; @@ -32,31 +36,43 @@ export class ColumnRef extends Ref { } } -export class DimensionRef extends Ref { +export class IdentifierRef extends ModelRef { + constructor(private readonly identifier: string) { + super(); + } + render(dialect: BaseDialect, _context: unknown) { + return { + sql: dialect.asIdentifier(this.identifier), + bindings: [], + }; + } +} + +export class DimensionRef extends ModelRef { constructor(private readonly dimension: Dimension) { super(); } - render(dialect: BaseDialect) { - return this.dimension.getSql(dialect); + render(dialect: BaseDialect, context: unknown) { + return this.dimension.getSql(dialect, context); } } -export class SqlWithRefs extends Ref { +export class SqlWithRefs extends ModelRef { constructor( public readonly strings: string[], public readonly values: unknown[], ) { super(); } - render(dialect: BaseDialect) { + render(dialect: BaseDialect, context: unknown) { const sql: string[] = []; const bindings: unknown[] = []; for (let i = 0; i < this.strings.length; i++) { sql.push(this.strings[i]!); const nextValue = this.values[i]; if (nextValue) { - if (nextValue instanceof Ref) { - const result = nextValue.render(dialect); + if (nextValue instanceof ModelRef) { + const result = nextValue.render(dialect, context); sql.push(result.sql); bindings.push(...result.bindings); } else { @@ -72,13 +88,21 @@ export class SqlWithRefs extends Ref { } } -export type SqlFn = (args: { +export type MemberSqlFn = (args: { + identifier: (name: string) => IdentifierRef; model: { column: (name: string) => ColumnRef; dimension: (name: DN) => DimensionRef; }; sql: (strings: TemplateStringsArray, ...values: unknown[]) => SqlWithRefs; -}) => Ref; + getContext: () => C; +}) => ModelRef; + +export type ModelSqlFn = (args: { + identifier: (name: string) => IdentifierRef; + sql: (strings: TemplateStringsArray, ...values: unknown[]) => SqlWithRefs; + getContext: () => C; +}) => ModelRef; function typeHasGranularity( type: string, @@ -95,30 +119,40 @@ export type WithGranularityDimensions< } : { [k in N]: T }; -export interface DimensionProps { +export interface DimensionProps { type: MemberType; - sql?: SqlFn; + sql?: MemberSqlFn; format?: MemberFormat; primaryKey?: boolean; description?: string; } +// biome-ignore lint/suspicious/noExplicitAny: +export type AnyDimensionProps = DimensionProps; + export type MetricType = "count" | "sum" | "avg" | "min" | "max"; -export interface MetricProps { +export interface MetricProps { type: MemberType; // TODO: allow custom aggregate functions: ({sql: SqlFn, metric: MetricRef}) => SqlWithRefs aggregateWith: MetricType; - sql?: SqlFn; + sql?: MemberSqlFn; format?: MemberFormat; description?: string; } +// biome-ignore lint/suspicious/noExplicitAny: +export type AnyMetricProps = MetricProps; + export abstract class Member { public abstract readonly name: string; public abstract readonly model: AnyModel; - public abstract props: DimensionProps | MetricProps; + public abstract props: AnyDimensionProps | AnyMetricProps; - abstract getSql(dialect: BaseDialect, modelAlias?: string): SqlWithBindings; + abstract getSql( + dialect: BaseDialect, + context: unknown, + modelAlias?: string, + ): SqlWithBindings; abstract isMetric(): this is Metric; abstract isDimension(): this is Dimension; @@ -130,17 +164,22 @@ export abstract class Member { getPath() { return `${this.model.name}.${this.name}`; } - renderSql(dialect: BaseDialect): SqlWithBindings | undefined { + renderSql( + dialect: BaseDialect, + context: unknown, + ): SqlWithBindings | undefined { if (this.props.sql) { const result = this.props.sql({ + identifier: (name: string) => new IdentifierRef(name), model: { column: (name: string) => new ColumnRef(this.model, name), dimension: (name: string) => new DimensionRef(this.model.getDimension(name)), }, sql: (strings, ...values) => new SqlWithRefs([...strings], values), + getContext: () => context, }); - return result.render(dialect); + return result.render(dialect, context); } } getDescription() { @@ -158,19 +197,19 @@ export class Dimension extends Member { constructor( public readonly model: AnyModel, public readonly name: string, - public readonly props: DimensionProps, + public readonly props: AnyDimensionProps, public readonly granularity?: Granularity, ) { super(); } - getSql(dialect: BaseDialect, modelAlias?: string) { + getSql(dialect: BaseDialect, context: unknown, modelAlias?: string) { if (modelAlias) { return sqlAsSqlWithBindings( `${dialect.asIdentifier(modelAlias)}.${this.getAlias(dialect)}`, ); } const result = - this.renderSql(dialect) ?? + this.renderSql(dialect, context) ?? sqlAsSqlWithBindings( `${dialect.asIdentifier(this.model.getAs())}.${dialect.asIdentifier( this.name, @@ -199,18 +238,18 @@ export class Metric extends Member { constructor( public readonly model: AnyModel, public readonly name: string, - public readonly props: MetricProps, + public readonly props: AnyMetricProps, ) { super(); } - getSql(dialect: BaseDialect, modelAlias?: string) { + getSql(dialect: BaseDialect, context: unknown, modelAlias?: string) { if (modelAlias) { return sqlAsSqlWithBindings( `${dialect.asIdentifier(modelAlias)}.${this.getAlias(dialect)}`, ); } return ( - this.renderSql(dialect) ?? + this.renderSql(dialect, context) ?? sqlAsSqlWithBindings( `${dialect.asIdentifier(this.model.getAs())}.${dialect.asIdentifier( this.name, @@ -218,8 +257,8 @@ export class Metric extends Member { ) ); } - getAggregateSql(dialect: BaseDialect, modelAlias?: string) { - const { sql, bindings } = this.getSql(dialect, modelAlias); + getAggregateSql(dialect: BaseDialect, context: unknown, modelAlias?: string) { + const { sql, bindings } = this.getSql(dialect, context, modelAlias); return { sql: `${this.props.aggregateWith.toUpperCase()}(${sql})`, bindings, @@ -234,12 +273,13 @@ export class Metric extends Member { } // biome-ignore lint/suspicious/noExplicitAny: -export type AnyModel = Model; -export type ModelConfig = +export type AnyModel = Model; +export type ModelConfig = | { type: "table"; name: string } - | { type: "sqlQuery"; alias: string; sql: string }; + | { type: "sqlQuery"; alias: string; sql: ModelSqlFn }; export class Model< + C, N extends string, D extends MemberNameToType = MemberNameToType, M extends MemberNameToType = MemberNameToType, @@ -249,17 +289,17 @@ export class Model< constructor( public readonly name: N, - public readonly config: ModelConfig, + public readonly config: ModelConfig, ) { this.name = name; } withDimension< DN1 extends string, - DP extends DimensionProps, + DP extends DimensionProps, >( name: DN1, dimension: DP, - ): Model, M> { + ): Model>, M> { this.dimensions[name] = new Dimension(this, name, dimension); if (typeHasGranularity(dimension.type)) { const granularity = GranularityByDimensionType[dimension.type]; @@ -274,10 +314,10 @@ export class Model< } return this; } - withMetric>( + withMetric>( name: MN1, metric: MP, - ): Model { + ): Model> { this.metrics[name] = new Metric(this, name, metric); return this; } @@ -316,21 +356,40 @@ export class Model< ? this.config.alias : this.config.name; } + getSql(dialect: BaseDialect, context: C) { + if (this.config.type === "sqlQuery") { + const result = this.config.sql({ + identifier: (name: string) => new IdentifierRef(name), + sql: (strings: TemplateStringsArray, ...values: unknown[]) => + new SqlWithRefs([...strings], values), + getContext: () => context, + }); + return result.render(dialect, context); + } + throw new Error("Model is not a SQL query"); + } } const VALID_NAME_RE = /^[a-zA-Z_][a-zA-Z0-9_]*$/; -export function model(name: N) { - if (!VALID_NAME_RE.test(name)) { - throw new Error(`Invalid model name: ${name}`); - } - +export function model() { return { - fromTable: (tableName: string) => { - return new Model(name, { type: "table", name: tableName }); - }, - fromSqlQuery: (sql: string) => { - return new Model(name, { type: "sqlQuery", alias: name, sql }); + withName: (name: N) => { + if (!VALID_NAME_RE.test(name)) { + throw new Error(`Invalid model name: ${name}`); + } + + return { + fromTable: (tableName?: string) => { + return new Model(name, { + type: "table", + name: tableName ?? name, + }); + }, + fromSqlQuery: (sql: ModelSqlFn) => { + return new Model(name, { type: "sqlQuery", alias: name, sql }); + }, + }; }, }; } diff --git a/src/lib/query-builder.ts b/src/lib/query-builder.ts index fb4266a..cdfb38a 100644 --- a/src/lib/query-builder.ts +++ b/src/lib/query-builder.ts @@ -79,6 +79,7 @@ export function buildQuerySchema(repository: AnyRepository) { } export class QueryBuilder< + C, D extends MemberNameToType, M extends MemberNameToType, F, @@ -86,13 +87,13 @@ export class QueryBuilder< public readonly querySchema: ReturnType; constructor( private readonly repository: AnyRepository, - private readonly Dialect: typeof BaseDialect, + private readonly dialect: BaseDialect, private readonly client: knex.Knex, ) { this.querySchema = buildQuerySchema(repository); } - unsafeBuildQuery(payload: unknown) { + unsafeBuildQuery(payload: unknown, context: unknown) { const query: AnyQuery = this.querySchema.parse(payload); const { referencedModels, segments } = expandQueryToSegments( @@ -108,7 +109,8 @@ export class QueryBuilder< const sqlQuery = buildQuery( this.client, this.repository, - this.Dialect, + this.dialect, + context, query, referencedModels, joinGraph, @@ -130,8 +132,10 @@ export class QueryBuilder< string & keyof M, F & { member: string & (keyof D | keyof M) } >, + ...rest: C extends undefined ? [] : [C] ) { - const { sql, bindings } = this.unsafeBuildQuery(query); + const [context] = rest; + const { sql, bindings } = this.unsafeBuildQuery(query, context); const result: SqlQueryResult< Simplify< @@ -180,6 +184,8 @@ export class QueryBuilder< } export type QueryBuilderQuery = Q extends QueryBuilder< + // biome-ignore lint/suspicious/noExplicitAny: + any, infer D, infer M, infer F diff --git a/src/lib/query-builder/build-query.ts b/src/lib/query-builder/build-query.ts index 1ac68a7..cb68558 100644 --- a/src/lib/query-builder/build-query.ts +++ b/src/lib/query-builder/build-query.ts @@ -5,7 +5,8 @@ import { AnyQuery, ModelQuery, QuerySegment } from "../types.js"; import knex from "knex"; import invariant from "tiny-invariant"; import { BaseDialect } from "../dialect/base.js"; -import type { Join } from "../join.js"; +import type { AnyJoin } from "../join.js"; +import { AnyModel } from "../model.js"; import type { AnyRepository } from "../repository.js"; interface ReferencedModels { @@ -36,24 +37,52 @@ function getDefaultOrderBy(repository: AnyRepository, query: AnyQuery) { return {}; } +function initializeQuerySegment( + knex: knex.Knex, + dialect: BaseDialect, + context: unknown, + model: AnyModel, +) { + if (model.config.type === "table") { + return knex(model.config.name); + } + const modelSql = model.getSql(dialect, context); + return knex( + knex.raw(`(${modelSql.sql}) as ${model.config.alias}`, modelSql.bindings), + ); +} + +function getJoinSubject( + knex: knex.Knex, + dialect: BaseDialect, + context: unknown, + model: AnyModel, +) { + if (model.config.type === "table") { + return model.config.name; + } + const modelSql = model.getSql(dialect, context); + return knex.raw( + `(${modelSql.sql}) as ${model.config.alias}`, + modelSql.bindings, + ); +} + // biome-ignore lint/complexity/noExcessiveCognitiveComplexity: function buildQuerySegmentJoinQuery( knex: knex.Knex, repository: AnyRepository, - Dialect: typeof BaseDialect, + dialect: BaseDialect, + context: unknown, joinGraph: graphlib.Graph, modelQueries: Record, source: string, ) { const visitedModels = new Set(); const model = repository.getModel(source); - const sqlQuery = - model.config.type === "table" - ? knex(model.config.name) - : knex(knex.raw(`(${model.config.sql}) as ${model.config.alias}`)); - const dialect = new Dialect(sqlQuery); + const sqlQuery = initializeQuerySegment(knex, dialect, context, model); - const modelStack: { modelName: string; join?: Join }[] = [ + const modelStack: { modelName: string; join?: AnyJoin }[] = [ { modelName: source }, ]; @@ -80,14 +109,9 @@ function buildQuerySegmentJoinQuery( if (join) { const joinType = join.reversed ? "rightJoin" : "leftJoin"; - const joinOn = join.joinOnDef.render(repository, dialect); + const joinOn = join.joinOnDef(context).render(repository, dialect); const rightModel = repository.getModel(join.right); - const joinSubject = - rightModel.config.type === "table" - ? rightModel.config.name - : knex.raw( - `(${rightModel.config.sql}) as ${rightModel.config.alias}`, - ); + const joinSubject = getJoinSubject(knex, dialect, context, rightModel); sqlQuery[joinType](joinSubject, knex.raw(joinOn.sql, joinOn.bindings)); @@ -99,7 +123,7 @@ function buildQuerySegmentJoinQuery( for (const metricName of modelQuery?.metrics || []) { const metric = repository.getMetric(metricName); - const { sql, bindings } = metric.getSql(dialect); + const { sql, bindings } = metric.getSql(dialect, context); sqlQuery.select( knex.raw(`${sql} as ${metric.getAlias(dialect)}`, bindings), ); @@ -107,7 +131,7 @@ function buildQuerySegmentJoinQuery( for (const dimensionName of dimensionNames) { const dimension = repository.getDimension(dimensionName); - const { sql, bindings } = dimension.getSql(dialect); + const { sql, bindings } = dimension.getSql(dialect, context); sqlQuery.select( knex.raw(`${sql} as ${dimension.getAlias(dialect)}`, bindings), @@ -128,7 +152,8 @@ function buildQuerySegmentJoinQuery( function buildQuerySegment( knex: knex.Knex, repository: AnyRepository, - Dialect: typeof BaseDialect, + dialect: BaseDialect, + context: unknown, joinGraph: graphlib.Graph, segment: QuerySegment, ) { @@ -144,12 +169,12 @@ function buildQuerySegment( const initialSqlQuery = buildQuerySegmentJoinQuery( knex, repository, - Dialect, + dialect, + context, joinGraph, segment.modelQueries, source, ); - const dialect = new Dialect(initialSqlQuery); // If there are no metrics, we need to use DISTINCT to avoid multiplying rows // otherwise GROUP BY will take care of it @@ -165,7 +190,7 @@ function buildQuerySegment( "dimension", segment.referencedModels.all, ) - .buildFilters(segment.query.filters, "and"); + .buildFilters(segment.query.filters, "and", context); if (filter) { initialSqlQuery.where(knex.raw(filter.sql, filter.bindings)); @@ -196,7 +221,7 @@ function buildQuerySegment( for (const metricName of segment.query.metrics || []) { const metric = repository.getMetric(metricName); - const { sql, bindings } = metric.getAggregateSql(dialect, alias); + const { sql, bindings } = metric.getAggregateSql(dialect, context, alias); sqlQuery.select( knex.raw(`${sql} as ${metric.getAlias(dialect)}`, bindings), @@ -214,14 +239,15 @@ function getAlias(index: number) { export function buildQuery( knex: knex.Knex, repository: AnyRepository, - Dialect: typeof BaseDialect, + dialect: BaseDialect, + context: unknown, query: AnyQuery, referencedModels: ReferencedModels, joinGraph: graphlib.Graph, segments: QuerySegment[], ) { const sqlQuerySegments = segments.map((segment) => - buildQuerySegment(knex, repository, Dialect, joinGraph, segment), + buildQuerySegment(knex, repository, dialect, context, joinGraph, segment), ); const [initialSqlQuerySegment, ...restSqlQuerySegments] = sqlQuerySegments; @@ -232,7 +258,6 @@ export function buildQuery( ); const rootAlias = getAlias(0); const rootSqlQuery = knex(initialSqlQuerySegment.sqlQuery.as(rootAlias)); - const dialect = new Dialect(rootSqlQuery); for (const dimensionName of initialSqlQuerySegment.projectedQuery .dimensions || []) { @@ -310,7 +335,7 @@ export function buildQuery( referencedModels.metrics, metricPrefixes, ) - .buildFilters(query.filters, "and"); + .buildFilters(query.filters, "and", context); if (filter) { rootSqlQuery.where(knex.raw(filter.sql, filter.bindings)); } diff --git a/src/lib/query-builder/expand-query.ts b/src/lib/query-builder/expand-query.ts index 1a30815..eb9adda 100644 --- a/src/lib/query-builder/expand-query.ts +++ b/src/lib/query-builder/expand-query.ts @@ -5,9 +5,9 @@ import { QuerySegment, } from "../types.js"; -import { Repository } from "../repository.js"; +import { AnyRepository } from "../repository.js"; -function analyzeQuery(repository: Repository, query: AnyQuery) { +function analyzeQuery(repository: AnyRepository, query: AnyQuery) { const allModels = new Set(); const dimensionModels = new Set(); const metricModels = new Set(); @@ -84,7 +84,7 @@ interface PreparedQuery { // biome-ignore lint/complexity/noExcessiveCognitiveComplexity: function getQuerySegment( - repository: Repository, + repository: AnyRepository, queryAnalysis: ReturnType, metricModel: string | null, index: number, @@ -209,7 +209,10 @@ function mergeQuerySegmentWithFilters( }; } -export function expandQueryToSegments(repository: Repository, query: AnyQuery) { +export function expandQueryToSegments( + repository: AnyRepository, + query: AnyQuery, +) { const queryAnalysis = analyzeQuery(repository, query); const metricModels = Object.keys(queryAnalysis.metricsByModel); const segments = diff --git a/src/lib/query-builder/filter-builder.ts b/src/lib/query-builder/filter-builder.ts index 4d923d6..9b35507 100644 --- a/src/lib/query-builder/filter-builder.ts +++ b/src/lib/query-builder/filter-builder.ts @@ -37,7 +37,7 @@ import { } from "./filter-builder/number-comparison-filter-builder.js"; import { BaseDialect } from "../dialect/base.js"; -import type { Repository } from "../repository.js"; +import type { AnyRepository } from "../repository.js"; import { equals as filterEquals } from "./filter-builder/equals.js"; import { notEquals as filterNotEquals } from "./filter-builder/not-equals.js"; import { sqlAsSqlWithBindings } from "./util.js"; @@ -51,18 +51,21 @@ export class FilterBuilder { AnyFilterFragmentBuilder >, private readonly dialect: BaseDialect, - private readonly repository: Repository, + private readonly repository: AnyRepository, private readonly filterType: FilterType, referencedModels: string[], private readonly metricPrefixes?: Record, ) { this.referencedModels = new Set(referencedModels); } - getMemberSql(memberName: string): SqlWithBindings | undefined { + getMemberSql( + memberName: string, + context: unknown, + ): SqlWithBindings | undefined { const member = this.repository.getMember(memberName); if (this.referencedModels.has(member.model.name)) { if (this.filterType === "dimension" && member.isDimension()) { - return member.getSql(this.dialect); + return member.getSql(this.dialect, context); } if (this.filterType === "metric" && member.isMetric()) { const prefix = this.metricPrefixes?.[member.model.name]; @@ -74,20 +77,26 @@ export class FilterBuilder { } } - buildOr(filter: OrConnective): SqlWithBindings | undefined { - return this.buildFilters(filter.filters, "or"); + buildOr(filter: OrConnective, context: unknown): SqlWithBindings | undefined { + return this.buildFilters(filter.filters, "or", context); } - buildAnd(filter: AndConnective): SqlWithBindings | undefined { - return this.buildFilters(filter.filters, "and"); + buildAnd( + filter: AndConnective, + context: unknown, + ): SqlWithBindings | undefined { + return this.buildFilters(filter.filters, "and", context); } - buildFilter(filter: AnyQueryFilter): SqlWithBindings | undefined { + buildFilter( + filter: AnyQueryFilter, + context: unknown, + ): SqlWithBindings | undefined { if (filter.operator === "and") { - return this.buildAnd(filter); + return this.buildAnd(filter, context); } if (filter.operator === "or") { - return this.buildOr(filter); + return this.buildOr(filter, context); } - const memberSql = this.getMemberSql(filter.member); + const memberSql = this.getMemberSql(filter.member, context); if (memberSql) { const builder = this.filterFragmentBuilders[filter.operator]; if (builder) { @@ -99,10 +108,11 @@ export class FilterBuilder { buildFilters( filters: AnyQueryFilter[], connective: "and" | "or", + context: unknown, ): SqlWithBindings | undefined { const result = filters.reduce<{ sqls: string[]; bindings: unknown[] }>( (acc, filter) => { - const result = this.buildFilter(filter); + const result = this.buildFilter(filter, context); if (result) { acc.sqls.push(result.sql); acc.bindings.push(...result.bindings); @@ -149,7 +159,7 @@ export class FilterFragmentBuilderRegistry { return Object.values(this.filterFragmentBuilders); } getFilterBuilder( - repository: Repository, + repository: AnyRepository, dialect: BaseDialect, filterType: FilterType, referencedModels: string[], diff --git a/src/lib/repository.ts b/src/lib/repository.ts index 7f0c03d..8dc0446 100644 --- a/src/lib/repository.ts +++ b/src/lib/repository.ts @@ -1,11 +1,12 @@ import { + AnyJoin, JOIN_WEIGHTS, - Join, - JoinDimensionRef, JoinDimensions, JoinFn, + JoinIdentifierRef, JoinOnDef, REVERSED_JOIN, + makeModelJoinPayload, } from "./join.js"; import { AnyModel, Model } from "./model.js"; import { @@ -22,16 +23,25 @@ import { BaseDialect } from "./dialect/base.js"; import { QueryBuilder } from "./query-builder.js"; // biome-ignore lint/suspicious/noExplicitAny: Using any for inference -export type ModelN = T extends Model ? N : never; +export type ModelC = T extends Model ? C : never; + +// biome-ignore lint/suspicious/noExplicitAny: Using any for inference +export type ModelN = T extends Model ? N : never; // biome-ignore lint/suspicious/noExplicitAny: Using any for inference -export type ModelD = T extends Model +export type ModelD = T extends Model ? { [K in string & keyof D as `${N}.${K}`]: D[K] } : never; // biome-ignore lint/suspicious/noExplicitAny: Using any for inference -export type ModelM = T extends Model +export type ModelM = T extends Model ? { [K in string & keyof M as `${N}.${K}`]: M[K] } : never; +export type ModelWithMatchingContext = [C] extends [ + ModelC, +] + ? T + : never; + // biome-ignore lint/suspicious/noExplicitAny: Using any for inference export type AnyRepository = Repository; @@ -50,6 +60,7 @@ function getClientAndDialect(dialect: AvailableDialects): { } export class Repository< + C, N extends string = never, D extends MemberNameToType = MemberNameToType, M extends MemberNameToType = MemberNameToType, @@ -60,7 +71,7 @@ export class Repository< private readonly models: Record = {}; private filterFragmentBuilderRegistry: AnyFilterFragmentBuilderRegistry = defaultFilterFragmentBuilderRegistry(); - readonly joins: Record> = {}; + readonly joins: Record> = {}; readonly graph: graphlib.Graph = new graphlib.Graph(); readonly dimensionsIndex: Record< string, @@ -69,7 +80,7 @@ export class Repository< readonly metricsIndex: Record = {} as Record; - public withModel(model: T) { + withModel(model: ModelWithMatchingContext) { this.models[model.name] = model; for (const dimension in model.dimensions) { this.dimensionsIndex[`${model.name}.${dimension}`] = { @@ -83,14 +94,22 @@ export class Repository< metric, }; } - return this as Repository, D & ModelD, M & ModelM, F>; + + return this as unknown as Repository< + C, + N | ModelN, + D & ModelD, + M & ModelM, + F + >; } - public withFilterFragmentBuilderRegistry< - T extends AnyFilterFragmentBuilderRegistry, - >(filterFragmentBuilderRegistry: T) { + withFilterFragmentBuilderRegistry( + filterFragmentBuilderRegistry: T, + ) { this.filterFragmentBuilderRegistry = filterFragmentBuilderRegistry; return this as Repository< + C, N, D, M, @@ -98,12 +117,12 @@ export class Repository< >; } - public getFilterFragmentBuilderRegistry() { + getFilterFragmentBuilderRegistry() { return this.filterFragmentBuilderRegistry; } getFilterBuilder( - repository: Repository, + repository: AnyRepository, dialect: BaseDialect, filterType: FilterType, referencedModels: string[], @@ -119,34 +138,30 @@ export class Repository< } join( - type: Join["type"], + type: AnyJoin["type"], modelName1: N1, modelName2: N2, - joinSqlDefFn: JoinFn, + joinSqlDefFn: JoinFn, ) { const model1 = this.models[modelName1]; const model2 = this.models[modelName2]; + invariant(model1, `Model ${model1} not found in repository`); invariant(model2, `Model ${model2} not found in repository`); - const dimensions = { - [model1.name]: Object.keys(model1.dimensions).reduce< - Record> - >((acc, dimension) => { - acc[dimension] = new JoinDimensionRef(modelName1, dimension); - return acc; - }, {}), - [model2.name]: Object.keys(model2.dimensions).reduce< - Record> - >((acc, dimension) => { - acc[dimension] = new JoinDimensionRef(model2.name, dimension); - return acc; - }, {}), - } as JoinDimensions; - - const joinSqlDef = joinSqlDefFn({ - sql: (strings, ...values) => new JoinOnDef([...strings], values), - dimensions, - }); + + const joinSqlDef = (context: C) => { + const models = { + [model1.name]: makeModelJoinPayload(model1, context), + [model2.name]: makeModelJoinPayload(model2, context), + } as JoinDimensions; + + return joinSqlDefFn({ + sql: (strings, ...values) => new JoinOnDef([...strings], values), + identifier: (name) => new JoinIdentifierRef(name), + models, + getContext: () => context, + }); + }; const reversedType = REVERSED_JOIN[type]; @@ -175,7 +190,7 @@ export class Repository< joinOneToOne( model1: N1, model2: N2, - joinSqlDefFn: JoinFn, + joinSqlDefFn: JoinFn, ) { return this.join("oneToOne", model1, model2, joinSqlDefFn); } @@ -183,7 +198,7 @@ export class Repository< joinOneToMany( model1: N1, model2: N2, - joinSqlDefFn: JoinFn, + joinSqlDefFn: JoinFn, ) { return this.join("oneToMany", model1, model2, joinSqlDefFn); } @@ -191,7 +206,7 @@ export class Repository< joinManyToOne( model1: N1, model2: N2, - joinSqlDefFn: JoinFn, + joinSqlDefFn: JoinFn, ) { return this.join("manyToOne", model1, model2, joinSqlDefFn); } @@ -199,7 +214,7 @@ export class Repository< joinManyToMany( model1: N1, model2: N2, - joinSqlDefFn: JoinFn, + joinSqlDefFn: JoinFn, ) { return this.join("manyToMany", model1, model2, joinSqlDefFn); } @@ -264,10 +279,10 @@ export class Repository< build(dialectName: AvailableDialects) { const { client, Dialect } = getClientAndDialect(dialectName); - return new QueryBuilder(this, Dialect, client); + return new QueryBuilder(this, new Dialect(), client); } } -export function repository() { - return new Repository(); +export function repository() { + return new Repository(); }