moduler: Move module transaction invoke handling to transaction itself
So we can reduce the generated code and add more unit tests
This commit is contained in:
@@ -257,14 +257,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
|
||||
return C.int(pam.ErrIgnore)
|
||||
}
|
||||
|
||||
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
|
||||
pam.Flags(flags), sliceFromArgv(argc, argv))
|
||||
|
||||
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
|
||||
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
|
||||
sliceFromArgv(argc, argv))
|
||||
if err == nil {
|
||||
return 0;
|
||||
return 0
|
||||
}
|
||||
|
||||
if (pam.Flags(flags) & pam.Silent) == 0 {
|
||||
if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
|
||||
fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -44,14 +44,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
|
||||
return C.int(pam.ErrIgnore)
|
||||
}
|
||||
|
||||
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
|
||||
pam.Flags(flags), sliceFromArgv(argc, argv))
|
||||
|
||||
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
|
||||
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
|
||||
sliceFromArgv(argc, argv))
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if (pam.Flags(flags) & pam.Silent) == 0 {
|
||||
if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
|
||||
fmt.Fprintf(os.Stderr, "module returned error: %v\n", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// Package pam provides a wrapper for the PAM application API.
|
||||
package pam
|
||||
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ModuleTransaction is an interface that a pam module transaction
|
||||
// should implement.
|
||||
type ModuleTransaction interface {
|
||||
@@ -30,8 +37,55 @@ type ModuleHandler interface {
|
||||
SetCred(ModuleTransaction, Flags, []string) error
|
||||
}
|
||||
|
||||
// NewModuleTransaction allows initializing a transaction invoker from
|
||||
// ModuleTransactionInvoker is an interface that a pam module transaction
|
||||
// should implement to redirect requests from C handlers to go,
|
||||
type ModuleTransactionInvoker interface {
|
||||
ModuleTransaction
|
||||
InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error
|
||||
}
|
||||
|
||||
// NewModuleTransactionInvoker allows initializing a transaction invoker from
|
||||
// the module side.
|
||||
func NewModuleTransaction(handle NativeHandle) ModuleTransaction {
|
||||
func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker {
|
||||
return &moduleTransaction{transactionBase{handle: handle}}
|
||||
}
|
||||
|
||||
func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc,
|
||||
flags Flags, args []string) error {
|
||||
invoker := func() error {
|
||||
if handler == nil {
|
||||
return ErrIgnore
|
||||
}
|
||||
err := handler(m, flags, args)
|
||||
if err != nil {
|
||||
service, _ := m.GetItem(Service)
|
||||
|
||||
var pamErr Error
|
||||
if !errors.As(err, &pamErr) {
|
||||
err = ErrSystem
|
||||
}
|
||||
|
||||
if pamErr == ErrIgnore || service == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s failed: %w", service, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
err := invoker()
|
||||
if errors.Is(err, Error(0)) {
|
||||
err = nil
|
||||
}
|
||||
var status int32
|
||||
if err != nil {
|
||||
status = int32(ErrSystem)
|
||||
|
||||
var pamErr Error
|
||||
if errors.As(err, &pamErr) {
|
||||
status = int32(pamErr)
|
||||
}
|
||||
}
|
||||
m.lastStatus.Store(status)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package pam
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ModuleTransaction_InvokeHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
mt := &moduleTransaction{}
|
||||
|
||||
err := mt.InvokeHandler(nil, 0, nil)
|
||||
if !errors.Is(err, ErrIgnore) {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
flags Flags
|
||||
args []string
|
||||
returnedError error
|
||||
expectedError error
|
||||
expectedErrorMsg string
|
||||
}{
|
||||
"success": {
|
||||
expectedError: nil,
|
||||
},
|
||||
"success-with-flags": {
|
||||
expectedError: nil,
|
||||
flags: Silent | RefreshCred,
|
||||
},
|
||||
"success-with-args": {
|
||||
expectedError: nil,
|
||||
args: []string{"foo", "bar"},
|
||||
},
|
||||
"success-with-args-and-flags": {
|
||||
expectedError: nil,
|
||||
flags: Silent | RefreshCred,
|
||||
args: []string{"foo", "bar"},
|
||||
},
|
||||
"ignore": {
|
||||
expectedError: ErrIgnore,
|
||||
returnedError: ErrIgnore,
|
||||
},
|
||||
"ignore-with-args-and-flags": {
|
||||
expectedError: ErrIgnore,
|
||||
returnedError: ErrIgnore,
|
||||
args: []string{"foo", "bar"},
|
||||
},
|
||||
"generic-error": {
|
||||
expectedError: ErrSystem,
|
||||
returnedError: errors.New("this is a generic go error"),
|
||||
expectedErrorMsg: "this is a generic go error",
|
||||
},
|
||||
"transaction-error-service-error": {
|
||||
expectedError: ErrService,
|
||||
returnedError: errors.Join(ErrService, errors.New("ErrService")),
|
||||
expectedErrorMsg: ErrService.Error(),
|
||||
},
|
||||
"return-type-as-error-success": {
|
||||
expectedError: nil,
|
||||
returnedError: Error(0),
|
||||
},
|
||||
"return-type-as-error": {
|
||||
expectedError: ErrNoModuleData,
|
||||
returnedError: ErrNoModuleData,
|
||||
expectedErrorMsg: ErrNoModuleData.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := mt.InvokeHandler(func(handlerMt ModuleTransaction,
|
||||
handlerFlags Flags, handlerArgs []string) error {
|
||||
if handlerMt != mt {
|
||||
t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt)
|
||||
}
|
||||
if handlerFlags != tc.flags {
|
||||
t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags)
|
||||
}
|
||||
if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") {
|
||||
t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs)
|
||||
}
|
||||
|
||||
return tc.returnedError
|
||||
}, tc.flags, tc.args)
|
||||
|
||||
status := Error(mt.lastStatus.Load())
|
||||
|
||||
if !errors.Is(err, tc.expectedError) {
|
||||
t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError)
|
||||
}
|
||||
|
||||
var expectedStatus Error
|
||||
if err != nil {
|
||||
var pamErr Error
|
||||
if errors.As(err, &pamErr) {
|
||||
expectedStatus = pamErr
|
||||
} else {
|
||||
expectedStatus = ErrSystem
|
||||
}
|
||||
}
|
||||
|
||||
if status != expectedStatus {
|
||||
t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user