diff --git a/litequery/migrations.py b/litequery/migrations.py index 0331895..26115af 100644 --- a/litequery/migrations.py +++ b/litequery/migrations.py @@ -2,6 +2,7 @@ import os import re import sqlite3 +import textwrap def migrate(db_path, migrations_dir): @@ -12,16 +13,14 @@ def migrate(db_path, migrations_dir): filenames = sort_migration_filenames( os.path.basename(p) for p in glob.glob(f"{migrations_dir}/*.sql") ) - c.execute( - """ - --sql + query = """ create table if not exists migrations ( id integer primary key autoincrement, filename text not null, - run_at datetime not null default current_timestamp + run_at text not null default current_timestamp ); - """ - ) + """ + c.execute(textwrap.dedent(query).strip()) migrations = c.execute("select * from migrations order by run_at asc").fetchall() migrations = {m["filename"] for m in migrations} unapplied = sort_migration_filenames(set(filenames) - migrations) @@ -42,6 +41,13 @@ def migrate(db_path, migrations_dir): ) conn.commit() + with open(os.path.join(os.path.dirname(db_path), "schema.sql"), "w") as f: + statements = c.execute( + "select sql from sqlite_master where sql is not null" + ).fetchall() + for (statement,) in statements: + f.write(f"{statement};\n") + conn.close() diff --git a/tests/test_migrations.py b/tests/test_migrations.py index f2e3bb6..ba4cf7c 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -1,6 +1,7 @@ import os import sqlite3 import tempfile +import textwrap from pathlib import Path import pytest @@ -56,3 +57,25 @@ def test_migrate(temp_db, temp_migrations_dir): assert table_info[1][1] == "name" conn.close() + + +def test_creates_schema(temp_db, temp_migrations_dir): + create_migration_file( + temp_migrations_dir, + "001_initial.sql", + "create table users (id integer primary key autoincrement);", + ) + + migrate(temp_db, temp_migrations_dir) + + with open(os.path.join(os.path.dirname(temp_db), "schema.sql")) as f: + schema_content = """ + CREATE TABLE migrations ( + id integer primary key autoincrement, + filename text not null, + run_at text not null default current_timestamp + ); + CREATE TABLE sqlite_sequence(name,seq); + CREATE TABLE users (id integer primary key autoincrement); + """ + assert f.read().strip() == textwrap.dedent(schema_content).strip()