Skip to content

Commit 1120993

Browse files
committed
testhelpers: create MockTraceLogger
Create a 'MockTraceLogger' struct that, like other mocks in 'mock.go', embeds a 'mock.Mock' struct. However, unique to the 'TraceLogger' instance in tests (namely the daemon tests) is that we usually don't want to assert on the logger's behavior; in most cases, we just want it to "do nothing" or pass through. Implement the 'TraceLogger' interface with checks that only invoke 'Mock.Called()' if the method has been mocked at least once (even if it's for args not matching the current ones). If the method is not mocked, fall back on "passthrough" defaults (usually returning the input 'context.Context' and/or the input args). Finally, add an input assertion to 'MockTraceLogger.Error()', to serve as an additional validation in unit tests and mirror the behavior of an actual 'TraceLogger' instance. Signed-off-by: Victoria Dye <[email protected]>
1 parent ce15989 commit 1120993

File tree

3 files changed

+118
-8
lines changed

3 files changed

+118
-8
lines changed

internal/daemon/launchd_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ var launchdCreatePlistTests = []struct {
226226

227227
func TestLaunchd_Create(t *testing.T) {
228228
// Set up mocks
229+
testLogger := &MockTraceLogger{}
229230
testUser := &user.User{
230231
Uid: "123",
231232
Username: "testuser",
@@ -240,7 +241,7 @@ func TestLaunchd_Create(t *testing.T) {
240241

241242
ctx := context.Background()
242243

243-
launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem)
244+
launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem)
244245

245246
// Verify launchd commands called
246247
for _, tt := range launchdCreateBehaviorTests {
@@ -356,6 +357,7 @@ func TestLaunchd_Create(t *testing.T) {
356357

357358
func TestLaunchd_Start(t *testing.T) {
358359
// Set up mocks
360+
testLogger := &MockTraceLogger{}
359361
testUser := &user.User{
360362
Uid: "123",
361363
Username: "testuser",
@@ -367,7 +369,7 @@ func TestLaunchd_Start(t *testing.T) {
367369

368370
ctx := context.Background()
369371

370-
launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil)
372+
launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil)
371373

372374
// Test #1: launchctl succeeds
373375
t.Run("Calls correct launchctl command", func(t *testing.T) {
@@ -437,6 +439,7 @@ var launchdStopTests = []struct {
437439

438440
func TestLaunchd_Stop(t *testing.T) {
439441
// Set up mocks
442+
testLogger := &MockTraceLogger{}
440443
testUser := &user.User{
441444
Uid: "123",
442445
Username: "testuser",
@@ -448,7 +451,7 @@ func TestLaunchd_Stop(t *testing.T) {
448451

449452
ctx := context.Background()
450453

451-
launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil)
454+
launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil)
452455

453456
for _, tt := range launchdStopTests {
454457
t.Run(tt.title, func(t *testing.T) {
@@ -527,6 +530,7 @@ var launchdRemoveTests = []struct {
527530

528531
func TestLaunchd_Remove(t *testing.T) {
529532
// Set up mocks
533+
testLogger := &MockTraceLogger{}
530534
testUser := &user.User{
531535
Uid: "123",
532536
Username: "testuser",
@@ -540,7 +544,7 @@ func TestLaunchd_Remove(t *testing.T) {
540544

541545
ctx := context.Background()
542546

543-
launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem)
547+
launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem)
544548

545549
for _, tt := range launchdRemoveTests {
546550
t.Run(tt.title, func(t *testing.T) {

internal/daemon/systemd_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ var systemdCreateServiceUnitTests = []struct {
126126

127127
func TestSystemd_Create(t *testing.T) {
128128
// Set up mocks
129+
testLogger := &MockTraceLogger{}
129130
testUser := &user.User{
130131
Uid: "123",
131132
Username: "testuser",
@@ -140,7 +141,7 @@ func TestSystemd_Create(t *testing.T) {
140141

141142
ctx := context.Background()
142143

143-
systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem)
144+
systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem)
144145

145146
for _, tt := range systemdCreateBehaviorTests {
146147
forceArg := tt.force.ToBoolList()
@@ -236,6 +237,7 @@ func TestSystemd_Create(t *testing.T) {
236237

237238
func TestSystemd_Start(t *testing.T) {
238239
// Set up mocks
240+
testLogger := &MockTraceLogger{}
239241
testUser := &user.User{
240242
Uid: "123",
241243
Username: "testuser",
@@ -248,7 +250,7 @@ func TestSystemd_Start(t *testing.T) {
248250

249251
ctx := context.Background()
250252

251-
systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil)
253+
systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil)
252254

253255
// Test #1: systemctl succeeds
254256
t.Run("Calls correct systemctl command", func(t *testing.T) {
@@ -280,6 +282,7 @@ func TestSystemd_Start(t *testing.T) {
280282

281283
func TestSystemd_Stop(t *testing.T) {
282284
// Set up mocks
285+
testLogger := &MockTraceLogger{}
283286
testUser := &user.User{
284287
Uid: "123",
285288
Username: "testuser",
@@ -292,7 +295,7 @@ func TestSystemd_Stop(t *testing.T) {
292295

293296
ctx := context.Background()
294297

295-
systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil)
298+
systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil)
296299

297300
// Test #1: systemctl succeeds
298301
t.Run("Calls correct systemctl command", func(t *testing.T) {
@@ -375,6 +378,7 @@ var systemdRemoveTests = []struct {
375378

376379
func TestSystemd_Remove(t *testing.T) {
377380
// Set up mocks
381+
testLogger := &MockTraceLogger{}
378382
testUser := &user.User{
379383
Uid: "123",
380384
Username: "testuser",
@@ -388,7 +392,7 @@ func TestSystemd_Remove(t *testing.T) {
388392

389393
ctx := context.Background()
390394

391-
systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem)
395+
systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem)
392396

393397
for _, tt := range systemdRemoveTests {
394398
t.Run(tt.title, func(t *testing.T) {

internal/testhelpers/mocks.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,113 @@
11
package testhelpers
22

33
import (
4+
"context"
5+
"fmt"
46
"os/user"
7+
"runtime"
58

69
"github.com/stretchr/testify/mock"
710
)
811

12+
func methodIsMocked(m *mock.Mock) bool {
13+
// Get the calling method name
14+
pc := make([]uintptr, 1)
15+
n := runtime.Callers(1, pc)
16+
if n == 0 {
17+
// No caller found - fall back on "not mocked"
18+
return false
19+
}
20+
caller := runtime.FuncForPC(pc[0] - 1)
21+
if caller == nil {
22+
// Caller not found - fall back on "not mocked"
23+
return false
24+
}
25+
26+
for _, call := range m.ExpectedCalls {
27+
if call.Method == caller.Name() {
28+
return true
29+
}
30+
}
31+
32+
return false
33+
}
34+
35+
type notMocked struct{}
36+
37+
var NotMockedValue notMocked = notMocked{}
38+
39+
func mockWithDefault[T any](args mock.Arguments, index int, defaultValue T) T {
40+
if len(args) <= index {
41+
return defaultValue
42+
}
43+
44+
mockedValue := args.Get(index)
45+
if _, ok := mockedValue.(notMocked); ok {
46+
return defaultValue
47+
}
48+
49+
return mockedValue.(T)
50+
}
51+
52+
type MockTraceLogger struct {
53+
mock.Mock
54+
}
55+
56+
func (l *MockTraceLogger) Region(ctx context.Context, category string, label string) (context.Context, func()) {
57+
fnArgs := mock.Arguments{}
58+
if methodIsMocked(&l.Mock) {
59+
fnArgs = l.Called(ctx, category, label)
60+
}
61+
return mockWithDefault(fnArgs, 0, ctx), mockWithDefault(fnArgs, 1, func() {})
62+
}
63+
64+
func (l *MockTraceLogger) LogCommand(ctx context.Context, commandName string) context.Context {
65+
fnArgs := mock.Arguments{}
66+
if methodIsMocked(&l.Mock) {
67+
fnArgs = l.Called(ctx, commandName)
68+
}
69+
return mockWithDefault(fnArgs, 0, ctx)
70+
}
71+
72+
func (l *MockTraceLogger) Error(ctx context.Context, err error) error {
73+
// Input validation
74+
if err == nil {
75+
panic("err must be nil")
76+
}
77+
78+
fnArgs := mock.Arguments{}
79+
if methodIsMocked(&l.Mock) {
80+
fnArgs = l.Called(ctx, err)
81+
}
82+
return mockWithDefault(fnArgs, 0, err)
83+
}
84+
85+
func (l *MockTraceLogger) Errorf(ctx context.Context, format string, a ...any) error {
86+
fnArgs := mock.Arguments{}
87+
if methodIsMocked(&l.Mock) {
88+
fnArgs = l.Called(ctx, format, a)
89+
}
90+
return mockWithDefault(fnArgs, 0, fmt.Errorf(format, a...))
91+
}
92+
93+
func (l *MockTraceLogger) Exit(ctx context.Context, exitCode int) {
94+
if methodIsMocked(&l.Mock) {
95+
l.Called(ctx, exitCode)
96+
}
97+
}
98+
99+
func (l *MockTraceLogger) Fatal(ctx context.Context, err error) {
100+
if methodIsMocked(&l.Mock) {
101+
l.Called(ctx, err)
102+
}
103+
}
104+
105+
func (l *MockTraceLogger) Fatalf(ctx context.Context, format string, a ...any) {
106+
if methodIsMocked(&l.Mock) {
107+
l.Called(ctx, format, a)
108+
}
109+
}
110+
9111
type MockUserProvider struct {
10112
mock.Mock
11113
}

0 commit comments

Comments
 (0)