From 3cef66aa93bfdbfc66ff3f96a739323e55be132a Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 22:36:03 +0000 Subject: [PATCH] feat(schema): add support for database views and serial types Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/dialect/dialect.go | 4 ++++ pkg/dialect/mysql.go | 4 ++++ pkg/dialect/postgres.go | 4 ++++ pkg/dialect/sqlite.go | 2 +- pkg/migrator/diff.go | 15 +++++++++++++ pkg/migrator/snapshot.go | 2 ++ pkg/rain/ddl.go | 37 +++++++++++++++++++++++++++++-- pkg/rain/model_binding.go | 4 ++-- pkg/rain/query_compile.go | 35 ++++++++++++++++++++++------- pkg/rain/query_delete.go | 4 ++++ pkg/rain/query_insert.go | 3 +++ pkg/rain/query_update.go | 4 ++++ pkg/schema/schema.go | 46 ++++++++++++++++++++++++++++++++++++++- 13 files changed, 150 insertions(+), 14 deletions(-) diff --git a/pkg/dialect/dialect.go b/pkg/dialect/dialect.go index eab97e6..97744c9 100644 --- a/pkg/dialect/dialect.go +++ b/pkg/dialect/dialect.go @@ -61,6 +61,10 @@ func (d *BaseDialect) DataType(columnType schema.ColumnType) string { typ := normalizeType(columnType.DataType) switch typ { + case "smallserial": + return "SMALLINT" + case "serial": + return "INTEGER" case "bigserial": return "BIGSERIAL" case "smallint": diff --git a/pkg/dialect/mysql.go b/pkg/dialect/mysql.go index 78195f8..6a9bb83 100644 --- a/pkg/dialect/mysql.go +++ b/pkg/dialect/mysql.go @@ -38,6 +38,10 @@ func (d *MySQLDialect) DataType(columnType schema.ColumnType) string { typ := normalizeType(columnType.DataType) switch typ { + case "smallserial": + return "SMALLINT" + case "serial": + return "INT" case "bigserial": return "BIGINT" case "smallint": diff --git a/pkg/dialect/postgres.go b/pkg/dialect/postgres.go index cf0bf09..e84129a 100644 --- a/pkg/dialect/postgres.go +++ b/pkg/dialect/postgres.go @@ -49,6 +49,10 @@ func (d *PostgresDialect) DataType(columnType schema.ColumnType) string { typ := normalizeType(columnType.DataType) switch typ { + case "smallserial": + return "SMALLSERIAL" + case "serial": + return "SERIAL" case "bigserial": return "BIGSERIAL" case "smallint": diff --git a/pkg/dialect/sqlite.go b/pkg/dialect/sqlite.go index cb119f6..612440e 100644 --- a/pkg/dialect/sqlite.go +++ b/pkg/dialect/sqlite.go @@ -44,7 +44,7 @@ func (d *SQLiteDialect) DataType(columnType schema.ColumnType) string { typ := normalizeType(columnType.DataType) switch typ { - case "bigserial": + case "smallserial", "serial", "bigserial": return "INTEGER" case "string", "varchar", "text": return "TEXT" diff --git a/pkg/migrator/diff.go b/pkg/migrator/diff.go index 3d7aa62..871d830 100644 --- a/pkg/migrator/diff.go +++ b/pkg/migrator/diff.go @@ -82,6 +82,21 @@ func planCreateAll(snapshot Snapshot) Plan { } func diffTable(previous, current TableSnapshot, dialectName string) ([]string, error) { + if previous.IsView != current.IsView { + return nil, fmt.Errorf("migrator: changing %q from view=%v to view=%v is not supported", current.Name, previous.IsView, current.IsView) + } + + if current.IsView { + if normalizeSQL(previous.CreateTableSQL) == normalizeSQL(current.CreateTableSQL) { + return nil, nil + } + // View changed - drop and recreate + return []string{ + "DROP VIEW " + quoteIdentifier(dialectName, current.Name), + current.CreateTableSQL, + }, nil + } + var statements []string previousColumns := make(map[string]ColumnSnapshot, len(previous.Columns)) diff --git a/pkg/migrator/snapshot.go b/pkg/migrator/snapshot.go index 12a69e8..f75dbc8 100644 --- a/pkg/migrator/snapshot.go +++ b/pkg/migrator/snapshot.go @@ -21,6 +21,7 @@ type Snapshot struct { // TableSnapshot stores a portable, deterministic representation of one table. type TableSnapshot struct { Name string `json:"name"` + IsView bool `json:"is_view,omitempty"` CreateTableSQL string `json:"create_table_sql"` Columns []ColumnSnapshot `json:"columns"` Constraints []ConstraintSnapshot `json:"constraints"` @@ -167,6 +168,7 @@ func BuildSnapshot(dialectName string, tables []schema.TableReference) (Snapshot tableSnapshots = append(tableSnapshots, TableSnapshot{ Name: tableDef.Name, + IsView: tableDef.IsView, CreateTableSQL: createTableSQL, Columns: columnSnapshots, Constraints: constraintSnapshots, diff --git a/pkg/rain/ddl.go b/pkg/rain/ddl.go index b17b0fb..e3671e5 100644 --- a/pkg/rain/ddl.go +++ b/pkg/rain/ddl.go @@ -140,6 +140,10 @@ func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { return "", errors.New("rain: create table requires a non-nil table") } + if table.IsView { + return createViewSQL(d, table) + } + var definitions []string tablePrimaryKey, err := tablePrimaryKeyConstraint(table) if err != nil { @@ -349,7 +353,7 @@ func columnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string { typeSQL = fmt.Sprintf("%s(%d)", typeSQL, column.Type.TimePrecision) } - if column.AutoIncrement && d.Name() == "sqlite" && column.Type.DataType == schema.TypeBigSerial { + if column.AutoIncrement && d.Name() == "sqlite" && (column.Type.DataType == schema.TypeSmallSerial || column.Type.DataType == schema.TypeSerial || column.Type.DataType == schema.TypeBigSerial) { return "INTEGER" } @@ -363,7 +367,7 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef, if !inlinePrimaryKey { return false } - if column.Type.DataType != schema.TypeBigSerial { + if column.Type.DataType != schema.TypeSmallSerial && column.Type.DataType != schema.TypeSerial && column.Type.DataType != schema.TypeBigSerial { return true } @@ -568,8 +572,37 @@ func predicateDDLSQL(d dialect.Dialect, table *schema.TableDef, predicate schema return expressionDDLSQL(d, table, predicate) } +func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { + if table.ViewQuery == nil { + return "", fmt.Errorf("rain: view %q requires a defining query", table.Name) + } + + ctx := newCompileContext(d) + ctx.useLiterals = true + // Views usually don't support or need parentheses around the entire SELECT + // across all dialects, and SQLite specifically rejects them. + if selectQuery, ok := table.ViewQuery.(*SelectQuery); ok { + if err := selectQuery.writeSQL(ctx); err != nil { + return "", err + } + } else { + if err := ctx.writeExpression(table.ViewQuery); err != nil { + return "", err + } + } + + return "CREATE VIEW " + d.QuoteIdentifier(table.Name) + " AS " + ctx.String(), nil +} + func expressionDDLSQL(d dialect.Dialect, table *schema.TableDef, expr schema.Expression) (string, error) { switch value := expr.(type) { + case *SelectQuery: + ctx := newCompileContext(d) + ctx.useLiterals = true + if err := value.writeSQL(ctx); err != nil { + return "", err + } + return ctx.String(), nil case schema.ColumnReference: column := value.ColumnDef() if column == nil { diff --git a/pkg/rain/model_binding.go b/pkg/rain/model_binding.go index 21c6de5..0282a97 100644 --- a/pkg/rain/model_binding.go +++ b/pkg/rain/model_binding.go @@ -207,7 +207,7 @@ func supportsScanForColumn(column *schema.ColumnDef, fieldType reflect.Type) boo baseType, _ := unwrapFieldType(fieldType) switch column.Type.DataType { - case schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt: + case schema.TypeSmallSerial, schema.TypeSerial, schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt: return isIntegerKind(baseType.Kind()) case schema.TypeReal, schema.TypeDouble: return baseType.Kind() == reflect.Float32 || baseType.Kind() == reflect.Float64 @@ -242,7 +242,7 @@ func supportsWriteForColumn(column *schema.ColumnDef, fieldType reflect.Type) bo baseType, _ := unwrapFieldType(fieldType) switch column.Type.DataType { - case schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt: + case schema.TypeSmallSerial, schema.TypeSerial, schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt: return isIntegerKind(baseType.Kind()) case schema.TypeReal, schema.TypeDouble: return baseType.Kind() == reflect.Float32 || baseType.Kind() == reflect.Float64 diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 173f7f3..e737616 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -76,11 +76,12 @@ func (q compiledQuery) bind(args PreparedArgs) ([]any, error) { } type compileContext struct { - builder strings.Builder - dialect dialect.Dialect - argPlan []compiledArg - err error - skipCTEs bool + builder strings.Builder + dialect dialect.Dialect + argPlan []compiledArg + err error + skipCTEs bool + useLiterals bool } func newCompileContext(d dialect.Dialect) *compileContext { @@ -137,6 +138,15 @@ func (c *compileContext) writeTable(table *schema.TableDef) { } } +func (c *compileContext) writeLiteral(value any) error { + literal, err := literalDDLSQL(c.dialect, value) + if err != nil { + return err + } + c.writeString(literal) + return nil +} + func (c *compileContext) writeReturning(exprs []schema.Expression, clause returningClause) error { if len(exprs) == 0 { return nil @@ -180,9 +190,15 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex case schema.ColumnReference: c.writeColumn(value) case schema.ValueExpr: - index := c.nextPlaceholderIndex() - c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value}) - c.writeString(c.dialect.Placeholder(index)) + if c.useLiterals { + if err := c.writeLiteral(value.Value); err != nil { + return err + } + } else { + index := c.nextPlaceholderIndex() + c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value}) + c.writeString(c.dialect.Placeholder(index)) + } case schema.PlaceholderExpr: if strings.TrimSpace(value.Name) == "" { return errors.New("rain: placeholder name cannot be empty") @@ -260,9 +276,12 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if !context.noParens { c.writeByte('(') } + prevSkip := c.skipCTEs + c.skipCTEs = true if err := value.writeSQL(c); err != nil { return err } + c.skipCTEs = prevSkip if !context.noParens { c.writeByte(')') } diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 8728fb9..75681b3 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -48,6 +49,9 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { if q.table == nil { return "", nil, errors.New("rain: delete query requires a table") } + if q.table.IsView { + return "", nil, fmt.Errorf("rain: cannot delete from view %q", q.table.Name) + } if len(q.where) == 0 && !q.unbounded { return "", nil, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows") } diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index dff13c9..3b00254 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -284,6 +284,9 @@ func (q *InsertQuery) validateSources() error { if q.table == nil { return errors.New("rain: insert query requires a table") } + if q.table.IsView { + return fmt.Errorf("rain: cannot insert into view %q", q.table.Name) + } sources := 0 if q.model != nil || len(q.values) > 0 { diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index 7744a52..f1b260a 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -62,6 +63,9 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { if q.table == nil { return "", nil, errors.New("rain: update query requires a table") } + if q.table.IsView { + return "", nil, fmt.Errorf("rain: cannot update view %q", q.table.Name) + } if len(q.values) == 0 { return "", nil, errors.New("rain: update query requires at least one assignment") } diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index b6637c2..cad3bf7 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -19,6 +19,8 @@ type TimestampKind string // Supported schema data types. const ( + TypeSmallSerial DataType = "SMALLSERIAL" + TypeSerial DataType = "SERIAL" TypeBigSerial DataType = "BIGSERIAL" TypeSmallInt DataType = "SMALLINT" TypeInteger DataType = "INTEGER" @@ -111,6 +113,8 @@ type ColumnType struct { type TableDef struct { Name string Alias string + IsView bool + ViewQuery Expression Columns []*ColumnDef Indexes []IndexDef Constraints []ConstraintDef @@ -258,6 +262,16 @@ func (t *TableModel) C(name string) *AnyColumn { return &AnyColumn{def: col} } +// SmallSerial adds a SMALLSERIAL column. +func (t *TableModel) SmallSerial(name string) *Column[int16] { + return addColumn[int16](t.def, name, ColumnType{DataType: TypeSmallSerial}, false, true) +} + +// Serial adds a SERIAL column. +func (t *TableModel) Serial(name string) *Column[int32] { + return addColumn[int32](t.def, name, ColumnType{DataType: TypeSerial}, false, true) +} + // BigSerial adds a BIGSERIAL column. func (t *TableModel) BigSerial(name string) *Column[int64] { return addColumn[int64](t.def, name, ColumnType{DataType: TypeBigSerial}, false, true) @@ -483,6 +497,31 @@ func Define[T any](name string, fn func(*T)) *T { return handle } +// DefineView creates a typed view handle backed by schema metadata and a defining query. +func DefineView[T any](name string, query Expression, fn func(*T)) *T { + if query == nil { + panic("schema: DefineView requires a query") + } + + handle := new(T) + def := &TableDef{ + Name: name, + IsView: true, + ViewQuery: query, + Columns: make([]*ColumnDef, 0, 8), + Indexes: make([]IndexDef, 0), + Constraints: make([]ConstraintDef, 0), + ForeignKeys: make([]ForeignKeyDef, 0), + Relations: make([]RelationDef, 0, 4), + columnsByName: make(map[string]*ColumnDef, 8), + relationsByName: make(map[string]RelationDef, 4), + } + bindTableModel(handle, def) + fn(handle) + + return handle +} + // Alias clones a typed table handle with a SQL alias. func Alias[T any](src *T, alias string) *T { clone := new(T) @@ -578,7 +617,7 @@ func (c *Column[T]) ColumnDef() *ColumnDef { func (c *Column[T]) PrimaryKey() *Column[T] { c.def.PrimaryKey = true c.def.Nullable = false - if c.def.Type.DataType == TypeBigSerial { + if c.def.Type.DataType == TypeSmallSerial || c.def.Type.DataType == TypeSerial || c.def.Type.DataType == TypeBigSerial { c.def.AutoIncrement = true } @@ -1384,6 +1423,11 @@ func cloneTableDef(src *TableDef, alias string) *TableDef { relationsByName: make(map[string]RelationDef, len(src.Relations)), } + cloned.IsView = src.IsView + if src.ViewQuery != nil { + cloned.ViewQuery = cloneExpressionForTable(src.ViewQuery, cloned) + } + for _, column := range src.Columns { copyColumn := *column copyColumn.Type.EnumValues = append([]string(nil), column.Type.EnumValues...)