diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 995e0c2..76cebe8 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -55,8 +55,12 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request } var args []reflect.Value - for _, arg := range r.ActionArgs { - args = append(args, reflect.ValueOf(arg)) + for i, arg := range r.ActionArgs { + if arg == nil { + args = append(args, reflect.Zero(method.Type().In(i))) + } else { + args = append(args, reflect.ValueOf(arg)) + } } res = &Result{Action: "return"} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index 38e95c3..d17ba51 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -542,6 +542,94 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { return ensureUser(tx, "setup-user") }, }, + "get-data-not-available": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }}, + }, + "set-data-empty-nil": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, + "set-data-empty-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", []string{"hello", "world"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{[]string{"hello", "world"}, nil}, + }, + }, + }, + "set-data-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-error-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-error-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + }, + }, + "set-data-to-value-replacing": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", "Hello"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{"Hello", nil}, + }, + }, + }, + "set-data-to-value-unset": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, } for name, tc := range tests { @@ -774,6 +862,24 @@ func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { }, }, }, + "SetData-nil": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, + "SetData": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", true), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, } for name, tc := range tests { diff --git a/module-transaction-mock.go b/module-transaction-mock.go index f00202e..968026a 100644 --- a/module-transaction-mock.go +++ b/module-transaction-mock.go @@ -6,6 +6,7 @@ package pam #cgo CFLAGS: -Wall -std=c99 #include #include +#include */ import "C" @@ -19,6 +20,7 @@ import ( type mockModuleTransactionExpectations struct { UserPrompt string + DataKey string } type mockModuleTransactionReturnedData struct { @@ -33,14 +35,19 @@ type mockModuleTransaction struct { Expectations mockModuleTransactionExpectations RetData mockModuleTransactionReturnedData ConversationHandler ConversationHandler + moduleData map[string]uintptr allocatedData []unsafe.Pointer } func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + m.moduleData = make(map[string]uintptr) runtime.SetFinalizer(m, func(m *mockModuleTransaction) { for _, ptr := range m.allocatedData { C.free(ptr) } + for _, handle := range m.moduleData { + _go_pam_data_cleanup(nil, C.uintptr_t(handle), C.PAM_DATA_SILENT) + } }) return m } @@ -76,6 +83,33 @@ func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int return C.int(m.RetData.Status) } +func (m *mockModuleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if handle, ok := m.moduleData[goKey]; ok { + *outHandle = C.uintptr_t(handle) + } else { + *outHandle = 0 + } + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if oldHandle, ok := m.moduleData[goKey]; ok { + _go_pam_data_cleanup(nil, C.uintptr_t(oldHandle), C.PAM_DATA_REPLACE) + } + if handle != 0 { + m.moduleData[goKey] = uintptr(handle) + } + return C.int(m.RetData.Status) +} + type mockConversationHandler struct { User string ExpectedMessage string diff --git a/module-transaction.go b/module-transaction.go index a698f42..71419e0 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -2,18 +2,14 @@ package pam /* -#cgo CFLAGS: -Wall -std=c99 -#cgo LDFLAGS: -lpam - -#include -#include -#include +#include "transaction.h" */ import "C" import ( "errors" "fmt" + "runtime/cgo" "unsafe" ) @@ -26,6 +22,8 @@ type ModuleTransaction interface { GetEnv(name string) string GetEnvList() (map[string]string, error) GetUser(prompt string) (string, error) + SetData(key string, data any) error + GetData(key string) (any, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -102,6 +100,8 @@ func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, type moduleTransactionIface interface { getUser(outUser **C.char, prompt *C.char) C.int + setData(key *C.char, handle C.uintptr_t) C.int + getData(key *C.char, outHandle *C.uintptr_t) C.int } func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { @@ -127,3 +127,56 @@ func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, func (m *moduleTransaction) GetUser(prompt string) (string, error) { return m.getUserImpl(m, prompt) } + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *moduleTransaction) SetData(key string, data any) error { + return m.setDataImpl(m, key, data) +} + +func (m *moduleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + return C.set_data(m.handle, key, handle) +} + +// setDataImpl is the implementation for SetData for testing purposes. +func (m *moduleTransaction) setDataImpl(iface moduleTransactionIface, + key string, data any) error { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle cgo.Handle + if data != nil { + handle = cgo.NewHandle(data) + } + return m.handlePamStatus(iface.setData(cKey, C.uintptr_t(handle))) +} + +//export _go_pam_data_cleanup +func _go_pam_data_cleanup(h NativeHandle, handle C.uintptr_t, status C.int) { + cgo.Handle(handle).Delete() +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *moduleTransaction) GetData(key string) (any, error) { + return m.getDataImpl(m, key) +} + +func (m *moduleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + return C.get_data(m.handle, key, outHandle) +} + +// getDataImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) getDataImpl(iface moduleTransactionIface, + key string) (any, error) { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle C.uintptr_t + if err := m.handlePamStatus(iface.getData(cKey, &handle)); err != nil { + return nil, err + } + if goHandle := cgo.Handle(handle); goHandle != cgo.Handle(0) { + return goHandle.Value(), nil + } + + return nil, m.handlePamStatus(C.int(ErrNoModuleData)) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index 7a44fd3..fa4c1be 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -8,6 +8,13 @@ import ( "testing" ) +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + func Test_NewNullModuleTransaction(t *testing.T) { t.Parallel() mt := moduleTransaction{} @@ -68,6 +75,24 @@ func Test_NewNullModuleTransaction(t *testing.T) { return mt.GetUser("prompt") }, }, + "GetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetData("some-data") + }, + }, + "SetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", []interface{}{}) + }, + }, + "SetData-nil": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", nil) + }, + }, } for name, tc := range tests { @@ -313,6 +338,63 @@ func Test_MockModuleTransaction(t *testing.T) { return mt.getUserImpl(mock, "who are you?") }, }, + "GetData-not-available": { + expectedError: ErrNoModuleData, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "GetData-not-available-other-failure": { + expectedError: ErrBuf, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + mockRetData: mockModuleTransactionReturnedData{Status: ErrBuf}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "SetData-empty-nil": { + expectedError: ErrNoModuleData, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", nil)) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-empty-to-value": { + expectedValue: []string{"hello", "world"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", + []string{"hello", "world"})) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-to-value": { + expectedValue: []interface{}{"a string", true, 0.55, errors.New("oh no")}, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "some-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "some-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + return mt.getDataImpl(mock, "some-data") + }, + }, + "SetData-to-value-replacing": { + expectedValue: "just a value", + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "replaced-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + "just a value")) + return mt.getDataImpl(mock, "replaced-data") + }, + }, } for name, tc := range tests { diff --git a/transaction.h b/transaction.h index 88d2766..b19ce3e 100644 --- a/transaction.h +++ b/transaction.h @@ -1,4 +1,7 @@ +#pragma once + #include +#include #include #include #include @@ -18,6 +21,7 @@ #endif extern int _go_pam_conv_handler(struct pam_message *, uintptr_t, char **reply); +extern void _go_pam_data_cleanup(pam_handle_t *, uintptr_t, int status); static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) { @@ -67,3 +71,21 @@ static inline int check_pam_start_confdir(void) return 0; } + +static inline void data_cleanup(pam_handle_t *pamh, void *data, int error_status) +{ + _go_pam_data_cleanup(pamh, (uintptr_t)data, error_status); +} + +static inline int set_data(pam_handle_t *pamh, const char *name, uintptr_t handle) +{ + if (handle) + return pam_set_data(pamh, name, (void *)handle, data_cleanup); + + return pam_set_data(pamh, name, NULL, NULL); +} + +static inline int get_data(pam_handle_t *pamh, const char *name, uintptr_t *out_handle) +{ + return pam_get_data(pamh, name, (const void **)out_handle); +}