transaction: Add ModuleTransaction type and ModuleHandler interface

This allows to easily define go-handlers for module operations.

We need to expose few more types externally so that it's possible to
create the module transaction handler and return specific transaction
errors
This commit is contained in:
Marco Trevisan (Treviño)
2023-09-25 18:52:56 +02:00
parent bbc25e137c
commit 11daf4a88d
3 changed files with 165 additions and 1 deletions

29
module-transaction.go Normal file
View File

@@ -0,0 +1,29 @@
// Package pam provides a wrapper for the PAM application API.
package pam
// ModuleTransaction is an interface that a pam module transaction
// should implement.
type ModuleTransaction interface {
SetItem(Item, string) error
GetItem(Item) (string, error)
PutEnv(nameVal string) error
GetEnv(name string) string
GetEnvList() (map[string]string, error)
}
// ModuleHandlerFunc is a function type used by the ModuleHandler.
type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error
// ModuleTransaction is the module-side handle for a PAM transaction.
type moduleTransaction = transactionBase
// ModuleHandler is an interface for objects that can be used to create
// PAM modules from go.
type ModuleHandler interface {
AcctMgmt(ModuleTransaction, Flags, []string) error
Authenticate(ModuleTransaction, Flags, []string) error
ChangeAuthTok(ModuleTransaction, Flags, []string) error
CloseSession(ModuleTransaction, Flags, []string) error
OpenSession(ModuleTransaction, Flags, []string) error
SetCred(ModuleTransaction, Flags, []string) error
}

131
module-transaction_test.go Normal file
View File

@@ -0,0 +1,131 @@
// Package pam provides a wrapper for the PAM application API.
package pam
import (
"errors"
"reflect"
"testing"
)
func Test_NewNullModuleTransaction(t *testing.T) {
t.Parallel()
mt := moduleTransaction{}
if mt.handle != nil {
t.Fatalf("unexpected handle value: %v", mt.handle)
}
if s := Error(mt.lastStatus.Load()); s != success {
t.Fatalf("unexpected status: %v", s)
}
tests := map[string]struct {
testFunc func(t *testing.T) (any, error)
expectedError error
ignoreError bool
}{
"GetItem": {
testFunc: func(t *testing.T) (any, error) {
t.Helper()
return mt.GetItem(Service)
},
},
"SetItem": {
testFunc: func(t *testing.T) (any, error) {
t.Helper()
return nil, mt.SetItem(Service, "foo")
},
},
"GetEnv": {
ignoreError: true,
testFunc: func(t *testing.T) (any, error) {
t.Helper()
return mt.GetEnv("foo"), nil
},
},
"PutEnv": {
expectedError: ErrAbort,
testFunc: func(t *testing.T) (any, error) {
t.Helper()
return nil, mt.PutEnv("foo=bar")
},
},
"GetEnvList": {
expectedError: ErrBuf,
testFunc: func(t *testing.T) (any, error) {
t.Helper()
list, err := mt.GetEnvList()
if len(list) > 0 {
t.Fatalf("unexpected list: %v", list)
}
return nil, err
},
},
}
for name, tc := range tests {
tc := tc
t.Run(name+"-error-check", func(t *testing.T) {
t.Parallel()
data, err := tc.testFunc(t)
switch d := data.(type) {
case string:
if d != "" {
t.Fatalf("empty value was expected, got %s", d)
}
case interface{}:
if !reflect.ValueOf(d).IsNil() {
t.Fatalf("nil value was expected, got %v", d)
}
default:
if d != nil {
t.Fatalf("nil value was expected, got %v", d)
}
}
if tc.ignoreError {
return
}
if err == nil {
t.Fatal("error was expected, but got none")
}
var expectedError error = ErrSystem
if tc.expectedError != nil {
expectedError = tc.expectedError
}
if !errors.Is(err, expectedError) {
t.Fatalf("status %v was expected, but got %v",
expectedError, err)
}
})
}
for name, tc := range tests {
// These can't be parallel - we test a private value that is not thread safe
t.Run(name+"-lastStatus-check", func(t *testing.T) {
mt.lastStatus.Store(99999)
_, err := tc.testFunc(t)
status := Error(mt.lastStatus.Load())
if tc.ignoreError {
return
}
if err == nil {
t.Fatal("error was expected, but got none")
}
expectedStatus := ErrSystem
if tc.expectedError != nil {
errors.As(err, &expectedStatus)
}
if status != expectedStatus {
t.Fatalf("status %v was expected, but got %d",
expectedStatus, status)
}
})
}
}

View File

@@ -124,11 +124,15 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) {
return C.CString(r), success
}
// NativeHandle is the type of the native PAM handle for a transaction so that
// it can be exported
type NativeHandle = *C.pam_handle_t
// transactionBase is a handler for a PAM transaction that can be used to
// group the operations that can be performed both by the application and the
// module side
type transactionBase struct {
handle *C.pam_handle_t
handle NativeHandle
lastStatus atomic.Int32
}