Skip to content

Commit

Permalink
Use regex to parse tables from SQL (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
kektobiologist authored Oct 26, 2024
1 parent a8dbbd7 commit b6447a0
Show file tree
Hide file tree
Showing 7 changed files with 1,032 additions and 19 deletions.
7 changes: 7 additions & 0 deletions apps/jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/** @type {import('ts-jest').JestConfigWithTsJest} **/
export default {
testEnvironment: "node",
transform: {
"^.+.tsx?$": ["ts-jest",{}],
},
};
4 changes: 3 additions & 1 deletion apps/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
},
"packageManager": "[email protected]",
"devDependencies": {
"@types/uuid": "^10"
"@types/uuid": "^10",
"jest": "^29.7.0",
"ts-jest": "^29.2.5"
}
}
2 changes: 0 additions & 2 deletions apps/src/metabase/helpers/DOMToState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { isDashboardPage } from './dashboard/util';
import { DashboardInfo } from './dashboard/types';
import { getDashboardAppState } from './dashboard/appState';
import { visualizationSettings, Card, ParameterValues, FormattedTable } from './types';
import { getTablesFromSql } from './parseSql';
const { getMetabaseState, queryURL } = RPCs;

interface ExtractedDataBase {
Expand Down Expand Up @@ -71,7 +70,6 @@ export async function convertDOMtoStateSQLQuery() {
const vizType = await getMetabaseState('qb.card.display') as string
const visualizationSettings = await getMetabaseState('qb.card.visualization_settings') as visualizationSettings
const sqlVariables = await getSqlVariables();
const tablesFromSql = getTablesFromSql(sqlQuery);
const metabaseAppStateSQLEditor: MetabaseAppStateSQLEditor = {
availableDatabases,
selectedDatabaseInfo,
Expand Down
24 changes: 17 additions & 7 deletions apps/src/metabase/helpers/getDatabaseSchema.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { memoize, RPCs } from 'web'
import { FormattedTable } from './types';
import { getTablesFromSql } from './parseSql';
import { getTablesFromSqlRegex } from './parseSql';
import _ from 'lodash';

const { getMetabaseState, fetchData } = RPCs;
Expand Down Expand Up @@ -155,21 +155,31 @@ export const getRelevantTablesForSelectedDb = async (sql: string): Promise<Forma
if (!dbId) {
return [];
}
const tablesFromSql = getTablesFromSql(sql);
const {tables: top200} = await memoizedGetTop200TablesWithoutFields(dbId);
const tablesFromSql = getTablesFromSqlRegex(sql);
let {tables: top200} = await memoizedGetTop200TablesWithoutFields(dbId);
const {tables: allTables} = await memoizedGetDatabaseTablesWithoutFields(dbId);
for (const table of tablesFromSql) {
// check if its already there in top200. if so don't do anything
const relevantTable = top200.find(tableInfo => tableInfo.name === table.table && tableInfo.schema === table.schema);
for (const tableInfo of tablesFromSql) {
// if schema is empty, assume its public
let {table, schema} = tableInfo;
if (schema === '' || schema === undefined) {
schema = 'public';
}
// lowercase everything
table = table.toLowerCase();
schema = schema.toLowerCase();
// check if its already there in top200. if so don't do anything. again lowercase everything
const relevantTable = top200.find(tableInfo => tableInfo.name.toLowerCase() === table && tableInfo.schema.toLowerCase() === schema);
if (!relevantTable) {
// check if there in allTables. if so, add it to top200
const relevantTable = allTables.find(tableInfo => tableInfo.name === table.table && tableInfo.schema === table.schema);
const relevantTable = allTables.find(tableInfo => tableInfo.name.toLowerCase() === table && tableInfo.schema.toLowerCase() === schema);
if (relevantTable) {
// insert at beginning
top200.unshift(relevantTable);
}
}
}
// dedupe (by schema.name)
top200 = _.uniqBy(top200, (tableInfo) => `${tableInfo.schema}.${tableInfo.name}`);
// trim to 200
return top200.slice(0, 200);
}
289 changes: 289 additions & 0 deletions apps/src/metabase/helpers/parseSql.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import { getTablesFromSqlRegex, type TableAndSchema } from "./parseSql";

describe('getTablesFromSqlRegex', () => {
const sqlStatementsAndResults: { sql: string, results: TableAndSchema[] }[] = [
// simple select
{
sql: `
SELECT *
FROM user_activity.events_log
WHERE event_status IS 'active';
`,
results: [
{ schema: 'user_activity', table: 'events_log' }
]
},
// another simple select
{
sql:
`
SELECT user_id
FROM platform_data.user_profiles;
`,
results: [
{ schema: 'platform_data', table: 'user_profiles' }
]
}
,
// with statement
{
sql: `
WITH recent_orders AS (
SELECT *
FROM sales_data.order_details
)
SELECT alias.order_id
FROM recent_orders alias;
`,
results: [
{ schema: 'sales_data', table: 'order_details' },
{ schema: '', table: 'recent_orders' }
]
}
,
// quotes and spaces
{
sql: `
SELECT
"source"."event_id" AS "EventID",
"source"."user_id" AS "UserID",
"source"."event_type" AS "EventType",
"source"."event_timestamp" AS "EventTimestamp",
"source"."page_viewed" AS "PageViewed",
"source"."button_label" AS "ButtonLabel"
FROM
(
SELECT
"analytics"."event_tracking"."event_id" AS "event_id",
"analytics"."event_tracking"."user_id" AS "user_id",
"analytics"."event_tracking"."event_type" AS "event_type",
"analytics"."event_tracking"."event_timestamp" AS "event_timestamp",
"analytics"."event_tracking"."page_viewed" AS "page_viewed",
"analytics"."event_tracking"."button_label" AS "button_label"
FROM
"analytics"."event tracking"
) AS "source"
LIMIT 1048575;
`,
results: [
{ schema: 'analytics', table: 'event tracking' }
]
}
,
// lots of tables
{
sql: `
SELECT
monthly_period,
AVG(session_duration) AS avg_duration,
customer_segment,
activity_type
FROM (
SELECT
customer.customer_id,
session_data.session_start,
session_data.session_end,
TO_CHAR(session_data.session_start, 'YYYY-MM') AS start_period,
TO_CHAR(session_data.session_end, 'YYYY-MM') AS end_period,
session_data.duration_in_minutes,
NULL AS total_spend,
(SELECT segment_definitions.segment_name
FROM purchase_transactions transactions
INNER JOIN transaction_metadata metadata ON metadata.transaction_id = transactions.id
INNER JOIN segment_definitions ON segment_definitions.id = metadata.segment_id
WHERE transactions.type = 'purchase'
AND transactions.status != 'refunded'
AND transactions.customer_id = customer.customer_id
LIMIT 1) AS customer_segment,
CASE
WHEN session_histories.session_data_id IS NULL THEN
CASE
WHEN (EXTRACT(EPOCH FROM (session_data.session_end - session_data.session_start))/86400) > 2 THEN 2
ELSE (EXTRACT(EPOCH FROM (session_data.session_end - session_data.session_start))/86400)
END
ELSE durations.total_duration
END AS session_duration,
CASE
WHEN session_histories.session_data_id IS NULL THEN 'New_Session'
ELSE 'Returning_Session'
END AS activity_type
FROM customer_sessions customer
INNER JOIN session_details session_data ON session_data.customer_id = customer.customer_id
LEFT JOIN (
SELECT COUNT(1), session_data_id
FROM session_histories
WHERE reason_code IN ('error-101', 'error-102', 'timeout', 'disconnect')
GROUP BY session_data_id
) session_histories ON session_histories.session_data_id = session_data.id
LEFT JOIN session_durations durations ON durations.session_data_id = session_data.id
) as sessions_derived;
`,
results: [
{ schema: '', table: 'purchase_transactions' },
{ schema: '', table: 'transaction_metadata' },
{ schema: '', table: 'segment_definitions' },
{ schema: '', table: 'customer_sessions' },
{ schema: '', table: 'session_details' },
{ schema: '', table: 'session_histories' },
{ schema: '', table: 'session_durations' },
]
}
,
// filters example
{
sql: `
SELECT
created_at,
store_location AS location_name,
CASE
WHEN store_category = 0 THEN 'Retail'
WHEN store_category = 1 THEN 'Warehouse'
WHEN store_category = 2 THEN 'Distribution'
ELSE 'Other'
END AS store_type,
CASE
WHEN store_size = 0 THEN 'Small'
WHEN store_size = 1 THEN 'Medium'
WHEN store_size = 2 THEN 'Large'
END AS store_size,
CASE
WHEN region_code = 10 THEN 'North'
ELSE 'South'
END AS region,
address_line AS address,
unit_number AS unit,
floor_section AS section,
manager_first_name AS first_name,
manager_last_name AS last_name,
contact_number AS contact_number
FROM store_locations
WHERE {{store_type_filter}}
AND {{store_size_filter}}
AND {{created_at_filter}}
AND {{region_filter}}
ORDER BY created_at DESC;
`,
results: [
{ schema: '', table: 'store_locations' }
]
},
// optional filter example
{
sql: `
SELECT count(*)
FROM products
WHERE 1=1
[[AND id = {{id}}]]
[[AND category = {{category}}]]
`,
results: [
{ schema: '', table: 'products' }
]
},
// foreign language example
{
sql: `
SELECT nombre, apellido
FROM públicó.usuarios
WHERE estado = 'activo';
`,
results: [
{ schema: 'públicó', table: 'usuarios' }
]
},
// join
{
sql: `
WITH daily_messages AS (
SELECT
p.login_email_id as email,
ur.profile_id,
DATE(ur.created_at) as "day",
COUNT(*) as "daily_messages"
FROM
public.user_records ur
JOIN
public.profiles p ON ur.profile_id = p.id
WHERE
ur.created_at >= NOW() - INTERVAL '10 days'
AND ur.type = 'user_message'
AND ur.model != 'gpt-4o-mini'
GROUP BY
p.login_email_id,
ur.profile_id,
DATE(ur.created_at)
), ranked_messages AS (
SELECT
email,
profile_id,
"day",
daily_messages,
RANK() OVER (PARTITION BY "day" ORDER BY daily_messages DESC) as rank
FROM daily_messages
)
SELECT
CASE WHEN rank <= 20 THEN email ELSE 'others' END as email,
"day",
SUM(daily_messages) as daily_messages
FROM ranked_messages
GROUP BY "day", CASE WHEN rank <= 20 THEN email ELSE 'others' END
ORDER BY "day" DESC, daily_messages DESC;
`,
results: [
{ schema: 'public', table: 'user_records' },
{ schema: 'public', table: 'profiles' },
{ schema: '', table: 'daily_messages' },
{ schema: '', table: 'ranked_messages' }
]
},
// same table multiple times (we actually return both, doesn't matter, dedup is done later anyway)
{
sql: `
WITH dummy1 AS (
SELECT * from some_schema.some_table
WHERE some_column = 'some_value'
),
dummy2 AS (
SELECT * from some_schema.some_table
WHERE some_column = 'some_value'
)
SELECT * FROM dummy1 JOIN dummy2 ON dummy1.some_column = dummy2.some_column
ORDER BY some_column;
`,
results: [
{ schema: 'some_schema', table: 'some_table' },
{ schema: 'some_schema', table: 'some_table' },
{ schema: '', table: 'dummy1' },
{ schema: '', table: 'dummy2' }
]

},
// dashes in table example
{
sql: `SELECT * from "some-schema"."some-table";`,
results: [
{ schema: 'some-schema', table: 'some-table' }
]
},
// only schema quoted
{
sql: `SELECT * from "some-schema".sometable;`,
results: [
{ schema: 'some-schema', table: 'sometable' }
]
},
// only table quoted
{
sql: `SELECT * from someschema."some-table";`,
results: [
{ schema: 'someschema', table: 'some-table' }
]
},
];
for (const { sql, results } of sqlStatementsAndResults) {
it(`should get the correct tables from sql: ${sql}`, () => {
const tables = getTablesFromSqlRegex(sql);
expect(tables).toEqual(results);
});
}
});
Loading

0 comments on commit b6447a0

Please sign in to comment.