transaction: Add ModuleTransaction type and ModuleHandler interface
This allows to easily define go-handlers for module operations. We need to expose few more types externally so that it's possible to create the module transaction handler and return specific transaction errors
This commit is contained in:
29
module-transaction.go
Normal file
29
module-transaction.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Package pam provides a wrapper for the PAM application API.
|
||||
package pam
|
||||
|
||||
// ModuleTransaction is an interface that a pam module transaction
|
||||
// should implement.
|
||||
type ModuleTransaction interface {
|
||||
SetItem(Item, string) error
|
||||
GetItem(Item) (string, error)
|
||||
PutEnv(nameVal string) error
|
||||
GetEnv(name string) string
|
||||
GetEnvList() (map[string]string, error)
|
||||
}
|
||||
|
||||
// ModuleHandlerFunc is a function type used by the ModuleHandler.
|
||||
type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error
|
||||
|
||||
// ModuleTransaction is the module-side handle for a PAM transaction.
|
||||
type moduleTransaction = transactionBase
|
||||
|
||||
// ModuleHandler is an interface for objects that can be used to create
|
||||
// PAM modules from go.
|
||||
type ModuleHandler interface {
|
||||
AcctMgmt(ModuleTransaction, Flags, []string) error
|
||||
Authenticate(ModuleTransaction, Flags, []string) error
|
||||
ChangeAuthTok(ModuleTransaction, Flags, []string) error
|
||||
CloseSession(ModuleTransaction, Flags, []string) error
|
||||
OpenSession(ModuleTransaction, Flags, []string) error
|
||||
SetCred(ModuleTransaction, Flags, []string) error
|
||||
}
|
||||
131
module-transaction_test.go
Normal file
131
module-transaction_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Package pam provides a wrapper for the PAM application API.
|
||||
package pam
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_NewNullModuleTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
mt := moduleTransaction{}
|
||||
|
||||
if mt.handle != nil {
|
||||
t.Fatalf("unexpected handle value: %v", mt.handle)
|
||||
}
|
||||
|
||||
if s := Error(mt.lastStatus.Load()); s != success {
|
||||
t.Fatalf("unexpected status: %v", s)
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
testFunc func(t *testing.T) (any, error)
|
||||
expectedError error
|
||||
ignoreError bool
|
||||
}{
|
||||
"GetItem": {
|
||||
testFunc: func(t *testing.T) (any, error) {
|
||||
t.Helper()
|
||||
return mt.GetItem(Service)
|
||||
},
|
||||
},
|
||||
"SetItem": {
|
||||
testFunc: func(t *testing.T) (any, error) {
|
||||
t.Helper()
|
||||
return nil, mt.SetItem(Service, "foo")
|
||||
},
|
||||
},
|
||||
"GetEnv": {
|
||||
ignoreError: true,
|
||||
testFunc: func(t *testing.T) (any, error) {
|
||||
t.Helper()
|
||||
return mt.GetEnv("foo"), nil
|
||||
},
|
||||
},
|
||||
"PutEnv": {
|
||||
expectedError: ErrAbort,
|
||||
testFunc: func(t *testing.T) (any, error) {
|
||||
t.Helper()
|
||||
return nil, mt.PutEnv("foo=bar")
|
||||
},
|
||||
},
|
||||
"GetEnvList": {
|
||||
expectedError: ErrBuf,
|
||||
testFunc: func(t *testing.T) (any, error) {
|
||||
t.Helper()
|
||||
list, err := mt.GetEnvList()
|
||||
if len(list) > 0 {
|
||||
t.Fatalf("unexpected list: %v", list)
|
||||
}
|
||||
return nil, err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name+"-error-check", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
data, err := tc.testFunc(t)
|
||||
|
||||
switch d := data.(type) {
|
||||
case string:
|
||||
if d != "" {
|
||||
t.Fatalf("empty value was expected, got %s", d)
|
||||
}
|
||||
case interface{}:
|
||||
if !reflect.ValueOf(d).IsNil() {
|
||||
t.Fatalf("nil value was expected, got %v", d)
|
||||
}
|
||||
default:
|
||||
if d != nil {
|
||||
t.Fatalf("nil value was expected, got %v", d)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ignoreError {
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("error was expected, but got none")
|
||||
}
|
||||
|
||||
var expectedError error = ErrSystem
|
||||
if tc.expectedError != nil {
|
||||
expectedError = tc.expectedError
|
||||
}
|
||||
|
||||
if !errors.Is(err, expectedError) {
|
||||
t.Fatalf("status %v was expected, but got %v",
|
||||
expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
// These can't be parallel - we test a private value that is not thread safe
|
||||
t.Run(name+"-lastStatus-check", func(t *testing.T) {
|
||||
mt.lastStatus.Store(99999)
|
||||
_, err := tc.testFunc(t)
|
||||
status := Error(mt.lastStatus.Load())
|
||||
|
||||
if tc.ignoreError {
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("error was expected, but got none")
|
||||
}
|
||||
|
||||
expectedStatus := ErrSystem
|
||||
if tc.expectedError != nil {
|
||||
errors.As(err, &expectedStatus)
|
||||
}
|
||||
|
||||
if status != expectedStatus {
|
||||
t.Fatalf("status %v was expected, but got %d",
|
||||
expectedStatus, status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -124,11 +124,15 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) {
|
||||
return C.CString(r), success
|
||||
}
|
||||
|
||||
// NativeHandle is the type of the native PAM handle for a transaction so that
|
||||
// it can be exported
|
||||
type NativeHandle = *C.pam_handle_t
|
||||
|
||||
// transactionBase is a handler for a PAM transaction that can be used to
|
||||
// group the operations that can be performed both by the application and the
|
||||
// module side
|
||||
type transactionBase struct {
|
||||
handle *C.pam_handle_t
|
||||
handle NativeHandle
|
||||
lastStatus atomic.Int32
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user