diff --git a/pkg/rain/model.go b/pkg/rain/model.go index 09be31a..0cd2c4a 100644 --- a/pkg/rain/model.go +++ b/pkg/rain/model.go @@ -4,10 +4,12 @@ import ( "database/sql" "errors" "fmt" + "math" "reflect" "strings" "sync" "time" + "unsafe" "github.com/hyperlocalise/rain-orm/pkg/schema" ) @@ -38,6 +40,11 @@ type scanColumnPlan struct { isDirect bool columnDef *schema.ColumnDef fieldType reflect.Type + + // OPTIMIZATION: Bypassing reflection in the hot loop using unsafe offsets. + offset uintptr + kind reflect.Kind + canUseOffset bool } type rowScanScratch struct { @@ -293,18 +300,17 @@ func scanRowsAgainstTableDirect(rows *sql.Rows, dest any, table *schema.TableDef } item := items.Index(n) - var scanTarget reflect.Value if pointerElems { item.Set(reflect.New(structType)) - scanTarget = item.Elem() + if err := scanDirectRow(item.Elem(), plan, scratch); err != nil { + return err + } } else { // Reset existing element to its zero state before reuse to avoid data carry-over. item.Set(zeroElem) - scanTarget = item - } - - if err := scanDirectRow(scanTarget, plan, scratch); err != nil { - return err + if err := scanDirectRowAddr(item.Addr().UnsafePointer(), item, plan, scratch); err != nil { + return err + } } } if err := rows.Err(); err != nil { @@ -330,34 +336,112 @@ func newScanTargets(cols []string) ([]any, []any) { } func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScratch) error { + baseAddr := target.Addr().UnsafePointer() + return scanDirectRowAddr(baseAddr, target, plan, scratch) +} + +func scanDirectRowAddr(baseAddr unsafe.Pointer, target reflect.Value, plan *rowScanPlan, scratch *rowScanScratch) error { for i := range plan.int64ValueCols { col := &plan.int64ValueCols[i] v := &scratch.ints[col.scratchIndex] if !v.Valid { return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) } - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err + + if col.canUseOffset { + ptr := unsafe.Add(baseAddr, col.offset) + switch col.kind { + case reflect.Int64: + *(*int64)(ptr) = v.Int64 + case reflect.Int32: + val := int32(v.Int64) + if int64(val) != v.Int64 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*int32)(ptr) = val + case reflect.Int16: + val := int16(v.Int64) + if int64(val) != v.Int64 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*int16)(ptr) = val + case reflect.Int8: + val := int8(v.Int64) + if int64(val) != v.Int64 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*int8)(ptr) = val + case reflect.Int: + val := int(v.Int64) + if int64(val) != v.Int64 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*int)(ptr) = val + case reflect.Uint64: + if v.Int64 < 0 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*uint64)(ptr) = uint64(v.Int64) + case reflect.Uint32: + if v.Int64 < 0 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + val := uint32(v.Int64) + if uint64(val) != uint64(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*uint32)(ptr) = val + case reflect.Uint16: + if v.Int64 < 0 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + val := uint16(v.Int64) + if uint64(val) != uint64(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*uint16)(ptr) = val + case reflect.Uint8: + if v.Int64 < 0 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + val := uint8(v.Int64) + if uint64(val) != uint64(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*uint8)(ptr) = val + case reflect.Uint: + if v.Int64 < 0 { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + val := uint(v.Int64) + if uint64(val) != uint64(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, col.fieldType) + } + *(*uint)(ptr) = val + default: + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + if err := assignRawValueToField(field, v.Int64); err != nil { + return err + } } - } else { - field = target.Field(col.index0) + continue } - kind := field.Kind() - // OPTIMIZATION: Fast-path Int64 (the most common DB int type) and use range - // checks for other integer types to minimize branch overhead and avoid - // redundant overflow checks for the common path. - if kind == reflect.Int64 { + + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + if field.Kind() == reflect.Int64 { field.SetInt(v.Int64) - } else if kind >= reflect.Int && kind <= reflect.Int32 { + } else if field.Kind() >= reflect.Int && field.Kind() <= reflect.Int32 { if field.OverflowInt(v.Int64) { return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) } field.SetInt(v.Int64) - } else if kind >= reflect.Uint && kind <= reflect.Uint64 { + } else if field.Kind() >= reflect.Uint && field.Kind() <= reflect.Uint64 { if v.Int64 < 0 || field.OverflowUint(uint64(v.Int64)) { return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) } @@ -371,15 +455,9 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra for i := range plan.int64PointerCols { col := &plan.int64PointerCols[i] v := &scratch.ints[col.scratchIndex] - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if !v.Valid { field.SetZero() @@ -389,16 +467,14 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra field.Set(reflect.New(field.Type().Elem())) } field = field.Elem() - kind := field.Kind() - // OPTIMIZATION: Fast-path Int64 and use range checks to reduce overhead in the hot loop. - if kind == reflect.Int64 { + if field.Kind() == reflect.Int64 { field.SetInt(v.Int64) - } else if kind >= reflect.Int && kind <= reflect.Int32 { + } else if field.Kind() >= reflect.Int && field.Kind() <= reflect.Int32 { if field.OverflowInt(v.Int64) { return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) } field.SetInt(v.Int64) - } else if kind >= reflect.Uint && kind <= reflect.Uint64 { + } else if field.Kind() >= reflect.Uint && field.Kind() <= reflect.Uint64 { if v.Int64 < 0 || field.OverflowUint(uint64(v.Int64)) { return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) } @@ -415,15 +491,13 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra if !v.Valid { return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) } - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + if col.canUseOffset { + *(*string)(unsafe.Add(baseAddr, col.offset)) = v.String + continue + } + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if field.Kind() == reflect.String { field.SetString(v.String) @@ -436,15 +510,9 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra for i := range plan.stringPointerCols { col := &plan.stringPointerCols[i] v := &scratch.strings[col.scratchIndex] - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if !v.Valid { field.SetZero() @@ -468,15 +536,13 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra if !v.Valid { return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) } - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + if col.canUseOffset { + *(*bool)(unsafe.Add(baseAddr, col.offset)) = v.Bool + continue + } + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if field.Kind() == reflect.Bool { field.SetBool(v.Bool) @@ -489,15 +555,9 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra for i := range plan.boolPointerCols { col := &plan.boolPointerCols[i] v := &scratch.bools[col.scratchIndex] - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if !v.Valid { field.SetZero() @@ -521,15 +581,26 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra if !v.Valid { return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) } - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err + if col.canUseOffset { + ptr := unsafe.Add(baseAddr, col.offset) + switch col.kind { + case reflect.Float64: + *(*float64)(ptr) = v.Float64 + case reflect.Float32: + f64 := v.Float64 + if f64 < 0 { + f64 = -f64 + } + if math.MaxFloat32 < f64 && f64 <= math.MaxFloat64 { + return fmt.Errorf("rain: value %f overflows field %s", v.Float64, col.fieldType) + } + *(*float32)(ptr) = float32(v.Float64) } - } else { - field = target.Field(col.index0) + continue + } + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if field.Kind() == reflect.Float32 || field.Kind() == reflect.Float64 { if field.OverflowFloat(v.Float64) { @@ -545,15 +616,9 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra for i := range plan.float64PointerCols { col := &plan.float64PointerCols[i] v := &scratch.floats[col.scratchIndex] - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if !v.Valid { field.SetZero() @@ -580,15 +645,13 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra if !v.Valid { return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) } - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + if col.canUseOffset { + *(*time.Time)(unsafe.Add(baseAddr, col.offset)) = v.Time + continue + } + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if field.Type() == reflect.TypeFor[time.Time]() { *field.Addr().Interface().(*time.Time) = v.Time @@ -601,15 +664,9 @@ func scanDirectRow(target reflect.Value, plan *rowScanPlan, scratch *rowScanScra for i := range plan.timePointerCols { col := &plan.timePointerCols[i] v := &scratch.times[col.scratchIndex] - var field reflect.Value - if col.isComplex { - var err error - field, err = fieldByIndexAlloc(target, col.fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(col.index0) + field, err := fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err } if !v.Valid { field.SetZero() @@ -785,16 +842,25 @@ func newRowScanPlanForColumns(cols []string, modelType reflect.Type, table *sche } var fieldType reflect.Type + var offset uintptr + canUseOffset := true if isComplex { fieldType = modelType for _, i := range fieldInfo.index { if fieldType.Kind() == reflect.Pointer { fieldType = fieldType.Elem() + canUseOffset = false + } + f := fieldType.Field(i) + if canUseOffset { + offset += f.Offset } - fieldType = fieldType.Field(i).Type + fieldType = f.Type } } else { - fieldType = modelType.Field(index0).Type + f := modelType.Field(index0) + offset = f.Offset + fieldType = f.Type } isDirect := !isJSON && isSimpleDirectType(fieldType) @@ -803,15 +869,18 @@ func newRowScanPlanForColumns(cols []string, modelType reflect.Type, table *sche } colPlan := scanColumnPlan{ - columnName: name, - scanIndex: idx, - fieldIndex: fieldInfo.index, - index0: index0, - isComplex: isComplex, - isJSON: isJSON, - isDirect: isDirect, - columnDef: columnDef, - fieldType: fieldType, + columnName: name, + scanIndex: idx, + fieldIndex: fieldInfo.index, + index0: index0, + isComplex: isComplex, + isJSON: isJSON, + isDirect: isDirect, + columnDef: columnDef, + fieldType: fieldType, + offset: offset, + kind: fieldType.Kind(), + canUseOffset: canUseOffset, } if isDirect {