Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/postgres: fixed inspection/migration for enums and indexes #711

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/action/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func SchemaNameFromURL(ctx context.Context, url string) (string, error) {
}
return cfg.DBName, err
case "postgres":
return postgresSchema(dsn)
return postgresSchema(url)
case "sqlite":
return schemaName(ctx, dsn)
default:
Expand Down
4 changes: 4 additions & 0 deletions cmd/action/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ func Test_PostgresSchemaDSN(t *testing.T) {
url: "postgres://localhost:5432/dbname?search_path=",
expected: "",
},
{
url: "postgres://us_er:password@localhost:5432/dbname?sslmode=disable",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.url, func(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions internal/integration/testdata/postgres/column-enum.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ table "users" {

-- 1.sql --
Table "script_column_enum.users"
Column | Type | Collation | Nullable | Default
--------+--------+-----------+----------+------------------
type | status | | not null | 'active'::status
Column | Type | Collation | Nullable | Default
--------+---------------------------+-----------+----------+-------------------------------------
type | script_column_enum.status | | not null | 'active'::script_column_enum.status


-- 2.hcl --
Expand All @@ -46,6 +46,6 @@ table "users" {

-- 2.sql --
Table "script_column_enum.users"
Column | Type | Collation | Nullable | Default
--------+--------+-----------+----------+--------------------
type | status | | not null | 'inactive'::status
Column | Type | Collation | Nullable | Default
--------+---------------------------+-----------+----------+---------------------------------------
type | script_column_enum.status | | not null | 'inactive'::script_column_enum.status
10 changes: 10 additions & 0 deletions sql/internal/sqlx/dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ func (d *DevDriver) NormalizeRealm(ctx context.Context, r *schema.Realm) (nr *sc
t.Schema = s
}
changes = append(changes, &schema.AddTable{T: t})

for _, c := range t.Columns {
e, ok := c.Type.Type.(*schema.EnumType)
if !ok {
continue
}
if e.Schema != s {
e.Schema = s
}
}
}
}
patch := func(r *schema.Realm) {
Expand Down
11 changes: 11 additions & 0 deletions sql/internal/sqlx/sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,17 @@ func (b *Builder) Table(t *schema.Table) *Builder {
return b
}

// EnumType writes the enum identifier to the builder, prefixed
// with the schema name if exists.
func (b *Builder) EnumType(t *schema.EnumType) *Builder {
if t.Schema != nil {
b.Ident(t.Schema.Name)
b.rewriteLastByte('.')
}
b.Ident(t.T)
return b
}

// Comma writes a comma in case the buffer is not empty, or
// replaces the last char if it is a whitespace.
func (b *Builder) Comma() *Builder {
Expand Down
5 changes: 4 additions & 1 deletion sql/postgres/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ func FormatType(t schema.Type) (string, error) {
if t.T == "" {
return "", errors.New("postgres: missing enum type name")
}
f = t.T
if t.Schema != nil {
f = t.Schema.Name + "."
}
f += t.T
case *schema.IntegerType:
switch f = strings.ToLower(t.T); f {
case TypeSmallInt, TypeInteger, TypeBigInt:
Expand Down
19 changes: 19 additions & 0 deletions sql/postgres/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ func (d *Driver) Lock(ctx context.Context, name string, timeout time.Duration) (
}, nil
}

// IsClean returns true if the realm is clean. A Postgres database
// is considered clean if there are no tables in any existing schema.
func (d *Driver) IsClean(ctx context.Context) (bool, error) {
realm, err := d.InspectRealm(ctx, nil)
if err != nil {
return false, err
}

clean := true
for _, s := range realm.Schemas {
if len(s.Tables) > 0 {
clean = false
break
}
}

return clean, nil
}

func acquire(ctx context.Context, conn schema.ExecQuerier, id uint32, timeout time.Duration) error {
switch {
// With timeout (context-based).
Expand Down
50 changes: 33 additions & 17 deletions sql/postgres/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,25 +283,35 @@ func columnType(c *columnDesc) schema.Type {
func (i *inspect) enumValues(ctx context.Context, s *schema.Schema) error {
var (
args []interface{}
ids = make(map[int64][]*schema.EnumType)
enums = make(map[int64]*schema.EnumType)
query = "SELECT enumtypid, enumlabel FROM pg_enum WHERE enumtypid IN (%s)"
)
for _, t := range s.Tables {
for _, c := range t.Columns {
if enum, ok := c.Type.Type.(*enumType); ok {
if _, ok := ids[enum.ID]; !ok {
e, ok := enums[enum.ID]
if !ok {
e = &schema.EnumType{T: enum.T, Schema: s}
enums[enum.ID] = e
args = append(args, enum.ID)
}

// Convert the intermediate type to the
// standard schema.EnumType.
e := &schema.EnumType{T: enum.T}
c.Type.Type = e
c.Type.Raw = enum.T
ids[enum.ID] = append(ids[enum.ID], e)

if expr, ok := c.Default.(*schema.RawExpr); ok {
parts := strings.Split(expr.X, "::")
schemaQualifiedName := fmt.Sprintf("%s.%s", s.Name, e.T)
if len(parts) == 2 && (parts[1] == e.T || parts[1] == schemaQualifiedName) {
c.Default = &schema.Literal{V: parts[0]}
}
}
}
}
}
if len(ids) == 0 {
if len(enums) == 0 {
return nil
}
rows, err := i.QueryContext(ctx, fmt.Sprintf(query, nArgs(0, len(args))), args...)
Expand All @@ -317,7 +327,7 @@ func (i *inspect) enumValues(ctx context.Context, s *schema.Schema) error {
if err := rows.Scan(&id, &v); err != nil {
return fmt.Errorf("postgres: scanning enum label: %w", err)
}
for _, enum := range ids[id] {
if enum, found := enums[id]; found {
enum.Values = append(enum.Values, v)
}
}
Expand Down Expand Up @@ -387,18 +397,17 @@ func (i *inspect) addIndexes(s *schema.Schema, rows *sql.Rows) error {
NullsLast: nullslast.Bool,
})
}
switch {
case sqlx.ValidString(expr):
part.X = &schema.RawExpr{
X: expr.String,
}
case sqlx.ValidString(column):
if sqlx.ValidString(column) {
part.C, ok = t.Column(column.String)
if !ok {
return fmt.Errorf("postgres: column %q was not found for index %q", column.String, idx.Name)
}
part.C.Indexes = append(part.C.Indexes, idx)
default:
} else if sqlx.ValidString(expr) {
part.X = &schema.RawExpr{
X: expr.String,
}
} else {
return fmt.Errorf("postgres: invalid part for index %q", idx.Name)
}
idx.Parts = append(idx.Parts, part)
Expand Down Expand Up @@ -767,6 +776,7 @@ ORDER BY
`

// Query to list table indexes.
// column name/expr ref.: https://gist.github.com/akki/6a64075ad3c50f3bcb4926dc49a06939
indexesQuery = `
SELECT
t.relname AS table_name,
Expand All @@ -777,25 +787,31 @@ SELECT
idx.indisunique AS unique,
c.contype AS constraint_type,
pg_get_expr(idx.indpred, idx.indrelid) AS predicate,
pg_get_expr(idx.indexprs, idx.indrelid) AS expression,
pg_get_indexdef(idx.indexrelid, idx.ord, false) AS expression,
pg_index_column_has_property(idx.indexrelid, a.attnum, 'desc') AS desc,
pg_index_column_has_property(idx.indexrelid, a.attnum, 'nulls_first') AS nulls_first,
pg_index_column_has_property(idx.indexrelid, a.attnum, 'nulls_last') AS nulls_last,
obj_description(to_regclass($1 || i.relname)::oid) AS comment
FROM
pg_index idx
(
select
*,
generate_series(1,array_length(i.indkey,1)) as ord,
unnest(i.indkey) AS key
from pg_index i
) idx
JOIN pg_class i ON i.oid = idx.indexrelid
JOIN pg_class t ON t.oid = idx.indrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
LEFT JOIN pg_constraint c ON idx.indexrelid = c.conindid
LEFT JOIN pg_attribute a ON a.attrelid = idx.indexrelid
LEFT JOIN pg_attribute a ON (a.attrelid, a.attnum) = (idx.indrelid, idx.key)
JOIN pg_am am ON am.oid = i.relam
WHERE
n.nspname = $1
AND t.relname IN (%s)
AND COALESCE(c.contype, '') <> 'f'
ORDER BY
table_name, index_name, a.attnum
table_name, index_name, idx.ord
`
fksQuery = `
SELECT
Expand Down
17 changes: 11 additions & 6 deletions sql/postgres/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestDriver_InspectTable(t *testing.T) {
m.ExpectQuery(queryColumns).
WithArgs("public", "users").
WillReturnRows(sqltest.Rows(`
table_name | column_name | data_type | is_nullable | column_default | character_maximum_length | numeric_precision | datetime_precision | numeric_scale | character_set_name | collation_name | udt_name | is_identity | identity_start | identity_increment | identity_generation | comment | typtype | oid
table_name | column_name | data_type | is_nullable | column_default | character_maximum_length | numeric_precision | datetime_precision | numeric_scale | character_set_name | collation_name | udt_name | is_identity | identity_start | identity_increment | identity_generation | comment | typtype | oid
-------------+--------------+-----------------------------+-------------+---------------------------------+--------------------------+-------------------+--------------------+---------------+--------------------+----------------+-------------+-------------+----------------+--------------------+---------------------+---------+---------+-------
users | id | bigint | NO | | | 64 | | 0 | | | int8 | YES | 100 | 1 | BY DEFAULT | | b | 20
users | rank | integer | YES | | | 32 | | 0 | | | int4 | NO | | | | rank | b | 23
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestDriver_InspectTable(t *testing.T) {
{Name: "c21", Type: &schema.ColumnType{Raw: "xml", Type: &XMLType{T: "xml"}}},
{Name: "c22", Type: &schema.ColumnType{Raw: "ARRAY", Null: true, Type: &ArrayType{T: "int4[]"}}},
{Name: "c23", Type: &schema.ColumnType{Raw: "USER-DEFINED", Null: true, Type: &UserDefinedType{T: "ltree"}}},
{Name: "c24", Type: &schema.ColumnType{Raw: "state", Type: &schema.EnumType{T: "state", Values: []string{"on", "off"}}}},
{Name: "c24", Type: &schema.ColumnType{Raw: "state", Type: &schema.EnumType{T: "state", Schema: t.Schema, Values: []string{"on", "off"}}}},
{Name: "c25", Type: &schema.ColumnType{Raw: "timestamp without time zone", Type: &schema.TimeType{T: "timestamp without time zone", Precision: p(4)}}, Default: &schema.RawExpr{X: "now()"}},
{Name: "c26", Type: &schema.ColumnType{Raw: "timestamp with time zone", Type: &schema.TimeType{T: "timestamp with time zone", Precision: p(6)}}},
{Name: "c27", Type: &schema.ColumnType{Raw: "time without time zone", Type: &schema.TimeType{T: "time without time zone", Precision: p(6)}}},
Expand All @@ -130,18 +130,21 @@ table_name | column_name | data_type | is_nullable | column_de
-----------+-------------+---------------------+-------------+---------------------------------+--------------------------+-------------------+--------------------+---------------+--------------------+----------------+----------+-------------+----------------+--------------------+---------------------+---------+---------+-------
users | id | bigint | NO | | | 64 | | 0 | | | int8 | NO | | | | | b | 20
users | c1 | smallint | NO | | | 16 | | 0 | | | int2 | NO | | | | | b | 21
users | parent_id | bigint | YES | | | 64 | | 0 | | | int8 | NO | | | | | b | 22
`))
m.ExpectQuery(queryIndexes).
WithArgs("public", "users").
WillReturnRows(sqltest.Rows(`
table_name | index_name | index_type | column_name | primary | unique | constraint_type | predicate | expression | desc | nulls_first | nulls_last | comment
----------------+-----------------+-------------+-------------+---------+--------+-----------------+-----------------------+---------------------------+------+-------------+------------+-----------
users | idx | hash | left | f | f | | | "left"((c11)::text, 100) | t | t | f | boring
users | idx1 | btree | left | f | f | | (id <> NULL::integer) | "left"((c11)::text, 100) | t | t | f |
users | idx | hash | | f | f | | | "left"((c11)::text, 100) | t | t | f | boring
users | idx1 | btree | | f | f | | (id <> NULL::integer) | "left"((c11)::text, 100) | t | t | f |
users | t1_c1_key | btree | c1 | f | t | u | | | t | t | f |
users | t1_pkey | btree | id | t | t | p | | | t | f | f |
users | idx4 | btree | c1 | f | t | | | | f | f | f |
users | idx4 | btree | id | f | t | | | | f | f | t |
users | idx4 | btree | id | f | t | | | | f | f | t |
users | idx5 | btree | c1 | f | t | | | c1 | f | f | f |
users | idx5 | btree | | f | t | | | coalesce(parent_id, 0) | f | f | f |
`))
m.noFKs()
m.noChecks()
Expand All @@ -152,12 +155,14 @@ users | idx4 | btree | id | f | t
columns := []*schema.Column{
{Name: "id", Type: &schema.ColumnType{Raw: "bigint", Type: &schema.IntegerType{T: "bigint"}}},
{Name: "c1", Type: &schema.ColumnType{Raw: "smallint", Type: &schema.IntegerType{T: "smallint"}}},
{Name: "parent_id", Type: &schema.ColumnType{Raw: "bigint", Null: true, Type: &schema.IntegerType{T: "bigint"}}},
}
indexes := []*schema.Index{
{Name: "idx", Table: t, Attrs: []schema.Attr{&IndexType{T: "hash"}, &schema.Comment{Text: "boring"}}, Parts: []*schema.IndexPart{{SeqNo: 1, X: &schema.RawExpr{X: `"left"((c11)::text, 100)`}, Desc: true, Attrs: []schema.Attr{&IndexColumnProperty{NullsFirst: true}}}}},
{Name: "idx1", Table: t, Attrs: []schema.Attr{&IndexType{T: "btree"}, &IndexPredicate{P: `(id <> NULL::integer)`}}, Parts: []*schema.IndexPart{{SeqNo: 1, X: &schema.RawExpr{X: `"left"((c11)::text, 100)`}, Desc: true, Attrs: []schema.Attr{&IndexColumnProperty{NullsFirst: true}}}}},
{Name: "t1_c1_key", Unique: true, Table: t, Attrs: []schema.Attr{&IndexType{T: "btree"}, &ConType{T: "u"}}, Parts: []*schema.IndexPart{{SeqNo: 1, C: columns[1], Desc: true, Attrs: []schema.Attr{&IndexColumnProperty{NullsFirst: true}}}}},
{Name: "idx4", Unique: true, Table: t, Attrs: []schema.Attr{&IndexType{T: "btree"}}, Parts: []*schema.IndexPart{{SeqNo: 1, C: columns[1]}, {SeqNo: 2, C: columns[0], Attrs: []schema.Attr{&IndexColumnProperty{NullsLast: true}}}}},
{Name: "idx5", Unique: true, Table: t, Attrs: []schema.Attr{&IndexType{T: "btree"}}, Parts: []*schema.IndexPart{{SeqNo: 1, C: columns[1]}, {SeqNo: 2, X: &schema.RawExpr{X: `coalesce(parent_id, 0)`}}}},
}
pk := &schema.Index{
Name: "t1_pkey",
Expand Down Expand Up @@ -497,7 +502,7 @@ type mock struct {
func (m mock) version(version string) {
m.ExpectQuery(sqltest.Escape(paramsQuery)).
WillReturnRows(sqltest.Rows(`
setting
setting
------------
en_US.utf8
en_US.utf8
Expand Down
26 changes: 20 additions & 6 deletions sql/postgres/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func (p *planApply) PlanChanges(ctx context.Context, name string, changes []sche
Reversible: true,
Transactional: true,
},
enums: make(map[string]struct{}),
}
if err := s.plan(ctx, changes); err != nil {
return nil, err
Expand All @@ -52,6 +53,8 @@ func (p *planApply) ApplyChanges(ctx context.Context, changes []schema.Change) e
type state struct {
conn
migrate.Plan

enums map[string]struct{}
}

// Exec executes the changes on the database. An error is returned
Expand Down Expand Up @@ -423,13 +426,18 @@ func (s *state) addTypes(ctx context.Context, columns ...*schema.Column) error {
if e.T == "" {
return fmt.Errorf("missing enum name for column %q", c.Name)
}
schemaName := "public" // This should never happen! Enums should always have a schema
if e.Schema != nil {
schemaName = e.Schema.Name
}
c.Type.Raw = e.T
if exists, err := s.enumExists(ctx, e.T); err != nil {
if exists, err := s.enumExists(ctx, schemaName, e.T); err != nil {
return err
} else if exists {
continue
}
b := Build("CREATE TYPE").Ident(e.T).P("AS ENUM")
s.enums[e.T] = struct{}{}
b := Build("CREATE TYPE").EnumType(e).P("AS ENUM")
b.Wrap(func(b *sqlx.Builder) {
b.MapComma(e.Values, func(i int, b *sqlx.Builder) {
b.WriteString("'" + e.Values[i] + "'")
Expand All @@ -438,7 +446,7 @@ func (s *state) addTypes(ctx context.Context, columns ...*schema.Column) error {
s.append(&migrate.Change{
Cmd: b.String(),
Comment: fmt.Sprintf("create enum type %q", e.T),
Reverse: Build("DROP TYPE").Ident(e.T).String(),
Reverse: Build("DROP TYPE").EnumType(e).String(),
})
}
return nil
Expand All @@ -462,8 +470,12 @@ func (s *state) alterType(from, to *schema.EnumType) error {
return nil
}

func (s *state) enumExists(ctx context.Context, name string) (bool, error) {
rows, err := s.QueryContext(ctx, "SELECT * FROM pg_type WHERE typname = $1 AND typtype = 'e'", name)
func (s *state) enumExists(ctx context.Context, schema, name string) (bool, error) {
if _, found := s.enums[name]; found {
return true, nil
}

rows, err := s.QueryContext(ctx, "SELECT ns.nspname, t.typname FROM pg_type t INNER JOIN pg_namespace ns ON ns.oid = t.typnamespace WHERE ns.nspname = $1 AND typname = $2 AND typtype = 'e'", schema, name)
if err != nil {
return false, fmt.Errorf("check index existence: %w", err)
}
Expand Down Expand Up @@ -545,7 +557,9 @@ func (s *state) columnDefault(b *sqlx.Builder, c *schema.Column) {
switch x := c.Default.(type) {
case *schema.Literal:
v := x.V
switch c.Type.Type.(type) {
switch t := c.Type.Type.(type) {
case *schema.EnumType:
v = quote(v) + "::" + t.Schema.Name + "." + t.T
case *schema.BoolType, *schema.DecimalType, *schema.IntegerType, *schema.FloatType:
default:
v = quote(v)
Expand Down
Loading