Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ issues:
linters:
- funlen
- noctx
- path: mocktail.go
linters:
- gocyclo
text: "cyclomatic complexity 16 of func `processSingleFile` is high" # The complexity is expected.

output:
show-stats: true
105 changes: 102 additions & 3 deletions mocktail.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ func main() {
}

var exported bool
var sourceFile string
flag.BoolVar(&exported, "e", false, "generate exported mocks")
flag.StringVar(&sourceFile, "source", "", "source file containing interfaces to mock")
flag.Parse()

root := info.Dir
Expand All @@ -58,9 +60,17 @@ func main() {
log.Fatalf("Chdir: %v", err)
}

model, err := walk(root, info.Path)
if err != nil {
log.Fatalf("walk: %v", err)
var model map[string]PackageDesc
if sourceFile != "" {
model, err = processSingleFile(sourceFile, root, info.Path)
if err != nil {
log.Fatalf("process single file: %v", err)
}
} else {
model, err = walk(root, info.Path)
if err != nil {
log.Fatalf("walk: %v", err)
}
}

if len(model) == 0 {
Expand All @@ -73,6 +83,95 @@ func main() {
}
}

// processSingleFile processes a single source file to extract interfaces for mocking.
func processSingleFile(sourceFile, root, moduleName string) (map[string]PackageDesc, error) {
model := make(map[string]PackageDesc)

// Convert to absolute path if relative
if !filepath.IsAbs(sourceFile) {
sourceFile = filepath.Join(os.Getenv("PWD"), sourceFile)
}

// Check if file exists
if _, err := os.Stat(sourceFile); os.IsNotExist(err) {
return nil, fmt.Errorf("source file does not exist: %s", sourceFile)
}

// Get the package path for this file
fileDir := filepath.Dir(sourceFile)

// Load the package
pkgs, err := packages.Load(
&packages.Config{
Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedFiles,
Dir: fileDir,
},
".",
)
if err != nil {
return nil, fmt.Errorf("load package: %w", err)
}

if len(pkgs) == 0 {
return model, nil // Return empty model when no packages found
}

pkg := pkgs[0]
if pkg.Types == nil {
relDir, err := filepath.Rel(root, fileDir)
if err != nil {
return nil, fmt.Errorf("get relative directory: %w", err)
}

return nil, fmt.Errorf("package %q has no type information", path.Join(moduleName, relDir))
}

// Find all interfaces in the package (both exported and unexported)
packageDesc := PackageDesc{
Pkg: pkg.Types,
Imports: map[string]struct{}{},
}

scope := pkg.Types.Scope()
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if obj == nil {
continue
}

// Check if it's an interface
if named, ok := obj.Type().(*types.Named); ok {
if interfaceType, ok := named.Underlying().(*types.Interface); ok {
interfaceDesc := InterfaceDesc{Name: name}

// Get all methods from the interface
for i := range interfaceType.NumMethods() {
method := interfaceType.Method(i)
interfaceDesc.Methods = append(interfaceDesc.Methods, method)

// Collect imports needed for this method
for _, imp := range getMethodImports(method, pkg.Types.Path()) {
packageDesc.Imports[imp] = struct{}{}
}
}

if len(interfaceDesc.Methods) > 0 {
packageDesc.Interfaces = append(packageDesc.Interfaces, interfaceDesc)
}
}
}
}

if len(packageDesc.Interfaces) > 0 {
// Use the source file path as the key, but change the filename to match expected output location
outputDir := filepath.Dir(sourceFile)
outputKey := filepath.Join(outputDir, srcMockFile)
model[outputKey] = packageDesc
}

return model, nil
}

