diff --git a/constraint/pkg/client/drivers/rego/schema/schema.go b/constraint/pkg/client/drivers/rego/schema/schema.go index dbb119dd9..440a6f7e1 100644 --- a/constraint/pkg/client/drivers/rego/schema/schema.go +++ b/constraint/pkg/client/drivers/rego/schema/schema.go @@ -54,6 +54,10 @@ func (in *Source) ToUnstructured() map[string]interface{} { // GetSource extracts Source from a templates.Code object. func GetSource(code templates.Code) (*Source, error) { rawCode := code.Source + if rawCode == nil || rawCode.Value == nil { + return nil, fmt.Errorf("%w: source", ErrMissingField) + } + v, ok := rawCode.Value.(map[string]interface{}) if !ok { return nil, ErrBadType diff --git a/constraint/pkg/client/drivers/rego/schema/schema_test.go b/constraint/pkg/client/drivers/rego/schema/schema_test.go index 55bf0f0f0..3f9f915e1 100644 --- a/constraint/pkg/client/drivers/rego/schema/schema_test.go +++ b/constraint/pkg/client/drivers/rego/schema/schema_test.go @@ -1,6 +1,7 @@ package schema import ( + "errors" "testing" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" @@ -77,3 +78,25 @@ func TestGetSourceVersions(t *testing.T) { }) } } + +func TestGetSourceMissingSource(t *testing.T) { + testCases := map[string]templates.Code{ + "nil source": { + Engine: Name, + Source: nil, + }, + "nil source value": { + Engine: Name, + Source: &templates.Anything{}, + }, + } + + for name, code := range testCases { + t.Run(name, func(t *testing.T) { + _, err := GetSource(code) + if !errors.Is(err, ErrMissingField) { + t.Fatalf("expected %v, got %v", ErrMissingField, err) + } + }) + } +}