From 6f3af6e9b27cdb447a4e6800a2bc731316a60c4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 23:08:08 +0200 Subject: [PATCH] moduler: Move module transaction invoke handling to transaction itself So we can reduce the generated code and add more unit tests --- cmd/pam-moduler/moduler.go | 10 ++-- example-module/pam_module.go | 8 +-- module-transaction.go | 58 ++++++++++++++++++- module-transaction_test.go | 106 +++++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 11 deletions(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 94298dc..68f4852 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -257,14 +257,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), - pam.Flags(flags), sliceFromArgv(argc, argv)) - + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) if err == nil { - return 0; + return 0 } - if (pam.Flags(flags) & pam.Silent) == 0 { + if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err) } diff --git a/example-module/pam_module.go b/example-module/pam_module.go index b13924e..080e97c 100644 --- a/example-module/pam_module.go +++ b/example-module/pam_module.go @@ -44,14 +44,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), - pam.Flags(flags), sliceFromArgv(argc, argv)) - + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) if err == nil { return 0 } - if (pam.Flags(flags) & pam.Silent) == 0 { + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) } diff --git a/module-transaction.go b/module-transaction.go index 12b3a40..0e87fe5 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -1,6 +1,13 @@ // Package pam provides a wrapper for the PAM application API. package pam +import "C" + +import ( + "errors" + "fmt" +) + // ModuleTransaction is an interface that a pam module transaction // should implement. type ModuleTransaction interface { @@ -30,8 +37,55 @@ type ModuleHandler interface { SetCred(ModuleTransaction, Flags, []string) error } -// NewModuleTransaction allows initializing a transaction invoker from +// ModuleTransactionInvoker is an interface that a pam module transaction +// should implement to redirect requests from C handlers to go, +type ModuleTransactionInvoker interface { + ModuleTransaction + InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from // the module side. -func NewModuleTransaction(handle NativeHandle) ModuleTransaction { +func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { return &moduleTransaction{transactionBase{handle: handle}} } + +func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, + flags Flags, args []string) error { + invoker := func() error { + if handler == nil { + return ErrIgnore + } + err := handler(m, flags, args) + if err != nil { + service, _ := m.GetItem(Service) + + var pamErr Error + if !errors.As(err, &pamErr) { + err = ErrSystem + } + + if pamErr == ErrIgnore || service == "" { + return err + } + + return fmt.Errorf("%s failed: %w", service, err) + } + return nil + } + err := invoker() + if errors.Is(err, Error(0)) { + err = nil + } + var status int32 + if err != nil { + status = int32(ErrSystem) + + var pamErr Error + if errors.As(err, &pamErr) { + status = int32(pamErr) + } + } + m.lastStatus.Store(status) + return err +} diff --git a/module-transaction_test.go b/module-transaction_test.go index d5c7533..8661f68 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -4,6 +4,7 @@ package pam import ( "errors" "reflect" + "strings" "testing" ) @@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) { }) } } + +func Test_ModuleTransaction_InvokeHandler(t *testing.T) { + t.Parallel() + mt := &moduleTransaction{} + + err := mt.InvokeHandler(nil, 0, nil) + if !errors.Is(err, ErrIgnore) { + t.Fatalf("unexpected err: %v", err) + } + + tests := map[string]struct { + flags Flags + args []string + returnedError error + expectedError error + expectedErrorMsg string + }{ + "success": { + expectedError: nil, + }, + "success-with-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + }, + "success-with-args": { + expectedError: nil, + args: []string{"foo", "bar"}, + }, + "success-with-args-and-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + args: []string{"foo", "bar"}, + }, + "ignore": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + }, + "ignore-with-args-and-flags": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + args: []string{"foo", "bar"}, + }, + "generic-error": { + expectedError: ErrSystem, + returnedError: errors.New("this is a generic go error"), + expectedErrorMsg: "this is a generic go error", + }, + "transaction-error-service-error": { + expectedError: ErrService, + returnedError: errors.Join(ErrService, errors.New("ErrService")), + expectedErrorMsg: ErrService.Error(), + }, + "return-type-as-error-success": { + expectedError: nil, + returnedError: Error(0), + }, + "return-type-as-error": { + expectedError: ErrNoModuleData, + returnedError: ErrNoModuleData, + expectedErrorMsg: ErrNoModuleData.Error(), + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := mt.InvokeHandler(func(handlerMt ModuleTransaction, + handlerFlags Flags, handlerArgs []string) error { + if handlerMt != mt { + t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt) + } + if handlerFlags != tc.flags { + t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags) + } + if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") { + t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs) + } + + return tc.returnedError + }, tc.flags, tc.args) + + status := Error(mt.lastStatus.Load()) + + if !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + var expectedStatus Error + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + expectedStatus = pamErr + } else { + expectedStatus = ErrSystem + } + } + + if status != expectedStatus { + t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus) + } + }) + } +}