Skip to content

Commit

Permalink
sql/postgres: fixed enum default values
Browse files Browse the repository at this point in the history
  • Loading branch information
svstanev committed Apr 17, 2022
1 parent a668f73 commit cb20358
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
12 changes: 6 additions & 6 deletions internal/integration/testdata/postgres/column-enum.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ table "users" {

-- 1.sql --
Table "script_column_enum.users"
Column | Type | Collation | Nullable | Default
--------+--------+-----------+----------+------------------
type | status | | not null | 'active'::status
Column | Type | Collation | Nullable | Default
--------+---------------------------+-----------+----------+-------------------------------------
type | script_column_enum.status | | not null | 'active'::script_column_enum.status


-- 2.hcl --
Expand All @@ -46,6 +46,6 @@ table "users" {

-- 2.sql --
Table "script_column_enum.users"
Column | Type | Collation | Nullable | Default
--------+--------+-----------+----------+--------------------
type | status | | not null | 'inactive'::status
Column | Type | Collation | Nullable | Default
--------+---------------------------+-----------+----------+---------------------------------------
type | script_column_enum.status | | not null | 'inactive'::script_column_enum.status
7 changes: 7 additions & 0 deletions sql/postgres/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ func (i *inspect) enumValues(ctx context.Context, s *schema.Schema) error {
e, ok := enums[enum.ID]
if !ok {
e = &schema.EnumType{T: enum.T, Schema: s}
if expr, ok := c.Default.(*schema.RawExpr); ok {
parts := strings.Split(expr.X, "::")
name := fmt.Sprintf("%s.%s", s.Name, e.T)
if len(parts) == 2 && parts[1] == name {
c.Default = &schema.Literal{V: parts[0]}
}
}
enums[enum.ID] = e
args = append(args, enum.ID)
}
Expand Down
33 changes: 12 additions & 21 deletions sql/postgres/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *state) topLevel(changes []schema.Change) []schema.Change {
// addTable builds and executes the query for creating a table in a schema.
func (s *state) addTable(ctx context.Context, add *schema.AddTable) error {
// Create enum types before using them in the `CREATE TABLE` statement.
if err := s.addTypes(ctx, add.T.Schema, add.T.Columns...); err != nil {
if err := s.addTypes(ctx, add.T.Columns...); err != nil {
return err
}
b := Build("CREATE TABLE")
Expand Down Expand Up @@ -222,7 +222,7 @@ func (s *state) modifyTable(ctx context.Context, modify *schema.ModifyTable) err
F: change.To,
})
case *schema.AddColumn:
if err := s.addTypes(ctx, modify.T.Schema, change.C); err != nil {
if err := s.addTypes(ctx, change.C); err != nil {
return err
}
if c := (schema.Comment{}); sqlx.Has(change.C.Attrs, &c) {
Expand Down Expand Up @@ -257,7 +257,7 @@ func (s *state) modifyTable(ctx context.Context, modify *schema.ModifyTable) err
}
// Enum was added (and column type was changed).
case !ok1 && ok2:
if err := s.addTypes(ctx, modify.T.Schema, change.To); err != nil {
if err := s.addTypes(ctx, change.To); err != nil {
return err
}
}
Expand Down Expand Up @@ -417,11 +417,7 @@ func (s *state) dropIndexes(t *schema.Table, indexes ...*schema.Index) {
}
}

func (s *state) addTypes(ctx context.Context, ns *schema.Schema, columns ...*schema.Column) error {
schemaName := "public"
if ns != nil {
schemaName = ns.Name
}
func (s *state) addTypes(ctx context.Context, columns ...*schema.Column) error {
for _, c := range columns {
e, ok := c.Type.Type.(*schema.EnumType)
if !ok {
Expand All @@ -430,6 +426,10 @@ func (s *state) addTypes(ctx context.Context, ns *schema.Schema, columns ...*sch
if e.T == "" {
return fmt.Errorf("missing enum name for column %q", c.Name)
}
schemaName := "public" // This should never happen! Enums should always have a schema
if e.Schema != nil {
schemaName = e.Schema.Name
}
c.Type.Raw = e.T
if exists, err := s.enumExists(ctx, schemaName, e.T); err != nil {
return err
Expand Down Expand Up @@ -557,26 +557,17 @@ func (s *state) columnDefault(b *sqlx.Builder, c *schema.Column) {
switch x := c.Default.(type) {
case *schema.Literal:
v := x.V
switch c.Type.Type.(type) {
switch t := c.Type.Type.(type) {
case *schema.EnumType:
v = quote(v) + "::" + t.Schema.Name + "." + t.T
case *schema.BoolType, *schema.DecimalType, *schema.IntegerType, *schema.FloatType:
default:
v = quote(v)
}
b.P("DEFAULT", v)
case *schema.RawExpr:
// Ignore identity functions added by the differ.
if t, ok := c.Type.Type.(*schema.EnumType); ok {
parts := strings.Split(x.X, "::")
if len(parts) == 2 {
var s string
if t.Schema != nil {
s += t.Schema.Name + "."
}
parts[1] = s + t.T
}

b.P("DEFAULT", strings.Join(parts, "::"))
} else if _, ok := c.Type.Type.(*SerialType); !ok {
if _, ok := c.Type.Type.(*SerialType); !ok {
b.P("DEFAULT", x.X)
}
}
Expand Down
1 change: 1 addition & 0 deletions sql/postgres/sqlspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ func convertEnums(tbls []*sqlspec.Table, enums []*Enum, sch *schema.Schema) erro
return fmt.Errorf("postgrs: column %q not found in table %q", col.Name, t.Name)
}
c.Type.Type = &schema.EnumType{
Schema: sch,
T: e.Name,
Values: e.Values,
}
Expand Down

0 comments on commit cb20358

Please sign in to comment.