From 07587423e427a02a0ef0146c0c5b6674154945e4 Mon Sep 17 00:00:00 2001 From: Ilya Priven <1186084+ikonst@users.noreply.github.com> Date: Sat, 15 Mar 2025 02:15:25 +0000 Subject: [PATCH] Call MethodByName with constants --- callback_create.go | 8 +++---- callback_delete.go | 4 ++-- callback_query.go | 2 +- callback_update.go | 8 +++---- scope.go | 59 +++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 64 insertions(+), 17 deletions(-) diff --git a/callback_create.go b/callback_create.go index 59840f863b..de661758b9 100644 --- a/callback_create.go +++ b/callback_create.go @@ -21,10 +21,10 @@ func init() { // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating func beforeCreateCallback(scope *Scope) { if !scope.HasError() { - scope.CallMethod("BeforeSave") + scope.CallMethod(callbackTypeBeforeSave) } if !scope.HasError() { - scope.CallMethod("BeforeCreate") + scope.CallMethod(callbackTypeBeforeCreate) } } @@ -196,9 +196,9 @@ func forceReloadAfterCreateCallback(scope *Scope) { // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating func afterCreateCallback(scope *Scope) { if !scope.HasError() { - scope.CallMethod("AfterCreate") + scope.CallMethod(callbackTypeAfterCreate) } if !scope.HasError() { - scope.CallMethod("AfterSave") + scope.CallMethod(callbackTypeAfterSave) } } diff --git a/callback_delete.go b/callback_delete.go index 48b97acbfb..823d846048 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -21,7 +21,7 @@ func beforeDeleteCallback(scope *Scope) { return } if !scope.HasError() { - scope.CallMethod("BeforeDelete") + scope.CallMethod(callbackTypeBeforeDelete) } } @@ -58,6 +58,6 @@ func deleteCallback(scope *Scope) { // afterDeleteCallback will invoke `AfterDelete` method after deleting func afterDeleteCallback(scope *Scope) { if !scope.HasError() { - scope.CallMethod("AfterDelete") + scope.CallMethod(callbackTypeAfterDelete) } } diff --git a/callback_query.go b/callback_query.go index f756271527..3dd5eb1865 100644 --- a/callback_query.go +++ b/callback_query.go @@ -100,6 +100,6 @@ func queryCallback(scope *Scope) { // afterQueryCallback will invoke `AfterFind` method after querying func afterQueryCallback(scope *Scope) { if !scope.HasError() { - scope.CallMethod("AfterFind") + scope.CallMethod(callbackTypeAfterFind) } } diff --git a/callback_update.go b/callback_update.go index 699e534b96..4b5c22ad46 100644 --- a/callback_update.go +++ b/callback_update.go @@ -39,10 +39,10 @@ func beforeUpdateCallback(scope *Scope) { } if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { - scope.CallMethod("BeforeSave") + scope.CallMethod(callbackTypeBeforeSave) } if !scope.HasError() { - scope.CallMethod("BeforeUpdate") + scope.CallMethod(callbackTypeBeforeUpdate) } } } @@ -112,10 +112,10 @@ func updateCallback(scope *Scope) { func afterUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { - scope.CallMethod("AfterUpdate") + scope.CallMethod(callbackTypeAfterUpdate) } if !scope.HasError() { - scope.CallMethod("AfterSave") + scope.CallMethod(callbackTypeAfterSave) } } } diff --git a/scope.go b/scope.go index ea12ee2f17..115f484b10 100644 --- a/scope.go +++ b/scope.go @@ -12,6 +12,20 @@ import ( "time" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // Scope contain current operation's information when you perform any operation on the database type Scope struct { Search *search @@ -239,17 +253,17 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { } // CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { +func (scope *Scope) CallMethod(callback callbackType) { if scope.Value == nil { return } if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) + scope.callMethod(callback, indirectScopeValue.Index(i)) } } else { - scope.callMethod(methodName, indirectScopeValue) + scope.callMethod(callback, indirectScopeValue) } } @@ -429,13 +443,46 @@ func (scope *Scope) CommitOrRollback() *Scope { // Private Methods For *gorm.Scope //////////////////////////////////////////////////////////////////////////////// -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + +func (scope *Scope) callMethod(callback callbackType, reflectValue reflect.Value) { // Only get address from non-pointer if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { reflectValue = reflectValue.Addr() } - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { + if methodValue := callBackToMethodValue(reflectValue, callback); methodValue.IsValid() { switch method := methodValue.Interface().(type) { case func(): method() @@ -454,7 +501,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { scope.Err(method(newDB)) scope.Err(newDB.Error) default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) + scope.Err(fmt.Errorf("unsupported function %v", callback)) } } }