Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ func init() {
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
scope.CallMethod(callbackTypeBeforeSave)
}
if !scope.HasError() {
scope.CallMethod("BeforeCreate")
scope.CallMethod(callbackTypeBeforeCreate)
}
}

Expand Down Expand Up @@ -196,9 +196,9 @@ func forceReloadAfterCreateCallback(scope *Scope) {
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterCreate")
scope.CallMethod(callbackTypeAfterCreate)
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
scope.CallMethod(callbackTypeAfterSave)
}
}
4 changes: 2 additions & 2 deletions callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func beforeDeleteCallback(scope *Scope) {
return
}
if !scope.HasError() {
scope.CallMethod("BeforeDelete")
scope.CallMethod(callbackTypeBeforeDelete)
}
}

Expand Down Expand Up @@ -58,6 +58,6 @@ func deleteCallback(scope *Scope) {
// afterDeleteCallback will invoke `AfterDelete` method after deleting
func afterDeleteCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterDelete")
scope.CallMethod(callbackTypeAfterDelete)
}
}
2 changes: 1 addition & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ func queryCallback(scope *Scope) {
// afterQueryCallback will invoke `AfterFind` method after querying
func afterQueryCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterFind")
scope.CallMethod(callbackTypeAfterFind)
}
}
8 changes: 4 additions & 4 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ func beforeUpdateCallback(scope *Scope) {
}
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
scope.CallMethod(callbackTypeBeforeSave)
}
if !scope.HasError() {
scope.CallMethod("BeforeUpdate")
scope.CallMethod(callbackTypeBeforeUpdate)
}
}
}
Expand Down Expand Up @@ -112,10 +112,10 @@ func updateCallback(scope *Scope) {
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("AfterUpdate")
scope.CallMethod(callbackTypeAfterUpdate)
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
scope.CallMethod(callbackTypeAfterSave)
}
}
}
59 changes: 53 additions & 6 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ import (
"time"
)

type callbackType string

const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)

// Scope contain current operation's information when you perform any operation on the database
type Scope struct {
Search *search
Expand Down Expand Up @@ -239,17 +253,17 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
}

// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
func (scope *Scope) CallMethod(callback callbackType) {
if scope.Value == nil {
return
}

if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
for i := 0; i < indirectScopeValue.Len(); i++ {
scope.callMethod(methodName, indirectScopeValue.Index(i))
scope.callMethod(callback, indirectScopeValue.Index(i))
}
} else {
scope.callMethod(methodName, indirectScopeValue)
scope.callMethod(callback, indirectScopeValue)
}
}

Expand Down Expand Up @@ -429,13 +443,46 @@ func (scope *Scope) CommitOrRollback() *Scope {
// Private Methods For *gorm.Scope
////////////////////////////////////////////////////////////////////////////////

func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}

func (scope *Scope) callMethod(callback callbackType, reflectValue reflect.Value) {
// Only get address from non-pointer
if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
reflectValue = reflectValue.Addr()
}

if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
if methodValue := callBackToMethodValue(reflectValue, callback); methodValue.IsValid() {
switch method := methodValue.Interface().(type) {
case func():
method()
Expand All @@ -454,7 +501,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
scope.Err(fmt.Errorf("unsupported function %v", callback))
}
}
}
Expand Down