Skip to content

Commit

Permalink
Merge pull request #14 from RazerM/feature/new-postgres
Browse files Browse the repository at this point in the history
Support latest PostgreSQL/Python, drop old versions
  • Loading branch information
RazerM authored Jan 6, 2024
2 parents 158f156 + 4ed94e4 commit 3acbdc9
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
postgres-version: ["11", "12", "13", "14", "15"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
postgres-version: ["12", "13", "14", "15", "16"]

services:
postgres:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py37-plus]
args: [--py38-plus]
- repo: https://github.com/timothycrosley/isort
rev: 5.13.2
hooks:
Expand Down
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@

**This release supports SQLAlchemy 2.0 or later**

### Added

- `get_all_parameter_acls`
- `get_parameter_acl`
- `PgOjectType.DOMAIN`
- `PgOjectType.PARAMETER`

### Changed

- Python 3.7 or later is required.
- Python 3.8 or later is required.
- The following arguments for `grant` and `revoke` are now keyword-only:
- `grant_option`
- `schema`
Expand Down
9 changes: 9 additions & 0 deletions src/pg_grant/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def get_default_privileges(type: PgObjectType, owner: str) -> List[Privileges]:
# default.
# https://stackoverflow.com/questions/46656644
public_privs = ["USAGE"]
elif type is PgObjectType.DOMAIN:
public_privs = ["USAGE"]

if public_privs:
priv_list.append(
Expand Down Expand Up @@ -203,6 +205,8 @@ def convert_priv(code: str, keyword: str) -> None:
convert_priv("C", "CREATE")
convert_priv("c", "CONNECT")
convert_priv("T", "TEMPORARY")
convert_priv("s", "SET")
convert_priv("A", "ALTER SYSTEM")

# Don't think anything can have all of them, but set all to False
# since we don't know type.
Expand Down Expand Up @@ -237,6 +241,8 @@ def convert_priv(code: str, keyword: str) -> None:
convert_priv("C", "CREATE")
elif type is PgObjectType.TYPE:
convert_priv("U", "USAGE")
elif type is PgObjectType.DOMAIN:
convert_priv("U", "USAGE")
elif type is PgObjectType.FOREIGN_DATA_WRAPPER:
convert_priv("U", "USAGE")
elif type is PgObjectType.FOREIGN_SERVER:
Expand All @@ -246,6 +252,9 @@ def convert_priv(code: str, keyword: str) -> None:
elif type is PgObjectType.LARGE_OBJECT:
convert_priv("r", "SELECT")
convert_priv("w", "UPDATE")
elif type is PgObjectType.PARAMETER:
convert_priv("s", "SET")
convert_priv("A", "ALTER SYSTEM")
else:
raise ValueError(f"Unknown type: {type}")

Expand Down
65 changes: 56 additions & 9 deletions src/pg_grant/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

from ._typing_sqlalchemy import ArgTypesInput
from .exc import NoSuchObjectError
from .types import ColumnInfo, FunctionInfo, RelationInfo, SchemaRelationInfo
from .types import (
ColumnInfo,
FunctionInfo,
ParameterInfo,
RelationInfo,
SchemaRelationInfo,
)

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -45,6 +51,8 @@
"get_tablespace_acl",
"get_all_type_acls",
"get_type_acl",
"get_all_parameter_acls",
"get_parameter_acl",
)

pg_table_is_visible = func.pg_catalog.pg_table_is_visible
Expand Down Expand Up @@ -114,6 +122,13 @@ class PgRelKind(Enum):
column("typacl"),
)

pg_parameter_acl = table(
"pg_parameter_acl",
column("oid"),
column("parname"),
column("paracl"),
)

pg_language = table(
"pg_language",
column("oid"),
Expand Down Expand Up @@ -267,6 +282,12 @@ class PgRelKind(Enum):
.outerjoin(pg_roles, pg_type.c.typowner == pg_roles.c.oid)
)

_pg_parameter_stmt = select(
pg_parameter_acl.c.oid,
pg_parameter_acl.c.parname.label("name"),
cast(pg_parameter_acl.c.paracl, ARRAY(Text)).label("acl"),
)


