diff --git a/module-transaction.go b/module-transaction.go new file mode 100644 index 0000000..9a3aae6 --- /dev/null +++ b/module-transaction.go @@ -0,0 +1,29 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +// ModuleTransaction is an interface that a pam module transaction +// should implement. +type ModuleTransaction interface { + SetItem(Item, string) error + GetItem(Item) (string, error) + PutEnv(nameVal string) error + GetEnv(name string) string + GetEnvList() (map[string]string, error) +} + +// ModuleHandlerFunc is a function type used by the ModuleHandler. +type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error + +// ModuleTransaction is the module-side handle for a PAM transaction. +type moduleTransaction = transactionBase + +// ModuleHandler is an interface for objects that can be used to create +// PAM modules from go. +type ModuleHandler interface { + AcctMgmt(ModuleTransaction, Flags, []string) error + Authenticate(ModuleTransaction, Flags, []string) error + ChangeAuthTok(ModuleTransaction, Flags, []string) error + CloseSession(ModuleTransaction, Flags, []string) error + OpenSession(ModuleTransaction, Flags, []string) error + SetCred(ModuleTransaction, Flags, []string) error +} diff --git a/module-transaction_test.go b/module-transaction_test.go new file mode 100644 index 0000000..d5c7533 --- /dev/null +++ b/module-transaction_test.go @@ -0,0 +1,131 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +import ( + "errors" + "reflect" + "testing" +) + +func Test_NewNullModuleTransaction(t *testing.T) { + t.Parallel() + mt := moduleTransaction{} + + if mt.handle != nil { + t.Fatalf("unexpected handle value: %v", mt.handle) + } + + if s := Error(mt.lastStatus.Load()); s != success { + t.Fatalf("unexpected status: %v", s) + } + + tests := map[string]struct { + testFunc func(t *testing.T) (any, error) + expectedError error + ignoreError bool + }{ + "GetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetItem(Service) + }, + }, + "SetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetItem(Service, "foo") + }, + }, + "GetEnv": { + ignoreError: true, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetEnv("foo"), nil + }, + }, + "PutEnv": { + expectedError: ErrAbort, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.PutEnv("foo=bar") + }, + }, + "GetEnvList": { + expectedError: ErrBuf, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + list, err := mt.GetEnvList() + if len(list) > 0 { + t.Fatalf("unexpected list: %v", list) + } + return nil, err + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name+"-error-check", func(t *testing.T) { + t.Parallel() + data, err := tc.testFunc(t) + + switch d := data.(type) { + case string: + if d != "" { + t.Fatalf("empty value was expected, got %s", d) + } + case interface{}: + if !reflect.ValueOf(d).IsNil() { + t.Fatalf("nil value was expected, got %v", d) + } + default: + if d != nil { + t.Fatalf("nil value was expected, got %v", d) + } + } + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + var expectedError error = ErrSystem + if tc.expectedError != nil { + expectedError = tc.expectedError + } + + if !errors.Is(err, expectedError) { + t.Fatalf("status %v was expected, but got %v", + expectedError, err) + } + }) + } + + for name, tc := range tests { + // These can't be parallel - we test a private value that is not thread safe + t.Run(name+"-lastStatus-check", func(t *testing.T) { + mt.lastStatus.Store(99999) + _, err := tc.testFunc(t) + status := Error(mt.lastStatus.Load()) + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + expectedStatus := ErrSystem + if tc.expectedError != nil { + errors.As(err, &expectedStatus) + } + + if status != expectedStatus { + t.Fatalf("status %v was expected, but got %d", + expectedStatus, status) + } + }) + } +} diff --git a/transaction.go b/transaction.go index bba4152..7f5b3e5 100644 --- a/transaction.go +++ b/transaction.go @@ -124,11 +124,15 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { return C.CString(r), success } +// NativeHandle is the type of the native PAM handle for a transaction so that +// it can be exported +type NativeHandle = *C.pam_handle_t + // transactionBase is a handler for a PAM transaction that can be used to // group the operations that can be performed both by the application and the // module side type transactionBase struct { - handle *C.pam_handle_t + handle NativeHandle lastStatus atomic.Int32 }