moduler: Move module transaction invoke handling to transaction itself
So we can reduce the generated code and add more unit tests
This commit is contained in:
@@ -257,14 +257,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
|
|||||||
return C.int(pam.ErrIgnore)
|
return C.int(pam.ErrIgnore)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
|
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
|
||||||
pam.Flags(flags), sliceFromArgv(argc, argv))
|
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
|
||||||
|
sliceFromArgv(argc, argv))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pam.Flags(flags) & pam.Silent) == 0 {
|
if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
|
||||||
fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err)
|
fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,14 +44,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int,
|
|||||||
return C.int(pam.ErrIgnore)
|
return C.int(pam.ErrIgnore)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)),
|
mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh))
|
||||||
pam.Flags(flags), sliceFromArgv(argc, argv))
|
err := mt.InvokeHandler(moduleFunc, pam.Flags(flags),
|
||||||
|
sliceFromArgv(argc, argv))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pam.Flags(flags) & pam.Silent) == 0 {
|
if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) {
|
||||||
fmt.Fprintf(os.Stderr, "module returned error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "module returned error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
// Package pam provides a wrapper for the PAM application API.
|
// Package pam provides a wrapper for the PAM application API.
|
||||||
package pam
|
package pam
|
||||||
|
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// ModuleTransaction is an interface that a pam module transaction
|
// ModuleTransaction is an interface that a pam module transaction
|
||||||
// should implement.
|
// should implement.
|
||||||
type ModuleTransaction interface {
|
type ModuleTransaction interface {
|
||||||
@@ -30,8 +37,55 @@ type ModuleHandler interface {
|
|||||||
SetCred(ModuleTransaction, Flags, []string) error
|
SetCred(ModuleTransaction, Flags, []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewModuleTransaction allows initializing a transaction invoker from
|
// ModuleTransactionInvoker is an interface that a pam module transaction
|
||||||
|
// should implement to redirect requests from C handlers to go,
|
||||||
|
type ModuleTransactionInvoker interface {
|
||||||
|
ModuleTransaction
|
||||||
|
InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModuleTransactionInvoker allows initializing a transaction invoker from
|
||||||
// the module side.
|
// the module side.
|
||||||
func NewModuleTransaction(handle NativeHandle) ModuleTransaction {
|
func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker {
|
||||||
return &moduleTransaction{transactionBase{handle: handle}}
|
return &moduleTransaction{transactionBase{handle: handle}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc,
|
||||||
|
flags Flags, args []string) error {
|
||||||
|
invoker := func() error {
|
||||||
|
if handler == nil {
|
||||||
|
return ErrIgnore
|
||||||
|
}
|
||||||
|
err := handler(m, flags, args)
|
||||||
|
if err != nil {
|
||||||
|
service, _ := m.GetItem(Service)
|
||||||
|
|
||||||
|
var pamErr Error
|
||||||
|
if !errors.As(err, &pamErr) {
|
||||||
|
err = ErrSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
if pamErr == ErrIgnore || service == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("%s failed: %w", service, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := invoker()
|
||||||
|
if errors.Is(err, Error(0)) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
var status int32
|
||||||
|
if err != nil {
|
||||||
|
status = int32(ErrSystem)
|
||||||
|
|
||||||
|
var pamErr Error
|
||||||
|
if errors.As(err, &pamErr) {
|
||||||
|
status = int32(pamErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.lastStatus.Store(status)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package pam
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_ModuleTransaction_InvokeHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mt := &moduleTransaction{}
|
||||||
|
|
||||||
|
err := mt.InvokeHandler(nil, 0, nil)
|
||||||
|
if !errors.Is(err, ErrIgnore) {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]struct {
|
||||||
|
flags Flags
|
||||||
|
args []string
|
||||||
|
returnedError error
|
||||||
|
expectedError error
|
||||||
|
expectedErrorMsg string
|
||||||
|
}{
|
||||||
|
"success": {
|
||||||
|
expectedError: nil,
|
||||||
|
},
|
||||||
|
"success-with-flags": {
|
||||||
|
expectedError: nil,
|
||||||
|
flags: Silent | RefreshCred,
|
||||||
|
},
|
||||||
|
"success-with-args": {
|
||||||
|
expectedError: nil,
|
||||||
|
args: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
"success-with-args-and-flags": {
|
||||||
|
expectedError: nil,
|
||||||
|
flags: Silent | RefreshCred,
|
||||||
|
args: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
"ignore": {
|
||||||
|
expectedError: ErrIgnore,
|
||||||
|
returnedError: ErrIgnore,
|
||||||
|
},
|
||||||
|
"ignore-with-args-and-flags": {
|
||||||
|
expectedError: ErrIgnore,
|
||||||
|
returnedError: ErrIgnore,
|
||||||
|
args: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
"generic-error": {
|
||||||
|
expectedError: ErrSystem,
|
||||||
|
returnedError: errors.New("this is a generic go error"),
|
||||||
|
expectedErrorMsg: "this is a generic go error",
|
||||||
|
},
|
||||||
|
"transaction-error-service-error": {
|
||||||
|
expectedError: ErrService,
|
||||||
|
returnedError: errors.Join(ErrService, errors.New("ErrService")),
|
||||||
|
expectedErrorMsg: ErrService.Error(),
|
||||||
|
},
|
||||||
|
"return-type-as-error-success": {
|
||||||
|
expectedError: nil,
|
||||||
|
returnedError: Error(0),
|
||||||
|
},
|
||||||
|
"return-type-as-error": {
|
||||||
|
expectedError: ErrNoModuleData,
|
||||||
|
returnedError: ErrNoModuleData,
|
||||||
|
expectedErrorMsg: ErrNoModuleData.Error(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := mt.InvokeHandler(func(handlerMt ModuleTransaction,
|
||||||
|
handlerFlags Flags, handlerArgs []string) error {
|
||||||
|
if handlerMt != mt {
|
||||||
|
t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt)
|
||||||
|
}
|
||||||
|
if handlerFlags != tc.flags {
|
||||||
|
t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags)
|
||||||
|
}
|
||||||
|
if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") {
|
||||||
|
t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tc.returnedError
|
||||||
|
}, tc.flags, tc.args)
|
||||||
|
|
||||||
|
status := Error(mt.lastStatus.Load())
|
||||||
|
|
||||||
|
if !errors.Is(err, tc.expectedError) {
|
||||||
|
t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedStatus Error
|
||||||
|
if err != nil {
|
||||||
|
var pamErr Error
|
||||||
|
if errors.As(err, &pamErr) {
|
||||||
|
expectedStatus = pamErr
|
||||||
|
} else {
|
||||||
|
expectedStatus = ErrSystem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != expectedStatus {
|
||||||
|
t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user