moduler: Move module transaction invoke handling to transaction itself

So we can reduce the generated code and add more unit tests
This commit is contained in:
Marco Trevisan (Treviño)
2023-09-25 23:08:08 +02:00
parent e0e1d2de2c
commit 6f3af6e9b2
4 changed files with 171 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ package pam
import (
"errors"
"reflect"
"strings"
"testing"
)
@@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) {
})
}
}
func Test_ModuleTransaction_InvokeHandler(t *testing.T) {
t.Parallel()
mt := &moduleTransaction{}
err := mt.InvokeHandler(nil, 0, nil)
if !errors.Is(err, ErrIgnore) {
t.Fatalf("unexpected err: %v", err)
}
tests := map[string]struct {
flags Flags
args []string
returnedError error
expectedError error
expectedErrorMsg string
}{
"success": {
expectedError: nil,
},
"success-with-flags": {
expectedError: nil,
flags: Silent | RefreshCred,
},
"success-with-args": {
expectedError: nil,
args: []string{"foo", "bar"},
},
"success-with-args-and-flags": {
expectedError: nil,
flags: Silent | RefreshCred,
args: []string{"foo", "bar"},
},
"ignore": {
expectedError: ErrIgnore,
returnedError: ErrIgnore,
},
"ignore-with-args-and-flags": {
expectedError: ErrIgnore,
returnedError: ErrIgnore,
args: []string{"foo", "bar"},
},
"generic-error": {
expectedError: ErrSystem,
returnedError: errors.New("this is a generic go error"),
expectedErrorMsg: "this is a generic go error",
},
"transaction-error-service-error": {
expectedError: ErrService,
returnedError: errors.Join(ErrService, errors.New("ErrService")),
expectedErrorMsg: ErrService.Error(),
},
"return-type-as-error-success": {
expectedError: nil,
returnedError: Error(0),
},
"return-type-as-error": {
expectedError: ErrNoModuleData,
returnedError: ErrNoModuleData,
expectedErrorMsg: ErrNoModuleData.Error(),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
err := mt.InvokeHandler(func(handlerMt ModuleTransaction,
handlerFlags Flags, handlerArgs []string) error {
if handlerMt != mt {
t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt)
}
if handlerFlags != tc.flags {
t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags)
}
if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") {
t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs)
}
return tc.returnedError
}, tc.flags, tc.args)
status := Error(mt.lastStatus.Load())
if !errors.Is(err, tc.expectedError) {
t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError)
}
var expectedStatus Error
if err != nil {
var pamErr Error
if errors.As(err, &pamErr) {
expectedStatus = pamErr
} else {
expectedStatus = ErrSystem
}
}
if status != expectedStatus {
t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus)
}
})
}
}