diff --git a/postgres.go b/postgres.go index dade1b5..94c00c3 100644 --- a/postgres.go +++ b/postgres.go @@ -208,33 +208,41 @@ func (dialector Dialector) Explain(sql string, vars ...interface{}) string { } func (dialector Dialector) DataTypeOf(field *schema.Field) string { + // PostgreSQL 10+ generated columns. The value-generation strategy is carried + // by the `generated` tag, intentionally kept separate from the column `type`: + // + // `gorm:"generated:identity"` -> GENERATED BY DEFAULT AS IDENTITY + // `gorm:"generated:identity always"` -> GENERATED ALWAYS AS IDENTITY + // `gorm:"generated:price * quantity"` -> GENERATED ALWAYS AS (price * quantity) STORED + // + // https://github.com/go-gorm/gorm/issues/7191 + if gen, ok := generatedColumnOf(field); ok { + if gen.identity { + return dialector.getSchemaIntType(field) + " GENERATED " + gen.mode + " AS IDENTITY" + } + return dialector.getSchemaBaseType(field) + " GENERATED ALWAYS AS (" + gen.expr + ") STORED" + } + + return dialector.getSchemaBaseType(field) +} + +func (dialector Dialector) getSchemaBaseType(field *schema.Field) string { switch field.DataType { case schema.Bool: return "boolean" case schema.Int, schema.Uint: - size := field.Size - if field.DataType == schema.Uint { - size++ - } + intType := dialector.getSchemaIntType(field) if field.AutoIncrement { - switch { - case size <= 16: + switch intType { + case "smallint": return "smallserial" - case size <= 32: + case "integer": return "serial" default: return "bigserial" } - } else { - switch { - case size <= 16: - return "smallint" - case size <= 32: - return "integer" - default: - return "bigint" - } } + return intType case schema.Float: if field.Precision > 0 { if field.Scale > 0 { @@ -260,6 +268,22 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } } +func (dialector Dialector) getSchemaIntType(field *schema.Field) string { + size := field.Size + if field.DataType == schema.Uint { + size++ + } + + switch { + case size <= 16: + return "smallint" + case size <= 32: + return "integer" + default: + return "bigint" + } +} + func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { sqlType := string(field.DataType) @@ -303,3 +327,55 @@ func getSerialDatabaseType(s string) (dbType string, ok bool) { return "", false } } + +// generatedColumn describes a PostgreSQL generated column parsed from a +// `generated` tag: either an identity column or a STORED computed column. +type generatedColumn struct { + identity bool // identity column: GENERATED { mode } AS IDENTITY + mode string // identity generation mode: "BY DEFAULT" or "ALWAYS" + expr string // computed column expression: GENERATED ALWAYS AS (expr) STORED +} + +// generatedColumnOf parses the `generated` tag. The value is either the keyword +// `identity` (optionally combined with the mode `always` / `by default`) for an +// identity column, or any other value, which is taken verbatim as the expression +// of a STORED computed column. +func generatedColumnOf(field *schema.Field) (generatedColumn, bool) { + value, ok := field.TagSettings["GENERATED"] + if !ok { + return generatedColumn{}, false + } + + // Ignore an empty value or a bare `generated` tag, which the tag parser + // stores as the upper-cased key, rather than treating it as an expression. + if value = strings.TrimSpace(value); value == "" || value == "GENERATED" { + return generatedColumn{}, false + } + + if mode, isIdentity := identityMode(value); isIdentity { + return generatedColumn{identity: true, mode: mode}, true + } + + return generatedColumn{expr: value}, true +} + +// identityMode reports whether value describes an identity column and, if so, +// its generation mode. The recognized keywords are `identity`, `always` and +// `by default`, in any order; any other token means value is a computed +// expression rather than an identity specification. +func identityMode(value string) (mode string, ok bool) { + mode = "BY DEFAULT" + for _, token := range strings.Fields(strings.ToLower(value)) { + switch token { + case "identity": + ok = true + case "always": + mode = "ALWAYS" + case "by", "default": + // part of the "by default" mode, which is the default; ignore + default: + return "", false + } + } + return mode, ok +} diff --git a/postgres_test.go b/postgres_test.go index 4f28726..dfc67da 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -2,7 +2,8 @@ package postgres import ( "testing" -"gorm.io/gorm/schema" + + "gorm.io/gorm/schema" ) func Test_DataTypeOf(t *testing.T) { @@ -13,11 +14,11 @@ func Test_DataTypeOf(t *testing.T) { field *schema.Field } tests := []struct { - name string + name string fields fields - args args - want string - } { + args args + want string + }{ { name: "it should return boolean", args: args{field: &schema.Field{DataType: schema.Bool}}, @@ -43,6 +44,97 @@ func Test_DataTypeOf(t *testing.T) { args: args{field: &schema.Field{DataType: schema.String, Size: 10485760}}, want: "varchar(10485760)", }, + { + // https://github.com/go-gorm/gorm/issues/7191 + name: "generated:identity renders an identity column (BY DEFAULT)", + args: args{field: &schema.Field{ + DataType: schema.Uint, + GORMDataType: schema.Uint, + Size: 64, + PrimaryKey: true, + AutoIncrement: true, + TagSettings: map[string]string{"GENERATED": "identity"}, + }}, + want: "bigint GENERATED BY DEFAULT AS IDENTITY", + }, + { + name: "generated:identity always renders an ALWAYS identity column", + args: args{field: &schema.Field{ + DataType: schema.Uint, + GORMDataType: schema.Uint, + Size: 64, + PrimaryKey: true, + AutoIncrement: true, + TagSettings: map[string]string{"GENERATED": "identity always"}, + }}, + want: "bigint GENERATED ALWAYS AS IDENTITY", + }, + { + name: "generated:always identity is order independent", + args: args{field: &schema.Field{ + DataType: schema.Int, + GORMDataType: schema.Int, + Size: 32, + TagSettings: map[string]string{"GENERATED": "always identity"}, + }}, + want: "integer GENERATED ALWAYS AS IDENTITY", + }, + { + name: "generated: renders a STORED computed column", + args: args{field: &schema.Field{ + DataType: "numeric", + TagSettings: map[string]string{"GENERATED": "price * quantity"}, + }}, + want: "numeric GENERATED ALWAYS AS (price * quantity) STORED", + }, + { + name: "generated: keeps commas inside the expression", + args: args{field: &schema.Field{ + DataType: schema.String, + Size: -1, + TagSettings: map[string]string{"GENERATED": "coalesce(first_name, last_name)"}, + }}, + want: "text GENERATED ALWAYS AS (coalesce(first_name, last_name)) STORED", + }, + { + name: "a bare generated tag is ignored", + args: args{field: &schema.Field{ + DataType: schema.Uint, + GORMDataType: schema.Uint, + Size: 64, + AutoIncrement: true, + TagSettings: map[string]string{"GENERATED": "GENERATED"}, + }}, + want: "bigserial", + }, + { + name: "a lowercase generated expression is not mistaken for a bare tag", + args: args{field: &schema.Field{ + DataType: "numeric", + TagSettings: map[string]string{"GENERATED": "generated"}, + }}, + want: "numeric GENERATED ALWAYS AS (generated) STORED", + }, + { + name: "it should still convert a plain custom type to bigserial for an auto increment field", + args: args{field: &schema.Field{ + DataType: "bigint", + GORMDataType: schema.Uint, + AutoIncrement: true, + Size: 64, + }}, + want: "bigserial", + }, + { + name: "it should keep an explicit bigserial type for an auto increment field", + args: args{field: &schema.Field{ + DataType: "bigserial", + GORMDataType: schema.Uint, + AutoIncrement: true, + Size: 64, + }}, + want: "bigserial", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -54,4 +146,4 @@ func Test_DataTypeOf(t *testing.T) { } }) } -} \ No newline at end of file +}