//nolint:gocognit,gocyclo // The complexity is expected.
func walk(root, moduleName string) (map[string]PackageDesc, error) {
model := make(map[string]PackageDesc)
Expand Down
221 changes: 219 additions & 2 deletions mocktail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ import (
"github.com/stretchr/testify/require"
)

const goosWindows = "windows"

func TestMocktail(t *testing.T) {
const testRoot = "./testdata/src"

if runtime.GOOS == "windows" {
if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

Expand Down Expand Up @@ -74,7 +76,7 @@ func TestMocktail(t *testing.T) {
func TestMocktail_exported(t *testing.T) {
const testRoot = "./testdata/exported"

if runtime.GOOS == "windows" {
if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

Expand Down Expand Up @@ -129,3 +131,218 @@ func TestMocktail_exported(t *testing.T) {
require.NoError(t, err)
}
}

func TestMocktail_source(t *testing.T) {
const testRoot = "./testdata/source"

if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

testCases := []struct {
name string
expectedOutput string
extraArgs []string
}{
{
name: "a",
expectedOutput: outputMockFile,
extraArgs: nil,
},
{
name: "b",
expectedOutput: outputExportedMockFile,
extraArgs: []string{"-e"},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testDir := filepath.Join(testRoot, tc.name)
interfacesFile := filepath.Join(testDir, "interfaces.go")

// Convert to absolute path to avoid path duplication issues
absTestDir, err := filepath.Abs(testDir)
require.NoError(t, err)
absInterfacesFile, err := filepath.Abs(interfacesFile)
require.NoError(t, err)

// Set up environment
t.Setenv("MOCKTAIL_TEST_PATH", absTestDir)

// Build command args
args := []string{"run", "."}
args = append(args, tc.extraArgs...)
args = append(args, "-source="+absInterfacesFile)

// Run mocktail with source parameter
output, err := exec.Command("go", args...).CombinedOutput()
t.Log(string(output))
require.NoError(t, err)

// Check generated file matches golden file
genPath := filepath.Join(testDir, tc.expectedOutput)
t.Cleanup(func() {
_ = os.Remove(genPath)
})

goldenPath := genPath + ".golden"

genBytes, err := os.ReadFile(genPath)
require.NoError(t, err)

goldenBytes, err := os.ReadFile(goldenPath)
require.NoError(t, err)

assert.Equal(t, string(goldenBytes), string(genBytes))

cmd := exec.Command("go", "test", "-v", "./...")
cmd.Dir = testDir

output, err = cmd.CombinedOutput()
t.Log(string(output))
require.NoError(t, err)
})
}
}

func TestProcessSingleFile(t *testing.T) {
t.Parallel()

if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

tests := []struct {
name string
sourceFile string
expectedErr bool
expectedIntf int // expected number of interfaces
}{
{
name: "valid_basic_file",
sourceFile: "testdata/source/a/interfaces.go",
expectedErr: false,
expectedIntf: 2, // PiniaColada, shirleyTemple
},
{
name: "valid_exported_file",
sourceFile: "testdata/source/b/interfaces.go",
expectedErr: false,
expectedIntf: 1, // PiniaColada
},
{
name: "nonexistent_file",
sourceFile: "testdata/source/nonexistent.go",
expectedErr: true,
},
{
name: "relative_path",
sourceFile: "./testdata/source/a/interfaces.go",
expectedErr: false,
expectedIntf: 2, // PiniaColada, shirleyTemple
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Convert source file to absolute path to avoid path issues
absSourceFile, err := filepath.Abs(tt.sourceFile)
if !tt.expectedErr {
require.NoError(t, err)
}

// Get the module info for the specific test directory
testDir := filepath.Dir(absSourceFile)
info, err := getModuleInfo(testDir)
if !tt.expectedErr {
require.NoError(t, err)
}

// Test processSingleFile function
model, err := processSingleFile(absSourceFile, info.Dir, info.Path)

if tt.expectedErr {
assert.Error(t, err)
return
}

require.NoError(t, err)

// Should have exactly one entry in the model
assert.Len(t, model, 1)

// Check the number of interfaces found
var totalInterfaces int
for _, pkgDesc := range model {
totalInterfaces += len(pkgDesc.Interfaces)
}
assert.Equal(t, tt.expectedIntf, totalInterfaces)

// Verify interfaces have methods
for _, pkgDesc := range model {
for _, intf := range pkgDesc.Interfaces {
assert.NotEmpty(t, intf.Methods, "Interface %s should have methods", intf.Name)
}
}
})
}
}

func TestProcessSingleFile_InvalidPackage(t *testing.T) {
t.Parallel()

if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

// Create a temporary file with invalid Go code
tmpFile, err := os.CreateTemp(t.TempDir(), "invalid_*.go")
require.NoError(t, err)
t.Cleanup(func() {
_ = os.Remove(tmpFile.Name())
})

_, err = tmpFile.WriteString("package invalid\n\n// This is not a valid interface\ntype NotAnInterface struct{}\n")
require.NoError(t, err)
_ = tmpFile.Close()

// Use current directory for temporary file test
cwd, err := os.Getwd()
require.NoError(t, err)
info, err := getModuleInfo(cwd)
require.NoError(t, err)

// Test processSingleFile with file containing no interfaces
model, err := processSingleFile(tmpFile.Name(), info.Dir, info.Path)
require.NoError(t, err)
assert.Empty(t, model, "Should return empty model when no interfaces found")
}

func TestProcessSingleFile_AbsolutePath(t *testing.T) {
t.Parallel()

if runtime.GOOS == goosWindows {
t.Skip(runtime.GOOS)
}

// Test with absolute path
absPath, err := filepath.Abs("testdata/source/a/interfaces.go")
require.NoError(t, err)

// Get module info from the test directory
testDir := filepath.Dir(absPath)
info, err := getModuleInfo(testDir)
require.NoError(t, err)

model, err := processSingleFile(absPath, info.Dir, info.Path)
require.NoError(t, err)

assert.Len(t, model, 1)

var totalInterfaces int
for _, pkgDesc := range model {
totalInterfaces += len(pkgDesc.Interfaces)
}
assert.Equal(t, 2, totalInterfaces)
}
Loading