diff --git a/docs/api.md b/docs/api.md index f63d4f0..08d0b15 100644 --- a/docs/api.md +++ b/docs/api.md @@ -32,6 +32,16 @@ Rolls back the current transaction and starts a new one. Closes the database connection. +### `with` statement + +Connection objects can be used as context managers to ensure that transactions are properly committed or rolled back. When entering the context, the connection object is returned. When exiting: +- Without exception: automatically commits the transaction +- With exception: automatically rolls back the transaction + +This behavior is compatible with Python's `sqlite3` module. Context managers work correctly in both transactional and autocommit modes. + +When mixing manual transaction control with context managers, the context manager's commit/rollback will apply to any active transaction at the time of exit. Manual calls to `commit()` or `rollback()` within the context are allowed and will start a new transaction as usual. + ### execute(sql, parameters=()) Create a new cursor object and executes the SQL statement. diff --git a/src/lib.rs b/src/lib.rs index ed7d444..08db248 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use pyo3::create_exception; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyList, PyTuple}; -use std::cell::{OnceCell, RefCell}; +use std::cell::RefCell; use std::sync::{Arc, OnceLock}; use std::time::Duration; use tokio::runtime::{Handle, Runtime}; @@ -38,14 +38,14 @@ fn is_remote_path(path: &str) -> bool { #[pyfunction] #[cfg(not(Py_3_12))] -#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), check_same_thread=true, uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None))] +#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), _check_same_thread=true, _uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None))] fn connect( py: Python<'_>, database: String, timeout: f64, isolation_level: Option, - check_same_thread: bool, - uri: bool, + _check_same_thread: bool, + _uri: bool, sync_url: Option, sync_interval: Option, auth_token: &str, @@ -56,8 +56,8 @@ fn connect( database, timeout, isolation_level, - check_same_thread, - uri, + _check_same_thread, + _uri, sync_url, sync_interval, auth_token, @@ -68,14 +68,14 @@ fn connect( #[pyfunction] #[cfg(Py_3_12)] -#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), check_same_thread=true, uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None, autocommit = LEGACY_TRANSACTION_CONTROL))] +#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), _check_same_thread=true, _uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None, autocommit = LEGACY_TRANSACTION_CONTROL))] fn connect( py: Python<'_>, database: String, timeout: f64, isolation_level: Option, - check_same_thread: bool, - uri: bool, + _check_same_thread: bool, + _uri: bool, sync_url: Option, sync_interval: Option, auth_token: &str, @@ -87,8 +87,8 @@ fn connect( database, timeout, isolation_level.clone(), - check_same_thread, - uri, + _check_same_thread, + _uri, sync_url, sync_interval, auth_token, @@ -111,8 +111,8 @@ fn _connect_core( database: String, timeout: f64, isolation_level: Option, - check_same_thread: bool, - uri: bool, + _check_same_thread: bool, + _uri: bool, sync_url: Option, sync_interval: Option, auth_token: &str, @@ -220,7 +220,7 @@ unsafe impl Send for Connection {} #[pymethods] impl Connection { - fn close(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> { + fn close(self_: PyRef<'_, Self>, _py: Python<'_>) -> PyResult<()> { self_.conn.replace(None); Ok(()) } @@ -330,11 +330,14 @@ impl Connection { fn in_transaction(self_: PyRef<'_, Self>) -> PyResult { #[cfg(Py_3_12)] { - return Ok( + Ok( !self_.conn.borrow().as_ref().unwrap().is_autocommit() || self_.autocommit == 0 - ); + ) + } + #[cfg(not(Py_3_12))] + { + Ok(!self_.conn.borrow().as_ref().unwrap().is_autocommit()) } - Ok(!self_.conn.borrow().as_ref().unwrap().is_autocommit()) } #[getter] @@ -354,6 +357,26 @@ impl Connection { self_.autocommit = autocommit; Ok(()) } + + fn __enter__(slf: PyRef<'_, Self>) -> PyResult> { + Ok(slf) + } + + fn __exit__( + self_: PyRef<'_, Self>, + exc_type: Option<&PyAny>, + _exc_val: Option<&PyAny>, + _exc_tb: Option<&PyAny>, + ) -> PyResult { + if exc_type.is_none() { + // Commit on clean exit + Connection::commit(self_)?; + } else { + // Rollback on error + Connection::rollback(self_)?; + } + Ok(false) // Always propagate exceptions + } } #[pyclass] diff --git a/tests/test_suite.py b/tests/test_suite.py index 428f314..041226f 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -6,16 +6,19 @@ import pytest import tempfile + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_connection_timeout(provider): conn = connect(provider, ":memory:", timeout=1.0) conn.close() + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_connection_close(provider): conn = connect(provider, ":memory:") conn.close() + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_execute(provider): conn = connect(provider, ":memory:") @@ -34,6 +37,7 @@ def test_cursor_execute(provider): res = cur.execute("SELECT * FROM users") assert (1, "alice@example.com") == res.fetchone() + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_cursor_close(provider): conn = connect(provider, ":memory:") @@ -47,6 +51,7 @@ def test_cursor_close(provider): with pytest.raises(Exception): cur.execute("SELECT * FROM users") + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_executemany(provider): conn = connect(provider, ":memory:") @@ -198,7 +203,9 @@ def test_connection_autocommit(provider): res = cur.execute("SELECT * FROM users") assert (1, "alice@example.com") == res.fetchone() - conn = connect(provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=-1) + conn = connect( + provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=-1 + ) assert conn.isolation_level == "DEFERRED" assert conn.autocommit == -1 cur = conn.cursor() @@ -210,7 +217,9 @@ def test_connection_autocommit(provider): assert (1, "alice@example.com") == res.fetchone() # Test autocommit Enabled (True) - conn = connect(provider, ":memory:", timeout=5, isolation_level=None, autocommit=True) + conn = connect( + provider, ":memory:", timeout=5, isolation_level=None, autocommit=True + ) assert conn.isolation_level == None assert conn.autocommit == True cur = conn.cursor() @@ -221,7 +230,9 @@ def test_connection_autocommit(provider): res = cur.execute("SELECT * FROM users") assert (1, "bob@example.com") == res.fetchone() - conn = connect(provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=True) + conn = connect( + provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=True + ) assert conn.isolation_level == "DEFERRED" assert conn.autocommit == True cur = conn.cursor() @@ -233,7 +244,9 @@ def test_connection_autocommit(provider): assert (1, "bob@example.com") == res.fetchone() # Test autocommit Disabled (False) - conn = connect(provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=False) + conn = connect( + provider, ":memory:", timeout=5, isolation_level="DEFERRED", autocommit=False + ) assert conn.isolation_level == "DEFERRED" assert conn.autocommit == False cur = conn.cursor() @@ -260,6 +273,7 @@ def test_params(provider): res = cur.execute("SELECT * FROM users") assert (1, "alice@example.com") == res.fetchone() + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_none_param(provider): conn = connect(provider, ":memory:") @@ -272,6 +286,7 @@ def test_none_param(provider): assert results[0] == (1, None) assert results[1] == (2, "alice@example.com") + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_fetchmany(provider): conn = connect(provider, ":memory:") @@ -321,6 +336,194 @@ def test_int64(provider): assert [(1, 1099511627776)] == res.fetchall() +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_commit(provider): + """Test that context manager commits on clean exit""" + conn = connect(provider, ":memory:") + with conn as c: + c.execute("CREATE TABLE t(x)") + c.execute("INSERT INTO t VALUES (1)") + # Changes should be committed + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 1 + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_rollback(provider): + """Test that context manager rolls back on exception""" + conn = connect(provider, ":memory:") + try: + with conn as c: + c.execute("CREATE TABLE t(x)") + c.execute("INSERT INTO t VALUES (1)") + raise ValueError("Test exception") + except ValueError: + pass + # Changes should be rolled back + cur = conn.cursor() + try: + cur.execute("SELECT COUNT(*) FROM t") + # If we get here, the table exists (rollback didn't work) + assert False, "Table should not exist after rollback" + except Exception: + # Table doesn't exist, which is what we expect after rollback + pass + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_autocommit(provider): + """Test that context manager works correctly with autocommit mode""" + conn = connect(provider, ":memory:", isolation_level=None) # autocommit mode + with conn as c: + c.execute("CREATE TABLE t(x)") + c.execute("INSERT INTO t VALUES (1)") + # In autocommit mode, changes are committed immediately + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 1 + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_nested(provider): + """Test nested context managers""" + conn = connect(provider, ":memory:") + with conn as c1: + c1.execute("CREATE TABLE t(x)") + c1.execute("INSERT INTO t VALUES (1)") + with conn as c2: + c2.execute("INSERT INTO t VALUES (2)") + # Inner context commits + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 2 + # Outer context also commits + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 2 + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_connection_reuse(provider): + """Test that connection remains usable after context manager exit""" + conn = connect(provider, ":memory:") + + # First use with context manager + with conn as c: + c.execute("CREATE TABLE t(x)") + c.execute("INSERT INTO t VALUES (1)") + + # Connection should still be valid + cur = conn.cursor() + cur.execute("INSERT INTO t VALUES (2)") + conn.commit() + + # Verify both inserts worked + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 2 + + # Use context manager again + with conn as c: + c.execute("INSERT INTO t VALUES (3)") + + # Final verification + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 3 + + conn.close() + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_nested_exception(provider): + """Test exception handling in nested context managers""" + conn = connect(provider, ":memory:") + + # Create table outside context + conn.execute("CREATE TABLE t(x)") + conn.commit() + + # Test that nested context managers share the same transaction + # An exception in an inner context will roll back the entire transaction + try: + with conn as c1: + c1.execute("INSERT INTO t VALUES (1)") + try: + with conn as c2: + c2.execute("INSERT INTO t VALUES (2)") + raise ValueError("Inner exception") + except ValueError: + pass + # The inner rollback affects the entire transaction + # So value 1 is also rolled back + c1.execute("INSERT INTO t VALUES (3)") + except: + pass + + # Only value 3 should be committed (1 and 2 were rolled back together) + cur = conn.cursor() + cur.execute("SELECT x FROM t ORDER BY x") + results = cur.fetchall() + assert results == [(3,)] + + # Test outer exception after nested context commits + conn.execute("DROP TABLE t") + conn.execute("CREATE TABLE t(x)") + conn.commit() + + try: + with conn as c1: + c1.execute("INSERT INTO t VALUES (10)") + with conn as c2: + c2.execute("INSERT INTO t VALUES (20)") + # Inner context will commit both values + # This will cause outer rollback but values are already committed + raise RuntimeError("Outer exception") + except RuntimeError: + pass + + # Values 10 and 20 should be committed by inner context + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 2 + + +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_context_manager_manual_transaction_control(provider): + """Test mixing manual transaction control with context managers""" + conn = connect(provider, ":memory:") + + with conn as c: + c.execute("CREATE TABLE t(x)") + c.execute("INSERT INTO t VALUES (1)") + + # Manual commit within context + c.commit() + + # Start new transaction + c.execute("INSERT INTO t VALUES (2)") + # This will be committed by context manager + + # Both values should be present + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM t") + assert cur.fetchone()[0] == 2 + + # Test manual rollback within context + with conn as c: + c.execute("INSERT INTO t VALUES (3)") + + # Manual rollback + c.rollback() + + # New transaction + c.execute("INSERT INTO t VALUES (4)") + # This will be committed by context manager + + # Should have values 1, 2, and 4 (not 3) + cur.execute("SELECT x FROM t ORDER BY x") + results = cur.fetchall() + assert results == [(1,), (2,), (4,)] + + def connect(provider, database, timeout=5, isolation_level="DEFERRED", autocommit=-1): if provider == "libsql-remote": from urllib import request @@ -331,9 +534,7 @@ def connect(provider, database, timeout=5, isolation_level="DEFERRED", autocommi raise Exception("libsql-remote server is not running") if res.getcode() != 200: raise Exception("libsql-remote server is not running") - return libsql.connect( - database, sync_url="http://localhost:8080", auth_token="" - ) + return libsql.connect(database, sync_url="http://localhost:8080", auth_token="") if provider == "libsql": if sys.version_info < (3, 12): return libsql.connect( @@ -343,15 +544,23 @@ def connect(provider, database, timeout=5, isolation_level="DEFERRED", autocommi if autocommit == -1: autocommit = libsql.LEGACY_TRANSACTION_CONTROL return libsql.connect( - database, timeout=timeout, isolation_level=isolation_level, autocommit=autocommit + database, + timeout=timeout, + isolation_level=isolation_level, + autocommit=autocommit, ) if provider == "sqlite": if sys.version_info < (3, 12): - return sqlite3.connect(database, timeout=timeout, isolation_level=isolation_level) + return sqlite3.connect( + database, timeout=timeout, isolation_level=isolation_level + ) else: if autocommit == -1: autocommit = sqlite3.LEGACY_TRANSACTION_CONTROL return sqlite3.connect( - database, timeout=timeout, isolation_level=isolation_level, autocommit=autocommit + database, + timeout=timeout, + isolation_level=isolation_level, + autocommit=autocommit, ) raise Exception(f"Provider `{provider}` is not supported")