moduler: Move module transaction invoke handling to transaction itself

So we can reduce the generated code and add more unit tests
This commit is contained in:
Marco Trevisan (Treviño)
2023-09-25 23:08:08 +02:00
parent e0e1d2de2c
commit 6f3af6e9b2
4 changed files with 171 additions and 11 deletions

View File

@@ -257,14 +257,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
return C.int(pam.ErrIgnore)
}
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
pam.Flags(flags), sliceFromArgv(argc, argv))
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
sliceFromArgv(argc, argv))
if err == nil {
return 0;
return 0
}
if (pam.Flags(flags) & pam.Silent) == 0 {
if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err)
}

View File

@@ -44,14 +44,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
return C.int(pam.ErrIgnore)
}
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
pam.Flags(flags), sliceFromArgv(argc, argv))
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
sliceFromArgv(argc, argv))
if err == nil {
return 0
}
if (pam.Flags(flags) & pam.Silent) == 0 {
if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
fmt.Fprintf(os.Stderr, "module returned error: %v\n", err)
}

View File

@@ -1,6 +1,13 @@
// Package pam provides a wrapper for the PAM application API.
package pam
import "C"
import (
"errors"
"fmt"
)
// ModuleTransaction is an interface that a pam module transaction
// should implement.
type ModuleTransaction interface {
@@ -30,8 +37,55 @@ type ModuleHandler interface {
SetCred(ModuleTransaction, Flags, []string) error
}
// NewModuleTransaction allows initializing a transaction invoker from
// ModuleTransactionInvoker is an interface that a pam module transaction
// should implement to redirect requests from C handlers to go,
type ModuleTransactionInvoker interface {
ModuleTransaction
InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error
}
// NewModuleTransactionInvoker allows initializing a transaction invoker from
// the module side.
func NewModuleTransaction(handle NativeHandle) ModuleTransaction {
func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker {
return &moduleTransaction{transactionBase{handle: handle}}
}
func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc,
flags Flags, args []string) error {
invoker := func() error {
if handler == nil {
return ErrIgnore
}
err := handler(m, flags, args)
if err != nil {
service, _ := m.GetItem(Service)
var pamErr Error
if !errors.As(err, &pamErr) {
err = ErrSystem
}
if pamErr == ErrIgnore || service == "" {
return err
}
return fmt.Errorf("%s failed: %w", service, err)
}
return nil
}
err := invoker()
if errors.Is(err, Error(0)) {
err = nil
}
var status int32
if err != nil {
status = int32(ErrSystem)
var pamErr Error
if errors.As(err, &pamErr) {
status = int32(pamErr)
}
}
m.lastStatus.Store(status)
return err
}

View File

@@ -4,6 +4,7 @@ package pam
import (
"errors"
"reflect"
"strings"
"testing"
)
@@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) {
})
}
}
func Test_ModuleTransaction_InvokeHandler(t *testing.T) {
t.Parallel()
mt := &moduleTransaction{}
err := mt.InvokeHandler(nil, 0, nil)
if !errors.Is(err, ErrIgnore) {
t.Fatalf("unexpected err: %v", err)
}
tests := map[string]struct {
flags Flags
args []string
returnedError error
expectedError error
expectedErrorMsg string
}{
"success": {
expectedError: nil,
},
"success-with-flags": {
expectedError: nil,
flags: Silent | RefreshCred,
},
"success-with-args": {
expectedError: nil,
args: []string{"foo", "bar"},
},
"success-with-args-and-flags": {
expectedError: nil,
flags: Silent | RefreshCred,
args: []string{"foo", "bar"},
},
"ignore": {
expectedError: ErrIgnore,
returnedError: ErrIgnore,
},
"ignore-with-args-and-flags": {
expectedError: ErrIgnore,
returnedError: ErrIgnore,
args: []string{"foo", "bar"},
},
"generic-error": {
expectedError: ErrSystem,
returnedError: errors.New("this is a generic go error"),
expectedErrorMsg: "this is a generic go error",
},
"transaction-error-service-error": {
expectedError: ErrService,
returnedError: errors.Join(ErrService, errors.New("ErrService")),
expectedErrorMsg: ErrService.Error(),
},
"return-type-as-error-success": {
expectedError: nil,
returnedError: Error(0),
},
"return-type-as-error": {
expectedError: ErrNoModuleData,
returnedError: ErrNoModuleData,
expectedErrorMsg: ErrNoModuleData.Error(),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
err := mt.InvokeHandler(func(handlerMt ModuleTransaction,
handlerFlags Flags, handlerArgs []string) error {
if handlerMt != mt {
t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt)
}
if handlerFlags != tc.flags {
t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags)
}
if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") {
t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs)
}
return tc.returnedError
}, tc.flags, tc.args)
status := Error(mt.lastStatus.Load())
if !errors.Is(err, tc.expectedError) {
t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError)
}
var expectedStatus Error
if err != nil {
var pamErr Error
if errors.As(err, &pamErr) {
expectedStatus = pamErr
} else {
expectedStatus = ErrSystem
}
}
if status != expectedStatus {
t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus)
}
})
}
}