diff --git a/.cspell.json b/.cspell.json index 511d14a..d4adb69 100644 --- a/.cspell.json +++ b/.cspell.json @@ -41,6 +41,7 @@ "outdir", "Paralamas", "Pista", + "remeda", "rmrf", "ryansonshine", "semantic-layer", diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 7174677..687d7dc 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -741,6 +741,37 @@ await describe("semantic layer", async () => { }, ]); }); + + await it("can filter by results of another query", async () => { + const query = queryBuilder.buildQuery({ + dimensions: ["customers.country"], + order: { "customers.country": "asc" }, + filters: [ + { + operator: "inQuery", + member: "customers.country", + value: { + dimensions: ["customers.country"], + filters: [ + { + operator: "equals", + member: "customers.country", + value: ["Argentina"], + }, + ], + }, + }, + ], + limit: 10, + }); + + const result = await client.query>( + query.sql, + query.bindings, + ); + + assert.deepEqual(result.rows, [{ customers___country: "Argentina" }]); + }); }); await describe("models from sql queries", async () => { @@ -1242,6 +1273,26 @@ await describe("semantic layer", async () => { required: ["operator", "member", "value"], additionalProperties: false, }, + { + type: "object", + properties: { + operator: { type: "string", const: "inQuery" }, + member: { type: "string" }, + value: { $ref: "#" }, + }, + required: ["operator", "member", "value"], + additionalProperties: false, + }, + { + type: "object", + properties: { + operator: { type: "string", const: "notInQuery" }, + member: { type: "string" }, + value: { $ref: "#" }, + }, + required: ["operator", "member", "value"], + additionalProperties: false, + }, ], }, }, @@ -1889,62 +1940,62 @@ await describe("semantic layer", async () => { }); describe("repository with context", async () => { - await it("propagates context to all sql functions", async () => { - type QueryContext = { - customerId: number; - }; + type QueryContext = { + customerId: number; + }; - const customersModel = semanticLayer - .model() - .withName("customers") - .fromSqlQuery( - ({ sql, identifier, getContext }) => - sql`select * from ${identifier("Customer")} where ${identifier( - "CustomerId", - )} = ${getContext().customerId}`, - ) - .withDimension("customer_id", { - type: "number", - primaryKey: true, - sql: ({ model, sql, getContext }) => - sql`${model.column("CustomerId")} || cast(${ - getContext().customerId - } as text)`, - }) - .withDimension("first_name", { - type: "string", - sql: ({ model }) => model.column("FirstName"), - }); + const customersModel = semanticLayer + .model() + .withName("customers") + .fromSqlQuery( + ({ sql, identifier, getContext }) => + sql`select * from ${identifier("Customer")} where ${identifier( + "CustomerId", + )} = ${getContext().customerId}`, + ) + .withDimension("customer_id", { + type: "number", + primaryKey: true, + sql: ({ model, sql, getContext }) => + sql`${model.column("CustomerId")} || cast(${ + getContext().customerId + } as text)`, + }) + .withDimension("first_name", { + type: "string", + sql: ({ model }) => model.column("FirstName"), + }); - const invoicesModel = semanticLayer - .model() - .withName("invoices") - .fromTable("Invoice") - .withDimension("invoice_id", { - type: "number", - primaryKey: true, - sql: ({ model }) => model.column("InvoiceId"), - }) - .withDimension("customer_id", { - type: "number", - sql: ({ model }) => model.column("CustomerId"), - }); + const invoicesModel = semanticLayer + .model() + .withName("invoices") + .fromTable("Invoice") + .withDimension("invoice_id", { + type: "number", + primaryKey: true, + sql: ({ model }) => model.column("InvoiceId"), + }) + .withDimension("customer_id", { + type: "number", + sql: ({ model }) => model.column("CustomerId"), + }); - const repository = semanticLayer - .repository() - .withModel(customersModel) - .withModel(invoicesModel) - .joinOneToMany( - "customers", - "invoices", - ({ sql, models, getContext }) => - sql`${models.customers.dimension( - "customer_id", - )} = ${models.invoices.dimension("customer_id")} and ${ - getContext().customerId - } = ${getContext().customerId}`, - ); + const repository = semanticLayer + .repository() + .withModel(customersModel) + .withModel(invoicesModel) + .joinOneToMany( + "customers", + "invoices", + ({ sql, models, getContext }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")} and ${ + getContext().customerId + } = ${getContext().customerId}`, + ); + await it("propagates context to all sql functions", async () => { const queryBuilder = repository.build("postgresql"); const query = queryBuilder.buildQuery( { @@ -1961,5 +2012,41 @@ await describe("semantic layer", async () => { // First 5 bindings are for the customerId, last one is for the limit assert.deepEqual(query.bindings, [1, 1, 1, 1, 1, 5000]); }); + + await it("propagates context to query filters", async () => { + const queryBuilder = repository.build("postgresql"); + const query = queryBuilder.buildQuery( + { + dimensions: ["customers.customer_id", "invoices.invoice_id"], + filters: [ + { + operator: "inQuery", + member: "customers.customer_id", + value: { + dimensions: ["customers.customer_id"], + filters: [ + { + operator: "equals", + member: "customers.customer_id", + value: [1], + }, + ], + }, + }, + ], + }, + { customerId: 1 }, + ); + + assert.equal( + query.sql, + 'select "q0"."customers___customer_id" as "customers___customer_id", "q0"."invoices___invoice_id" as "invoices___invoice_id" from (select "invoices_query"."customers___customer_id" as "customers___customer_id", "invoices_query"."invoices___invoice_id" as "invoices___invoice_id" from (select distinct "Invoice"."InvoiceId" as "invoices___invoice_id", "customers"."CustomerId" || cast($1 as text) as "customers___customer_id" from "Invoice" right join (select * from "Customer" where "CustomerId" = $2) as customers on "customers"."CustomerId" || cast($3 as text) = "Invoice"."CustomerId" and $4 = $5 where "customers"."CustomerId" || cast($6 as text) in (select "q0"."customers___customer_id" as "customers___customer_id" from (select "customers_query"."customers___customer_id" as "customers___customer_id" from (select distinct "customers"."CustomerId" || cast($7 as text) as "customers___customer_id" from (select * from "Customer" where "CustomerId" = $8) as customers where "customers"."CustomerId" || cast($9 as text) = $10) as "customers_query") as "q0" order by "customers___customer_id" asc limit $11)) as "invoices_query") as "q0" order by "customers___customer_id" asc limit $12', + ); + + assert.deepEqual( + query.bindings, + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5000, 5000], + ); + }); }); }); diff --git a/src/lib/query-builder.ts b/src/lib/query-builder.ts index 48a0ebf..3066b73 100644 --- a/src/lib/query-builder.ts +++ b/src/lib/query-builder.ts @@ -1,6 +1,6 @@ import { AnyQuery, - AnyQueryFilter, + FilterType, IntrospectionResult, MemberNameToType, Query, @@ -14,111 +14,36 @@ import { } from "./types.js"; import knex from "knex"; -import { z } from "zod"; import { Simplify } from "type-fest"; import { BaseDialect } from "./dialect/base.js"; import { buildQuery } from "./query-builder/build-query.js"; +import { FilterBuilder } from "./query-builder/filter-builder.js"; import { findOptimalJoinGraph } from "./query-builder/optimal-join-graph.js"; import { processQueryAndExpandToSegments } from "./query-builder/process-query-and-expand-to-segments.js"; +import { QuerySchema, buildQuerySchema } from "./query-schema.js"; import type { AnyRepository } from "./repository.js"; import { getAdHocAlias, getAdHocPath } from "./util.js"; -function getDimensionNamesSchema(dimensionPaths: string[]) { - return z - .array( - z - .string() - .refine((arg) => dimensionPaths.includes(arg)) - .describe("Dimension name"), - ) - .optional(); -} - -function getMetricNamesSchema(metricPaths: string[], dimensionPaths: string[]) { - const adHocMetricSchema = z.object({ - aggregateWith: z.enum(["sum", "count", "min", "max", "avg"]), - dimension: z - .string() - .refine((arg) => dimensionPaths.includes(arg)) - .describe("Dimension name"), - }); - - return z - .array( - z - .string() - .refine((arg) => metricPaths.includes(arg)) - .describe("Metric name") - .or(adHocMetricSchema), - ) - .optional(); -} - -export function buildQuerySchema(repository: AnyRepository) { - const dimensionPaths = repository.getDimensions().map((d) => d.getPath()); - const metricPaths = repository.getMetrics().map((m) => m.getPath()); - const memberPaths = [...dimensionPaths, ...metricPaths]; - - const registeredFilterFragmentBuildersSchemas = repository - .getFilterFragmentBuilderRegistry() - .getFilterFragmentBuilders() - .map((builder) => builder.fragmentBuilderSchema); - - const filters: z.ZodType = z.array( - z.union([ - z.object({ - operator: z.literal("and"), - filters: z.lazy(() => filters), - }), - z.object({ - operator: z.literal("or"), - filters: z.lazy(() => filters), - }), - ...registeredFilterFragmentBuildersSchemas.map((schema) => - schema.refine((arg) => memberPaths.includes(arg.member), { - path: ["member"], - message: "Member not found", - }), - ), - ]), - ); - - const schema = z - .object({ - dimensions: getDimensionNamesSchema(dimensionPaths), - metrics: getMetricNamesSchema(metricPaths, dimensionPaths), - filters: filters.optional(), - limit: z.number().optional(), - offset: z.number().optional(), - order: z.record(z.string(), z.enum(["asc", "desc"])).optional(), - }) - .refine( - (arg) => (arg.dimensions?.length ?? 0) + (arg.metrics?.length ?? 0) > 0, - "At least one dimension or metric must be selected", - ); - - return schema; -} - export class QueryBuilder< C, D extends MemberNameToType, M extends MemberNameToType, F, > { - public readonly querySchema: ReturnType; + public readonly querySchema: QuerySchema; constructor( - private readonly repository: AnyRepository, - private readonly dialect: BaseDialect, - private readonly client: knex.Knex, + public readonly repository: AnyRepository, + public readonly dialect: BaseDialect, + public readonly client: knex.Knex, ) { - this.querySchema = buildQuerySchema(repository); + this.querySchema = buildQuerySchema(this); } - unsafeBuildQuery(payload: unknown, context: unknown) { - const parsedQuery: AnyQuery = this.querySchema.parse(payload); - + unsafeBuildGenericQueryWithoutSchemaParse( + parsedQuery: AnyQuery, + context: unknown, + ) { const { query, referencedModels, segments } = processQueryAndExpandToSegments(this.repository, parsedQuery); @@ -128,9 +53,7 @@ export class QueryBuilder< ); const sqlQuery = buildQuery( - this.client, - this.repository, - this.dialect, + this, context, query, referencedModels, @@ -138,8 +61,15 @@ export class QueryBuilder< segments, ); - const { sql, bindings } = sqlQuery.toSQL().toNative(); + return sqlQuery.toSQL(); + } + unsafeBuildQuery(payload: unknown, context: unknown) { + const parsedQuery: AnyQuery = this.querySchema.parse(payload); + const { sql, bindings } = this.unsafeBuildGenericQueryWithoutSchemaParse( + parsedQuery, + context, + ).toNative(); return { sql, bindings: bindings as unknown[], @@ -177,6 +107,16 @@ export class QueryBuilder< return result; } + getFilterBuilder( + filterType: FilterType, + referencedModels: string[], + metricPrefixes?: Record, + ): FilterBuilder { + return this.repository + .getFilterFragmentBuilderRegistry() + .getFilterBuilder(this, filterType, referencedModels, metricPrefixes); + } + introspect(query: AnyQuery): IntrospectionResult { const queryDimensions = query.dimensions ?? []; const queryMetrics = query.metrics ?? []; @@ -227,3 +167,6 @@ export type QueryBuilderQuery = Q extends QueryBuilder< > ? Query : never; + +// biome-ignore lint/suspicious/noExplicitAny: +export type AnyQueryBuilder = QueryBuilder; diff --git a/src/lib/query-builder/build-query.ts b/src/lib/query-builder/build-query.ts index 50fae97..37c682f 100644 --- a/src/lib/query-builder/build-query.ts +++ b/src/lib/query-builder/build-query.ts @@ -13,6 +13,7 @@ import invariant from "tiny-invariant"; import { BaseDialect } from "../dialect/base.js"; import type { AnyJoin } from "../join.js"; import { AnyModel } from "../model.js"; +import { AnyQueryBuilder } from "../query-builder.js"; import type { AnyRepository } from "../repository.js"; import { getAdHocAlias } from "../util.js"; @@ -57,22 +58,22 @@ function getDefaultOrderBy(repository: AnyRepository, query: AnyQuery) { } function initializeQuerySegment( - knex: knex.Knex, + client: knex.Knex, dialect: BaseDialect, context: unknown, model: AnyModel, ) { if (model.config.type === "table") { - return knex(model.config.name); + return client(model.config.name); } const modelSql = model.getSql(dialect, context); - return knex( - knex.raw(`(${modelSql.sql}) as ${model.config.alias}`, modelSql.bindings), + return client( + client.raw(`(${modelSql.sql}) as ${model.config.alias}`, modelSql.bindings), ); } function getJoinSubject( - knex: knex.Knex, + client: knex.Knex, dialect: BaseDialect, context: unknown, model: AnyModel, @@ -81,7 +82,7 @@ function getJoinSubject( return model.config.name; } const modelSql = model.getSql(dialect, context); - return knex.raw( + return client.raw( `(${modelSql.sql}) as ${model.config.alias}`, modelSql.bindings, ); @@ -89,17 +90,20 @@ function getJoinSubject( // biome-ignore lint/complexity/noExcessiveCognitiveComplexity: Essential complexity function buildQuerySegmentJoinQuery( - knex: knex.Knex, - repository: AnyRepository, - dialect: BaseDialect, + queryBuilder: AnyQueryBuilder, context: unknown, joinGraph: graphlib.Graph, modelQueries: Record, source: string, ) { const visitedModels = new Set(); - const model = repository.getModel(source); - const sqlQuery = initializeQuerySegment(knex, dialect, context, model); + const model = queryBuilder.repository.getModel(source); + const sqlQuery = initializeQuerySegment( + queryBuilder.client, + queryBuilder.dialect, + context, + model, + ); const modelStack: { modelName: string; join?: AnyJoin }[] = [ { modelName: source }, @@ -113,7 +117,7 @@ function buildQuerySegmentJoinQuery( visitedModels.add(modelName); const modelQuery = modelQueries[modelName]; - const model = repository.getModel(modelName); + const model = queryBuilder.repository.getModel(modelName); const hasMetrics = modelQuery?.metrics && modelQuery.metrics.size > 0; const hasAdHocMetrics = modelQuery?.adHocMetrics && modelQuery.adHocMetrics.size > 0; @@ -130,11 +134,21 @@ function buildQuerySegmentJoinQuery( if (join) { const joinType = join.reversed ? "rightJoin" : "leftJoin"; - const joinOn = join.joinOnDef(context).render(repository, dialect); - const rightModel = repository.getModel(join.right); - const joinSubject = getJoinSubject(knex, dialect, context, rightModel); + const joinOn = join + .joinOnDef(context) + .render(queryBuilder.repository, queryBuilder.dialect); + const rightModel = queryBuilder.repository.getModel(join.right); + const joinSubject = getJoinSubject( + queryBuilder.client, + queryBuilder.dialect, + context, + rightModel, + ); - sqlQuery[joinType](joinSubject, knex.raw(joinOn.sql, joinOn.bindings)); + sqlQuery[joinType]( + joinSubject, + queryBuilder.client.raw(joinOn.sql, joinOn.bindings), + ); // We have a join that is multiplying the rows, so we need to use DISTINCT if (join.type === "manyToMany" || join.type === "oneToMany") { @@ -143,37 +157,48 @@ function buildQuerySegmentJoinQuery( } for (const metricName of modelQuery?.metrics || []) { - const metric = repository.getMetric(metricName); - const { sql, bindings } = metric.getSql(dialect, context); + const metric = queryBuilder.repository.getMetric(metricName); + const { sql, bindings } = metric.getSql(queryBuilder.dialect, context); sqlQuery.select( - knex.raw(`${sql} as ${metric.getAlias(dialect)}`, bindings), + queryBuilder.client.raw( + `${sql} as ${metric.getAlias(queryBuilder.dialect)}`, + bindings, + ), ); } for (const adHocMetric of modelQuery?.adHocMetrics || []) { - const dimension = repository.getDimension(adHocMetric.dimension); + const dimension = queryBuilder.repository.getDimension( + adHocMetric.dimension, + ); const { sql, bindings } = dimension.getSqlWithoutGranularity( - dialect, + queryBuilder.dialect, context, ); sqlQuery.select( - knex.raw(`${sql} as ${getAdHocMetricAlias(adHocMetric)}`, bindings), + queryBuilder.client.raw( + `${sql} as ${getAdHocMetricAlias(adHocMetric)}`, + bindings, + ), ); } for (const dimensionName of dimensionNames) { - const dimension = repository.getDimension(dimensionName); - const { sql, bindings } = dimension.getSql(dialect, context); + const dimension = queryBuilder.repository.getDimension(dimensionName); + const { sql, bindings } = dimension.getSql(queryBuilder.dialect, context); sqlQuery.select( - knex.raw(`${sql} as ${dimension.getAlias(dialect)}`, bindings), + queryBuilder.client.raw( + `${sql} as ${dimension.getAlias(queryBuilder.dialect)}`, + bindings, + ), ); } modelStack.push( ...unvisitedNeighbors.map((unvisitedModelName) => ({ modelName: unvisitedModelName, - join: repository.getJoin(modelName, unvisitedModelName), + join: queryBuilder.repository.getJoin(modelName, unvisitedModelName), })), ); } @@ -183,9 +208,7 @@ function buildQuerySegmentJoinQuery( // biome-ignore lint/complexity/noExcessiveCognitiveComplexity: Essential complexity function buildQuerySegment( - knex: knex.Knex, - repository: AnyRepository, - dialect: BaseDialect, + queryBuilder: AnyQueryBuilder, context: unknown, joinGraph: graphlib.Graph, segment: QuerySegment, @@ -200,9 +223,7 @@ function buildQuerySegment( invariant(source, "No source found for segment"); const initialSqlQuery = buildQuerySegmentJoinQuery( - knex, - repository, - dialect, + queryBuilder, context, joinGraph, segment.modelQueries, @@ -219,65 +240,77 @@ function buildQuerySegment( } if (segment.query.filters) { - const filter = repository - .getFilterBuilder( - repository, - dialect, - "dimension", - segment.referencedModels.all, - ) + const filter = queryBuilder + .getFilterBuilder("dimension", segment.referencedModels.all) .buildFilters(segment.query.filters, "and", context); if (filter) { - initialSqlQuery.where(knex.raw(filter.sql, filter.bindings)); + initialSqlQuery.where( + queryBuilder.client.raw(filter.sql, filter.bindings), + ); } } const alias = `${source}_query`; - const sqlQuery = knex(initialSqlQuery.as(alias)); + const sqlQuery = queryBuilder.client(initialSqlQuery.as(alias)); const hasMetrics = (segment.query.metrics && segment.query.metrics.length > 0) || (segment.query.adHocMetrics && segment.query.adHocMetrics.length > 0); for (const dimensionName of segment.query.dimensions || []) { - const dimension = repository.getDimension(dimensionName); + const dimension = queryBuilder.repository.getDimension(dimensionName); sqlQuery.select( - knex.raw( - `${dialect.asIdentifier(alias)}.${dimension.getAlias( - dialect, - )} as ${dimension.getAlias(dialect)}`, + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(alias)}.${dimension.getAlias( + queryBuilder.dialect, + )} as ${dimension.getAlias(queryBuilder.dialect)}`, ), ); if (hasMetrics) { sqlQuery.groupBy( - knex.raw( - `${dialect.asIdentifier(alias)}.${dimension.getAlias(dialect)}`, + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(alias)}.${dimension.getAlias( + queryBuilder.dialect, + )}`, ), ); } } for (const metricName of segment.query.metrics || []) { - const metric = repository.getMetric(metricName); - const { sql, bindings } = metric.getAggregateSql(dialect, context, alias); + const metric = queryBuilder.repository.getMetric(metricName); + const { sql, bindings } = metric.getAggregateSql( + queryBuilder.dialect, + context, + alias, + ); sqlQuery.select( - knex.raw(`${sql} as ${metric.getAlias(dialect)}`, bindings), + queryBuilder.client.raw( + `${sql} as ${metric.getAlias(queryBuilder.dialect)}`, + bindings, + ), ); } for (const adHocMetric of segment.query.adHocMetrics || []) { - const dimension = repository.getDimension(adHocMetric.dimension); - const initialSql = dialect.aggregate( + const dimension = queryBuilder.repository.getDimension( + adHocMetric.dimension, + ); + const initialSql = queryBuilder.dialect.aggregate( adHocMetric.aggregateWith, - `${dialect.asIdentifier(alias)}.${getAdHocMetricAlias(adHocMetric)}`, + `${queryBuilder.dialect.asIdentifier(alias)}.${getAdHocMetricAlias( + adHocMetric, + )}`, ); const dimensionGranularity = dimension.getGranularity(); const sql = dimensionGranularity - ? dialect.withGranularity(dimensionGranularity, initialSql) + ? queryBuilder.dialect.withGranularity(dimensionGranularity, initialSql) : initialSql; - sqlQuery.select(knex.raw(`${sql} as ${getAdHocMetricAlias(adHocMetric)}`)); + sqlQuery.select( + queryBuilder.client.raw(`${sql} as ${getAdHocMetricAlias(adHocMetric)}`), + ); } return { ...segment, sqlQuery }; @@ -289,9 +322,7 @@ function getAlias(index: number) { // biome-ignore lint/complexity/noExcessiveCognitiveComplexity: export function buildQuery( - knex: knex.Knex, - repository: AnyRepository, - dialect: BaseDialect, + queryBuilder: AnyQueryBuilder, context: unknown, query: AnyQuery, referencedModels: ReferencedModels, @@ -299,7 +330,7 @@ export function buildQuery( segments: QuerySegment[], ) { const sqlQuerySegments = segments.map((segment) => - buildQuerySegment(knex, repository, dialect, context, joinGraph, segment), + buildQuerySegment(queryBuilder, context, joinGraph, segment), ); const [initialSqlQuerySegment, ...restSqlQuerySegments] = sqlQuerySegments; @@ -310,34 +341,36 @@ export function buildQuery( );*/ const joinOnDimensions = query.dimensions?.map((dimensionName) => { - return repository.getDimension(dimensionName); + return queryBuilder.repository.getDimension(dimensionName); }); const rootAlias = getAlias(0); - const rootSqlQuery = knex(initialSqlQuerySegment.sqlQuery.as(rootAlias)); + const rootSqlQuery = queryBuilder.client( + initialSqlQuerySegment.sqlQuery.as(rootAlias), + ); for (const dimensionName of initialSqlQuerySegment.projectedQuery .dimensions || []) { - const dimension = repository.getDimension(dimensionName); + const dimension = queryBuilder.repository.getDimension(dimensionName); rootSqlQuery.select( - knex.raw( - `${dialect.asIdentifier(rootAlias)}.${dimension.getAlias( - dialect, - )} as ${dimension.getAlias(dialect)}`, + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(rootAlias)}.${dimension.getAlias( + queryBuilder.dialect, + )} as ${dimension.getAlias(queryBuilder.dialect)}`, ), ); } for (const metricName of initialSqlQuerySegment.projectedQuery.metrics || []) { - const metric = repository.getMetric(metricName); + const metric = queryBuilder.repository.getMetric(metricName); rootSqlQuery.select( - knex.raw( - `${dialect.asIdentifier(rootAlias)}.${metric.getAlias( - dialect, - )} as ${metric.getAlias(dialect)}`, + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(rootAlias)}.${metric.getAlias( + queryBuilder.dialect, + )} as ${metric.getAlias(queryBuilder.dialect)}`, ), ); } @@ -345,8 +378,8 @@ export function buildQuery( for (const adHocMetric of initialSqlQuerySegment.projectedQuery .adHocMetrics || []) { rootSqlQuery.select( - knex.raw( - `${dialect.asIdentifier(rootAlias)}.${getAdHocMetricAlias( + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(rootAlias)}.${getAdHocMetricAlias( adHocMetric, )} as ${getAdHocMetricAlias(adHocMetric)}`, ), @@ -360,25 +393,30 @@ export function buildQuery( joinOnDimensions && joinOnDimensions.length > 0 ? joinOnDimensions .map((dimension) => { - return `${dialect.asIdentifier(rootAlias)}.${dimension.getAlias( - dialect, - )} = ${dialect.asIdentifier(alias)}.${dimension.getAlias( - dialect, - )}`; + return `${queryBuilder.dialect.asIdentifier( + rootAlias, + )}.${dimension.getAlias( + queryBuilder.dialect, + )} = ${queryBuilder.dialect.asIdentifier( + alias, + )}.${dimension.getAlias(queryBuilder.dialect)}`; }) .join(" and ") : "1 = 1"; - rootSqlQuery.innerJoin(segment.sqlQuery.as(alias), knex.raw(joinOn)); + rootSqlQuery.innerJoin( + segment.sqlQuery.as(alias), + queryBuilder.client.raw(joinOn), + ); for (const metricName of segment.projectedQuery.metrics || []) { if ((query.metrics ?? []).includes(metricName)) { - const metric = repository.getMetric(metricName); + const metric = queryBuilder.repository.getMetric(metricName); rootSqlQuery.select( - knex.raw( - `${dialect.asIdentifier(alias)}.${metric.getAlias( - dialect, - )} as ${metric.getAlias(dialect)}`, + queryBuilder.client.raw( + `${queryBuilder.dialect.asIdentifier(alias)}.${metric.getAlias( + queryBuilder.dialect, + )} as ${metric.getAlias(queryBuilder.dialect)}`, ), ); } @@ -395,24 +433,20 @@ export function buildQuery( }, {}, ); - const filter = repository - .getFilterBuilder( - repository, - dialect, - "metric", - referencedModels.metrics, - metricPrefixes, - ) + const filter = queryBuilder + .getFilterBuilder("metric", referencedModels.metrics, metricPrefixes) .buildFilters(query.filters, "and", context); if (filter) { - rootSqlQuery.where(knex.raw(filter.sql, filter.bindings)); + rootSqlQuery.where(queryBuilder.client.raw(filter.sql, filter.bindings)); } } const orderBy = Object.entries( - query.order || getDefaultOrderBy(repository, query), + query.order || getDefaultOrderBy(queryBuilder.repository, query), ).map(([member, direction]) => { - const memberSql = repository.getMember(member).getAlias(dialect); + const memberSql = queryBuilder.repository + .getMember(member) + .getAlias(queryBuilder.dialect); return `${memberSql} ${direction}`; }); diff --git a/src/lib/query-builder/filter-builder.ts b/src/lib/query-builder/filter-builder.ts index 0a7fd8e..05753c9 100644 --- a/src/lib/query-builder/filter-builder.ts +++ b/src/lib/query-builder/filter-builder.ts @@ -40,9 +40,12 @@ import { lt as filterLt, lte as filterLte, } from "./filter-builder/number-comparison-filter-builder.js"; +import { + inQuery as filterInQuery, + notInQuery as filterNotInQuery, +} from "./filter-builder/query-filter-builder.js"; -import { BaseDialect } from "../dialect/base.js"; -import type { AnyRepository } from "../repository.js"; +import { AnyQueryBuilder } from "../query-builder.js"; import { sqlAsSqlWithBindings } from "./util.js"; export class FilterBuilder { @@ -53,8 +56,7 @@ export class FilterBuilder { string, AnyFilterFragmentBuilder >, - private readonly dialect: BaseDialect, - private readonly repository: AnyRepository, + public readonly queryBuilder: AnyQueryBuilder, private readonly filterType: FilterType, referencedModels: string[], private readonly metricPrefixes?: Record, @@ -65,16 +67,18 @@ export class FilterBuilder { memberName: string, context: unknown, ): SqlWithBindings | undefined { - const member = this.repository.getMember(memberName); + const member = this.queryBuilder.repository.getMember(memberName); if (this.referencedModels.has(member.model.name)) { if (this.filterType === "dimension" && member.isDimension()) { - return member.getSql(this.dialect, context); + return member.getSql(this.queryBuilder.dialect, context); } if (this.filterType === "metric" && member.isMetric()) { const prefix = this.metricPrefixes?.[member.model.name]; - const sql = member.getAlias(this.dialect); + const sql = member.getAlias(this.queryBuilder.dialect); return sqlAsSqlWithBindings( - prefix ? `${this.dialect.asIdentifier(prefix)}.${sql}` : sql, + prefix + ? `${this.queryBuilder.dialect.asIdentifier(prefix)}.${sql}` + : sql, ); } } @@ -103,7 +107,7 @@ export class FilterBuilder { if (memberSql) { const builder = this.filterFragmentBuilders[filter.operator]; if (builder) { - return builder.build(this, memberSql, filter); + return builder.build(this, context, memberSql, filter); } throw new Error(`Unknown filter operator: ${filter.operator}`); } @@ -162,16 +166,14 @@ export class FilterFragmentBuilderRegistry { return Object.values(this.filterFragmentBuilders); } getFilterBuilder( - repository: AnyRepository, - dialect: BaseDialect, + queryBuilder: AnyQueryBuilder, filterType: FilterType, referencedModels: string[], metricPrefixes?: Record, ): FilterBuilder { return new FilterBuilder( this.filterFragmentBuilders, - dialect, - repository, + queryBuilder, filterType, referencedModels, metricPrefixes, @@ -207,5 +209,7 @@ export function defaultFilterFragmentBuilderRegistry() { .register(filterInDateRange) .register(filterNotInDateRange) .register(filterBeforeDate) - .register(filterAfterDate); + .register(filterAfterDate) + .register(filterInQuery) + .register(filterNotInQuery); } diff --git a/src/lib/query-builder/filter-builder/date-filter-builder.ts b/src/lib/query-builder/filter-builder/date-filter-builder.ts index 7d1993c..3f3a18b 100644 --- a/src/lib/query-builder/filter-builder/date-filter-builder.ts +++ b/src/lib/query-builder/filter-builder/date-filter-builder.ts @@ -21,15 +21,19 @@ function parseDate(value: Schema) { } function makeDateFilterBuilder(name: T) { - return filterFragmentBuilder(name, Schema, (_builder, member, filter) => { - const date = parseDate(filter.value); - const sql = `${member.sql} ${name === "beforeDate" ? "<" : ">"} ?`; - const bindings: unknown[] = [...member.bindings, date]; - return { - sql, - bindings, - }; - }); + return filterFragmentBuilder( + name, + Schema, + (_builder, _context, member, filter) => { + const date = parseDate(filter.value); + const sql = `${member.sql} ${name === "beforeDate" ? "<" : ">"} ?`; + const bindings: unknown[] = [...member.bindings, date]; + return { + sql, + bindings, + }; + }, + ); } export const beforeDate = makeDateFilterBuilder("beforeDate" as const); diff --git a/src/lib/query-builder/filter-builder/date-range-filter-builder.ts b/src/lib/query-builder/filter-builder/date-range-filter-builder.ts index 15943c6..a9b4b6b 100644 --- a/src/lib/query-builder/filter-builder/date-range-filter-builder.ts +++ b/src/lib/query-builder/filter-builder/date-range-filter-builder.ts @@ -41,15 +41,19 @@ function parseDateRange(value: Schema) { } function makeDateRangeFilterBuilder(name: T, isNot: boolean) { - return filterFragmentBuilder(name, Schema, (_builder, member, filter) => { - const [firstDate, lastDate] = parseDateRange(filter.value); - const sql = `${member.sql} ${isNot ? "not between" : "between"} ? and ?`; - const bindings: unknown[] = [...member.bindings, firstDate, lastDate]; - return { - sql, - bindings, - }; - }); + return filterFragmentBuilder( + name, + Schema, + (_builder, _context, member, filter) => { + const [firstDate, lastDate] = parseDateRange(filter.value); + const sql = `${member.sql} ${isNot ? "not between" : "between"} ? and ?`; + const bindings: unknown[] = [...member.bindings, firstDate, lastDate]; + return { + sql, + bindings, + }; + }, + ); } export const inDateRange = makeDateRangeFilterBuilder( diff --git a/src/lib/query-builder/filter-builder/equals.ts b/src/lib/query-builder/filter-builder/equals.ts index 1e72a84..a26cb2b 100644 --- a/src/lib/query-builder/filter-builder/equals.ts +++ b/src/lib/query-builder/filter-builder/equals.ts @@ -7,7 +7,7 @@ function makeEqualsFilterFragmentBuilder(name: T) { z.array( z.union([z.string(), z.number(), z.bigint(), z.boolean(), z.date()]), ), - (_builder, member, filter) => { + (_builder, _context, member, filter) => { if (filter.value.length === 1) { return { sql: `${member.sql} = ?`, diff --git a/src/lib/query-builder/filter-builder/filter-fragment-builder.ts b/src/lib/query-builder/filter-builder/filter-fragment-builder.ts index 507f577..48c7551 100644 --- a/src/lib/query-builder/filter-builder/filter-fragment-builder.ts +++ b/src/lib/query-builder/filter-builder/filter-fragment-builder.ts @@ -1,25 +1,39 @@ import { ZodSchema, z } from "zod"; +import { AnyQueryBuilder } from "../../query-builder.js"; import { SqlWithBindings } from "../../types.js"; import type { FilterBuilder } from "../filter-builder.js"; export class FilterFragmentBuilder< N extends string, - Z extends ZodSchema | null, + Z extends ZodSchema | ((queryBuilder: AnyQueryBuilder) => ZodSchema) | null, T extends FilterFragmentBuilderPayload, > { - public readonly fragmentBuilderSchema: ZodSchema; + public readonly fragmentBuilderSchema: + | ZodSchema + | ((queryBuilder: AnyQueryBuilder) => ZodSchema); constructor( public readonly operator: string, valueSchema: Z, private readonly builder: FilterFragmentBuilderFn, ) { if (valueSchema) { - this.fragmentBuilderSchema = z.object({ - operator: z.literal(operator), - member: z.string(), - value: valueSchema, - }); + if (typeof valueSchema === "function") { + this.fragmentBuilderSchema = (queryBuilder: AnyQueryBuilder) => { + const resolvedValueSchema = valueSchema(queryBuilder); + return z.object({ + operator: z.literal(operator), + member: z.string(), + value: resolvedValueSchema, + }); + }; + } else { + this.fragmentBuilderSchema = z.object({ + operator: z.literal(operator), + member: z.string(), + value: valueSchema, + }); + } } else { this.fragmentBuilderSchema = z.object({ operator: z.literal(operator), @@ -27,21 +41,32 @@ export class FilterFragmentBuilder< }); } } - build(filterBuilder: FilterBuilder, member: SqlWithBindings, payload: T) { - const filter = this.fragmentBuilderSchema.parse(payload); - return this.builder(filterBuilder, member, filter); + getFilterFragmentBuilderSchema(queryBuilder: AnyQueryBuilder) { + return typeof this.fragmentBuilderSchema === "function" + ? this.fragmentBuilderSchema(queryBuilder) + : this.fragmentBuilderSchema; + } + build( + filterBuilder: FilterBuilder, + context: unknown, + member: SqlWithBindings, + payload: unknown, + ) { + // We can directly pass payload as T because the schema is already validated in the QueryBuilder + return this.builder(filterBuilder, context, member, payload as T); } } export type AnyFilterFragmentBuilder = FilterFragmentBuilder< string, - ZodSchema | null, + ZodSchema | ((queryBuilder: AnyQueryBuilder) => ZodSchema) | null, // biome-ignore lint/suspicious/noExplicitAny: any >; export type FilterFragmentBuilderFn = ( builder: FilterBuilder, + context: unknown, member: SqlWithBindings, filter: T, ) => SqlWithBindings; @@ -51,8 +76,12 @@ export type GetFilterFragmentBuilderPayload = export type FilterFragmentBuilderPayload< N extends string, - Z extends ZodSchema | null, - T = Z extends ZodSchema ? z.infer : null, + Z extends ZodSchema | ((queryBuilder: AnyQueryBuilder) => ZodSchema) | null, + T = Z extends ZodSchema + ? z.infer + : Z extends (queryBuilder: AnyQueryBuilder) => ZodSchema + ? z.infer> + : null, > = T extends null ? { operator: N; member: string } : { @@ -63,7 +92,7 @@ export type FilterFragmentBuilderPayload< export function filterFragmentBuilder< N extends string, - Z extends ZodSchema | null, + Z extends ZodSchema | ((queryBuilder: AnyQueryBuilder) => ZodSchema) | null, T extends FilterFragmentBuilderPayload, >(name: N, valueSchema: Z, builder: FilterFragmentBuilderFn) { return new FilterFragmentBuilder(name, valueSchema, builder); diff --git a/src/lib/query-builder/filter-builder/ilike-filter-builder.ts b/src/lib/query-builder/filter-builder/ilike-filter-builder.ts index 9467e77..8ba7a3a 100644 --- a/src/lib/query-builder/filter-builder/ilike-filter-builder.ts +++ b/src/lib/query-builder/filter-builder/ilike-filter-builder.ts @@ -29,7 +29,7 @@ function makeILikeFilterBuilder( return filterFragmentBuilder( name, z.array(z.string()), - (_filterBuilder, member, filter) => { + (_filterBuilder, _context, member, filter) => { const { sqls, bindings } = filter.value.reduce<{ sqls: string[]; bindings: unknown[]; diff --git a/src/lib/query-builder/filter-builder/not-equals.ts b/src/lib/query-builder/filter-builder/not-equals.ts index 1409b37..bbbd947 100644 --- a/src/lib/query-builder/filter-builder/not-equals.ts +++ b/src/lib/query-builder/filter-builder/not-equals.ts @@ -7,7 +7,7 @@ function makeNotEqualsFilterFragmentBuilder(name: T) { z.array( z.union([z.string(), z.number(), z.bigint(), z.boolean(), z.date()]), ), - (_builder, member, filter) => { + (_builder, _context, member, filter) => { if (filter.value.length === 1) { return { sql: `${member.sql} <> ?`, diff --git a/src/lib/query-builder/filter-builder/null-check-filter-builder.ts b/src/lib/query-builder/filter-builder/null-check-filter-builder.ts index a501f19..fcb8f3e 100644 --- a/src/lib/query-builder/filter-builder/null-check-filter-builder.ts +++ b/src/lib/query-builder/filter-builder/null-check-filter-builder.ts @@ -4,7 +4,7 @@ function makeNullCheckFilterBuilder( name: T, isNull: boolean, ) { - return filterFragmentBuilder(name, null, (_builder, member) => { + return filterFragmentBuilder(name, null, (_builder, _context, member) => { const sql = `${member.sql} is ${isNull ? "" : "not"} null`; return { sql, diff --git a/src/lib/query-builder/filter-builder/number-comparison-filter-builder.ts b/src/lib/query-builder/filter-builder/number-comparison-filter-builder.ts index c641993..9981dd7 100644 --- a/src/lib/query-builder/filter-builder/number-comparison-filter-builder.ts +++ b/src/lib/query-builder/filter-builder/number-comparison-filter-builder.ts @@ -14,7 +14,7 @@ function makeNumberComparisonFilterBuilder< return filterFragmentBuilder( operator, z.array(z.number()), - (_builder, member, filter) => { + (_builder, _context, member, filter) => { const { sqls, bindings } = filter.value.reduce<{ sqls: string[]; bindings: unknown[]; diff --git a/src/lib/query-builder/filter-builder/query-filter-builder.ts b/src/lib/query-builder/filter-builder/query-filter-builder.ts new file mode 100644 index 0000000..6daf7d0 --- /dev/null +++ b/src/lib/query-builder/filter-builder/query-filter-builder.ts @@ -0,0 +1,55 @@ +import { + FilterFragmentBuilder, + filterFragmentBuilder, +} from "./filter-fragment-builder.js"; + +import { z } from "zod"; +import { AnyQueryBuilder } from "../../query-builder.js"; + +export type InOrNotIn = "in" | "notIn"; + +const inOrNotInToSQL = { + in: "in", + notIn: "not in", +} as const; + +// Return type here is intentionally simplified, but we make it exact later in the QueryBuilder class +function makeQueryFilterFragmentBuilder( + name: T, + inOrNotIn: InOrNotIn, +): FilterFragmentBuilder< + T, + (queryBuilder: AnyQueryBuilder) => z.ZodType, + { + operator: T; + member: string; + value: object; + } +> { + return filterFragmentBuilder( + name, + (queryBuilder) => { + return z.lazy(() => queryBuilder.querySchema); + }, + ( + filterBuilder, + context, + member, + filter, + ): { sql: string; bindings: unknown[] } => { + const { sql, bindings } = + filterBuilder.queryBuilder.unsafeBuildGenericQueryWithoutSchemaParse( + filter.value, + context, + ); + + return { + sql: `${member.sql} ${inOrNotInToSQL[inOrNotIn]} (${sql})`, + bindings: [...member.bindings, ...bindings], + }; + }, + ); +} + +export const inQuery = makeQueryFilterFragmentBuilder("inQuery", "in"); +export const notInQuery = makeQueryFilterFragmentBuilder("notInQuery", "notIn"); diff --git a/src/lib/query-schema.ts b/src/lib/query-schema.ts new file mode 100644 index 0000000..628e014 --- /dev/null +++ b/src/lib/query-schema.ts @@ -0,0 +1,86 @@ +import { z } from "zod"; +import { AnyQueryBuilder } from "./query-builder.js"; +import { AnyQueryFilter } from "./types.js"; + +function getDimensionNamesSchema(dimensionPaths: string[]) { + return z + .array( + z + .string() + .refine((arg) => dimensionPaths.includes(arg)) + .describe("Dimension name"), + ) + .optional(); +} + +function getMetricNamesSchema(metricPaths: string[], dimensionPaths: string[]) { + const adHocMetricSchema = z.object({ + aggregateWith: z.enum(["sum", "count", "min", "max", "avg"]), + dimension: z + .string() + .refine((arg) => dimensionPaths.includes(arg)) + .describe("Dimension name"), + }); + + return z + .array( + z + .string() + .refine((arg) => metricPaths.includes(arg)) + .describe("Metric name") + .or(adHocMetricSchema), + ) + .optional(); +} + +export function buildQuerySchema(queryBuilder: AnyQueryBuilder) { + const dimensionPaths = queryBuilder.repository + .getDimensions() + .map((d) => d.getPath()); + const metricPaths = queryBuilder.repository + .getMetrics() + .map((m) => m.getPath()); + const memberPaths = [...dimensionPaths, ...metricPaths]; + + const registeredFilterFragmentBuildersSchemas = queryBuilder.repository + .getFilterFragmentBuilderRegistry() + .getFilterFragmentBuilders() + .map((builder) => builder.getFilterFragmentBuilderSchema(queryBuilder)); + + const filters: z.ZodType = z.array( + z.union([ + z.object({ + operator: z.literal("and"), + filters: z.lazy(() => filters), + }), + z.object({ + operator: z.literal("or"), + filters: z.lazy(() => filters), + }), + ...registeredFilterFragmentBuildersSchemas.map((schema) => + schema.refine((arg) => memberPaths.includes(arg.member), { + path: ["member"], + message: "Member not found", + }), + ), + ]), + ); + + const schema = z + .object({ + dimensions: getDimensionNamesSchema(dimensionPaths), + metrics: getMetricNamesSchema(metricPaths, dimensionPaths), + filters: filters.optional(), + limit: z.number().optional(), + offset: z.number().optional(), + order: z.record(z.string(), z.enum(["asc", "desc"])).optional(), + }) + .refine( + (arg) => (arg.dimensions?.length ?? 0) + (arg.metrics?.length ?? 0) > 0, + "At least one dimension or metric must be selected", + ); + + return schema; +} + +export type QuerySchema = ReturnType; diff --git a/src/lib/repository.ts b/src/lib/repository.ts index 8362ca3..9c06262 100644 --- a/src/lib/repository.ts +++ b/src/lib/repository.ts @@ -12,11 +12,10 @@ import { AnyModel, Model } from "./model.js"; import type { Dimension, Metric } from "./model.js"; import { AnyFilterFragmentBuilderRegistry, - FilterBuilder, GetFilterFragmentBuilderRegistryPayload, defaultFilterFragmentBuilderRegistry, } from "./query-builder/filter-builder.js"; -import { AvailableDialects, FilterType, MemberNameToType } from "./types.js"; +import { AvailableDialects, MemberNameToType } from "./types.js"; import graphlib from "@dagrejs/graphlib"; import knex from "knex"; @@ -123,22 +122,6 @@ export class Repository< return this.filterFragmentBuilderRegistry; } - getFilterBuilder( - repository: AnyRepository, - dialect: BaseDialect, - filterType: FilterType, - referencedModels: string[], - metricPrefixes?: Record, - ): FilterBuilder { - return this.filterFragmentBuilderRegistry.getFilterBuilder( - repository, - dialect, - filterType, - referencedModels, - metricPrefixes, - ); - } - join( type: AnyJoin["type"], modelName1: N1, diff --git a/src/lib/semantic-layer.ts b/src/lib/semantic-layer.ts index fc39ce7..acf10e9 100644 --- a/src/lib/semantic-layer.ts +++ b/src/lib/semantic-layer.ts @@ -1,6 +1,7 @@ export * from "./repository.js"; export * from "./model.js"; export * from "./join.js"; +export * from "./query-schema.js"; export * from "./query-builder.js"; export * from "./query-builder/filter-builder.js"; export { BaseDialect } from "./dialect/base.js"; diff --git a/src/lib/types.ts b/src/lib/types.ts index 4cddb3a..f93ee98 100644 --- a/src/lib/types.ts +++ b/src/lib/types.ts @@ -28,16 +28,41 @@ export type QueryMetric< DN extends string = string, > = MN | QueryAdHocMetric; +export type WithInQueryFilter = [ + Extract, +] extends [never] + ? F + : + | Exclude + | { operator: "inQuery"; member: QueryDN | QueryMN; value: Q }; + +export type WithNotInQueryFilter< + F extends AnyQueryFilter, + Q extends AnyQuery, +> = [Extract] extends [never] + ? F + : + | Exclude + | { operator: "notInQuery"; member: QueryDN | QueryMN; value: Q }; + export type Query = { dimensions?: DN[]; metrics?: QueryMetric[]; order?: { [K in DN | MN]?: "asc" | "desc" }; - filters?: QueryFilter[]; + filters?: WithNotInQueryFilter< + WithInQueryFilter, Query>, + Query + >[]; limit?: number; offset?: number; }; -// biome-ignore lint/suspicious/noExplicitAny: +// biome-ignore lint/suspicious/noExplicitAny: Any used for inference +export type QueryDN = Q extends Query ? DN : never; +// biome-ignore lint/suspicious/noExplicitAny: Any used for inference +export type QueryMN = Q extends Query ? MN : never; + +// biome-ignore lint/suspicious/noExplicitAny: Any used for inference export type AnyQuery = Query; export interface ModelQuery {