def _filter_pg_class_stmt(
stmt: Select[TP], schema: Optional[str] = None, rel_name: Optional[str] = None
Expand Down Expand Up @@ -395,7 +416,7 @@ def get_table_acl(
:class:`~.types.SchemaRelationInfo`
"""
stmt = _table_stmt(schema=schema, table_name=name)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(name)
return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))
Expand Down Expand Up @@ -463,7 +484,7 @@ def get_sequence_acl(
:class:`~.types.SchemaRelationInfo`
"""
stmt = _sequence_stmt(schema=schema, sequence_name=sequence)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(sequence)
return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))
Expand Down Expand Up @@ -535,7 +556,7 @@ def get_function_acl(
raise TypeError("arg_types should be a sequence of strings, e.g. ['text']")

stmt = _filter_pg_proc_stmt(schema, function_name, arg_types)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(function_name)
return FunctionInfo(**t.cast("Mapping[str, Any]", row))
Expand All @@ -558,7 +579,7 @@ def get_language_acl(conn: Connectable, language: str) -> RelationInfo:
:class:`~.types.RelationInfo`
"""
stmt = _pg_lang_stmt.where(pg_language.c.lanname == language)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(language)
return RelationInfo(**t.cast("Mapping[str, Any]", row))
Expand All @@ -581,7 +602,7 @@ def get_schema_acl(conn: Connectable, schema: str) -> RelationInfo:
:class:`~.types.RelationInfo`
"""
stmt = _pg_schema_stmt.where(pg_namespace.c.nspname == schema)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(schema)
return RelationInfo(**t.cast("Mapping[str, Any]", row))
Expand All @@ -604,7 +625,7 @@ def get_database_acl(conn: Connectable, database: str) -> RelationInfo:
:class:`~.types.RelationInfo`
"""
stmt = _pg_db_stmt.where(pg_database.c.datname == database)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(database)
return RelationInfo(**t.cast("Mapping[str, Any]", row))
Expand All @@ -627,7 +648,7 @@ def get_tablespace_acl(conn: Connectable, tablespace: str) -> RelationInfo:
:class:`~.types.RelationInfo`
"""
stmt = _pg_tablespace_stmt.where(pg_tablespace.c.spcname == tablespace)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(tablespace)
return RelationInfo(**t.cast("Mapping[str, Any]", row))
Expand Down Expand Up @@ -657,7 +678,33 @@ def get_type_acl(
:class:`~.types.SchemaRelationInfo`
"""
stmt = _filter_pg_type_stmt(schema=schema, type_name=type_name)
row = conn.execute(stmt).mappings().first()
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
raise NoSuchObjectError(type_name)
return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))


def get_all_parameter_acls(conn: Connectable) -> List[ParameterInfo]:
"""Return all parameters which have non-default ACLs.
Returns:
List of :class:`~.types.ParameterInfo` objects
"""
return [
ParameterInfo(**t.cast("Mapping[str, Any]", row))
for row in conn.execute(_pg_parameter_stmt).mappings()
]


def get_parameter_acl(conn: Connectable, parameter: str) -> Optional[ParameterInfo]:
"""Return information of the given parameter.
Returns:
:class:`~.types.ParameterInfo` if the parameter exists and has
non-default privileges, otherwise ``None``.
"""
stmt = _pg_parameter_stmt.where(pg_parameter_acl.c.parname == parameter)
row = conn.execute(stmt).mappings().one_or_none()
if row is None:
return None
return ParameterInfo(**t.cast("Mapping[str, Any]", row))
13 changes: 6 additions & 7 deletions src/pg_grant/sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import sys
from typing import Any, ClassVar, List, Optional, Tuple, Union, cast, overload
from typing import Any, ClassVar, List, Literal, Optional, Tuple, Union, cast, overload

from sqlalchemy import FromClause, Sequence, inspect
from sqlalchemy.ext.compiler import compiles
Expand All @@ -10,11 +10,6 @@
from ._typing_sqlalchemy import AnyTarget, ArgTypesInput, TableTarget
from .types import PgObjectType

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand All @@ -27,7 +22,7 @@

_re_valid_priv = re.compile(
r"(SELECT|UPDATE|INSERT|DELETE|TRUNCATE|REFERENCES|TRIGGER|EXECUTE|USAGE"
r"|CREATE|CONNECT|TEMPORARY|ALL)(?:\s+\((.*)\))?"
r"|CREATE|CONNECT|TEMPORARY|SET|ALTER SYSTEM|ALL)(?:\s+\((.*)\))?"
)


Expand Down Expand Up @@ -202,10 +197,12 @@ def grant(
PgObjectType.DATABASE,
PgObjectType.TABLESPACE,
PgObjectType.TYPE,
PgObjectType.DOMAIN,
PgObjectType.FOREIGN_DATA_WRAPPER,
PgObjectType.FOREIGN_SERVER,
PgObjectType.FOREIGN_TABLE,
PgObjectType.LARGE_OBJECT,
PgObjectType.PARAMETER,
],
target: str,
grantee: str,
Expand Down Expand Up @@ -319,10 +316,12 @@ def revoke(
PgObjectType.DATABASE,
PgObjectType.TABLESPACE,
PgObjectType.TYPE,
PgObjectType.DOMAIN,
PgObjectType.FOREIGN_DATA_WRAPPER,
PgObjectType.FOREIGN_SERVER,
PgObjectType.FOREIGN_TABLE,
PgObjectType.LARGE_OBJECT,
PgObjectType.PARAMETER,
],
target: str,
grantee: str,
Expand Down
31 changes: 22 additions & 9 deletions src/pg_grant/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from enum import Enum
from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Tuple, overload

Expand All @@ -17,10 +16,7 @@
else:
HAVE_SQLALCHEMY = True

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from typing import Literal


class PgObjectType(Enum):
Expand All @@ -34,10 +30,12 @@ class PgObjectType(Enum):
DATABASE = "DATABASE"
TABLESPACE = "TABLESPACE"
TYPE = "TYPE"
DOMAIN = "DOMAIN"
FOREIGN_DATA_WRAPPER = "FOREIGN DATA WRAPPER"
FOREIGN_SERVER = "FOREIGN SERVER"
FOREIGN_TABLE = "FOREIGN TABLE"
LARGE_OBJECT = "LARGE OBJECT"
PARAMETER = "PARAMETER"


@define
Expand Down Expand Up @@ -299,7 +297,7 @@ def as_revoke_statements(
raise RuntimeError("Missing sqlalchemy extra")


@define
@define(kw_only=True)
class RelationInfo:
"""Holds object information and privileges as queried using the
:mod:`.query` submodule."""
Expand All @@ -317,7 +315,7 @@ class RelationInfo:
acl: Optional[Tuple[str, ...]] = field(converter=converters.optional(tuple))


@define
@define(kw_only=True)
class SchemaRelationInfo(RelationInfo):
"""Holds object information and privileges as queried using the
:mod:`.query` submodule."""
Expand All @@ -326,7 +324,7 @@ class SchemaRelationInfo(RelationInfo):
schema: str


@define
@define(kw_only=True)
class FunctionInfo(SchemaRelationInfo):
"""Holds object information and privileges as queried using the
:mod:`.query` submodule."""
Expand All @@ -335,7 +333,7 @@ class FunctionInfo(SchemaRelationInfo):
arg_types: Tuple[str, ...] = field(converter=tuple)


@define
@define(kw_only=True)
class ColumnInfo:
"""Holds object information and privileges as queried using the
:mod:`.query` submodule."""
Expand All @@ -357,3 +355,18 @@ class ColumnInfo:

#: Column access control list.
acl: Optional[Tuple[str, ...]] = field(converter=converters.optional(tuple))


@define(kw_only=True)
class ParameterInfo:
"""Holds object information and privileges as queried using the
:mod:`.query` submodule."""

#: Row identifier.
oid: int

#: Name of the table, sequence, etc.
name: str

#: Access control list.
acl: Optional[Tuple[str, ...]] = field(converter=converters.optional(tuple))
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

tests_dir = Path(__file__).parents[0].resolve()
test_schema_file = Path(tests_dir, "data", "test-schema.sql")
test_schema_15_file = Path(tests_dir, "data", "test-schema-15+.sql")

# This matches docker-compose.yml for easy local development
DEFAULT_DATABASE_URL = "postgresql://[email protected]:5440/postgres"
Expand Down Expand Up @@ -67,9 +68,13 @@ def engine(postgres_url):

@pytest.fixture(scope="session")
def pg_schema(engine):
with test_schema_file.open() as fp:
with engine.begin() as conn:
with engine.begin() as conn:
with test_schema_file.open() as fp:
conn.execute(text(fp.read()))
server_version = conn.connection.dbapi_connection.info.server_version
if server_version >= 150000:
with test_schema_15_file.open() as fp:
conn.execute(text(fp.read()))


@pytest.fixture
Expand Down
1 change: 1 addition & 0 deletions tests/data/test-schema-15+.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GRANT ALL ON PARAMETER log_min_duration_statement TO alice;
Loading

0 comments on commit 3acbdc9

Please sign in to comment.