From 7a073f5ba0df7e934f6a3af36c21d8c8ae5042f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Tue, 3 Oct 2023 14:37:28 +0200 Subject: [PATCH] module-transaction: Add support for setting/getting module data Module data is data associated with a module handle that is available for the whole module loading time so it can be used also during different operations. We use cgo handles to preserve the life of the go objects so any value can be associated with a pam transaction. --- .../integration-tester-module.go | 8 +- .../integration-tester-module_test.go | 106 ++++++++++++++++++ module-transaction-mock.go | 34 ++++++ module-transaction.go | 65 ++++++++++- module-transaction_test.go | 82 ++++++++++++++ transaction.h | 22 ++++ 6 files changed, 309 insertions(+), 8 deletions(-) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 995e0c2..76cebe8 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -55,8 +55,12 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request } var args []reflect.Value - for _, arg := range r.ActionArgs { - args = append(args, reflect.ValueOf(arg)) + for i, arg := range r.ActionArgs { + if arg == nil { + args = append(args, reflect.Zero(method.Type().In(i))) + } else { + args = append(args, reflect.ValueOf(arg)) + } } res = &Result{Action: "return"} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index 38e95c3..d17ba51 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -542,6 +542,94 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { return ensureUser(tx, "setup-user") }, }, + "get-data-not-available": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }}, + }, + "set-data-empty-nil": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, + "set-data-empty-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", []string{"hello", "world"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{[]string{"hello", "world"}, nil}, + }, + }, + }, + "set-data-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-error-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-error-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + }, + }, + "set-data-to-value-replacing": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", "Hello"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{"Hello", nil}, + }, + }, + }, + "set-data-to-value-unset": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, } for name, tc := range tests { @@ -774,6 +862,24 @@ func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { }, }, }, + "SetData-nil": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, + "SetData": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", true), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, } for name, tc := range tests { diff --git a/module-transaction-mock.go b/module-transaction-mock.go index f00202e..968026a 100644 --- a/module-transaction-mock.go +++ b/module-transaction-mock.go @@ -6,6 +6,7 @@ package pam #cgo CFLAGS: -Wall -std=c99 #include #include +#include */ import "C" @@ -19,6 +20,7 @@ import ( type mockModuleTransactionExpectations struct { UserPrompt string + DataKey string } type mockModuleTransactionReturnedData struct { @@ -33,14 +35,19 @@ type mockModuleTransaction struct { Expectations mockModuleTransactionExpectations RetData mockModuleTransactionReturnedData ConversationHandler ConversationHandler + moduleData map[string]uintptr allocatedData []unsafe.Pointer } func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + m.moduleData = make(map[string]uintptr) runtime.SetFinalizer(m, func(m *mockModuleTransaction) { for _, ptr := range m.allocatedData { C.free(ptr) } + for _, handle := range m.moduleData { + _go_pam_data_cleanup(nil, C.uintptr_t(handle), C.PAM_DATA_SILENT) + } }) return m } @@ -76,6 +83,33 @@ func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int return C.int(m.RetData.Status) } +func (m *mockModuleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if handle, ok := m.moduleData[goKey]; ok { + *outHandle = C.uintptr_t(handle) + } else { + *outHandle = 0 + } + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if oldHandle, ok := m.moduleData[goKey]; ok { + _go_pam_data_cleanup(nil, C.uintptr_t(oldHandle), C.PAM_DATA_REPLACE) + } + if handle != 0 { + m.moduleData[goKey] = uintptr(handle) + } + return C.int(m.RetData.Status) +} + type mockConversationHandler struct { User string ExpectedMessage string diff --git a/module-transaction.go b/module-transaction.go index a698f42..71419e0 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -2,18 +2,14 @@ package pam /* -#cgo CFLAGS: -Wall -std=c99 -#cgo LDFLAGS: -lpam - -#include -#include -#include +#include "transaction.h" */ import "C" import ( "errors" "fmt" + "runtime/cgo" "unsafe" ) @@ -26,6 +22,8 @@ type ModuleTransaction interface { GetEnv(name string) string GetEnvList() (map[string]string, error) GetUser(prompt string) (string, error) + SetData(key string, data any) error + GetData(key string) (any, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -102,6 +100,8 @@ func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, type moduleTransactionIface interface { getUser(outUser **C.char, prompt *C.char) C.int + setData(key *C.char, handle C.uintptr_t) C.int + getData(key *C.char, outHandle *C.uintptr_t) C.int } func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { @@ -127,3 +127,56 @@ func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, func (m *moduleTransaction) GetUser(prompt string) (string, error) { return m.getUserImpl(m, prompt) } + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *moduleTransaction) SetData(key string, data any) error { + return m.setDataImpl(m, key, data) +} + +func (m *moduleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + return C.set_data(m.handle, key, handle) +} + +// setDataImpl is the implementation for SetData for testing purposes. +func (m *moduleTransaction) setDataImpl(iface moduleTransactionIface, + key string, data any) error { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle cgo.Handle + if data != nil { + handle = cgo.NewHandle(data) + } + return m.handlePamStatus(iface.setData(cKey, C.uintptr_t(handle))) +} + +//export _go_pam_data_cleanup +func _go_pam_data_cleanup(h NativeHandle, handle C.uintptr_t, status C.int) { + cgo.Handle(handle).Delete() +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *moduleTransaction) GetData(key string) (any, error) { + return m.getDataImpl(m, key) +} + +func (m *moduleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + return C.get_data(m.handle, key, outHandle) +} + +// getDataImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) getDataImpl(iface moduleTransactionIface, + key string) (any, error) { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle C.uintptr_t + if err := m.handlePamStatus(iface.getData(cKey, &handle)); err != nil { + return nil, err + } + if goHandle := cgo.Handle(handle); goHandle != cgo.Handle(0) { + return goHandle.Value(), nil + } + + return nil, m.handlePamStatus(C.int(ErrNoModuleData)) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index 7a44fd3..fa4c1be 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -8,6 +8,13 @@ import ( "testing" ) +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + func Test_NewNullModuleTransaction(t *testing.T) { t.Parallel() mt := moduleTransaction{} @@ -68,6 +75,24 @@ func Test_NewNullModuleTransaction(t *testing.T) { return mt.GetUser("prompt") }, }, + "GetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetData("some-data") + }, + }, + "SetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", []interface{}{}) + }, + }, + "SetData-nil": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", nil) + }, + }, } for name, tc := range tests { @@ -313,6 +338,63 @@ func Test_MockModuleTransaction(t *testing.T) { return mt.getUserImpl(mock, "who are you?") }, }, + "GetData-not-available": { + expectedError: ErrNoModuleData, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "GetData-not-available-other-failure": { + expectedError: ErrBuf, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + mockRetData: mockModuleTransactionReturnedData{Status: ErrBuf}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "SetData-empty-nil": { + expectedError: ErrNoModuleData, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", nil)) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-empty-to-value": { + expectedValue: []string{"hello", "world"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", + []string{"hello", "world"})) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-to-value": { + expectedValue: []interface{}{"a string", true, 0.55, errors.New("oh no")}, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "some-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "some-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + return mt.getDataImpl(mock, "some-data") + }, + }, + "SetData-to-value-replacing": { + expectedValue: "just a value", + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "replaced-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + "just a value")) + return mt.getDataImpl(mock, "replaced-data") + }, + }, } for name, tc := range tests { diff --git a/transaction.h b/transaction.h index 88d2766..b19ce3e 100644 --- a/transaction.h +++ b/transaction.h @@ -1,4 +1,7 @@ +#pragma once + #include +#include #include #include #include @@ -18,6 +21,7 @@ #endif extern int _go_pam_conv_handler(struct pam_message *, uintptr_t, char **reply); +extern void _go_pam_data_cleanup(pam_handle_t *, uintptr_t, int status); static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) { @@ -67,3 +71,21 @@ static inline int check_pam_start_confdir(void) return 0; } + +static inline void data_cleanup(pam_handle_t *pamh, void *data, int error_status) +{ + _go_pam_data_cleanup(pamh, (uintptr_t)data, error_status); +} + +static inline int set_data(pam_handle_t *pamh, const char *name, uintptr_t handle) +{ + if (handle) + return pam_set_data(pamh, name, (void *)handle, data_cleanup); + + return pam_set_data(pamh, name, NULL, NULL); +} + +static inline int get_data(pam_handle_t *pamh, const char *name, uintptr_t *out_handle) +{ + return pam_get_data(pamh, name, (const void **)out_handle); +}