diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go index 7abc2e3..7ef01f7 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go @@ -29,7 +29,7 @@ func ensureError(t *testing.T, err error, expected error) { func ensureEqual(t *testing.T, a any, b any) { t.Helper() if !reflect.DeepEqual(a, b) { - t.Fatalf("values mismatch %v vs %v", a, b) + t.Fatalf("values mismatch %#v vs %#v", a, b) } } diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index ecde5ce..38e95c3 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -54,7 +54,8 @@ func (cr *checkedRequest) check(res *Result) error { return cr.r.check(res, cr.exp) } -func ensureItem(tx *pam.Transaction, item pam.Item, expected string) error { +func ensureUser(tx *pam.Transaction, expected string) error { + item := pam.User if value, err := tx.GetItem(item); err != nil { return err } else if value != expected { @@ -152,7 +153,7 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { exp: []interface{}{"an-user", nil}, }}, finish: func(tx *pam.Transaction, l *Listener, ts testState) error { - return ensureItem(tx, pam.User, "an-user") + return ensureUser(tx, "an-user") }, }, "set-item-User-preset": { @@ -167,7 +168,7 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { exp: []interface{}{"an-user", nil}, }}, finish: func(tx *pam.Transaction, l *Listener, ts testState) error { - return ensureItem(tx, pam.User, "an-user") + return ensureUser(tx, "an-user") }, }, "set-get-item-User-empty": { @@ -488,6 +489,59 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { return nil }, }, + "get-user-empty-no-conv-set": { + expectedError: pam.ErrConv, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"", pam.ErrConv}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "") + }, + }, + "get-user-empty-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "who are you? ", + ExpectedStyle: pam.PromptEchoOn, + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"replying-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "replying-user") + }, + }, + "get-user-preset-without-conv": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, + "get-user-preset-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "No message should have been shown!", + ExpectedStyle: pam.PromptEchoOn, + }, + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, } for name, tc := range tests { diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index 3fc6b0c..095994b 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -1,6 +1,13 @@ // Package utils contains the internal test utils package utils +import ( + "errors" + "fmt" + + "github.com/msteinert/pam/v2" +) + // Action represents a PAM action to perform. type Action int @@ -106,3 +113,43 @@ type SerializableError struct { func (e *SerializableError) Error() string { return e.Msg } + +// Credentials is a test [pam.ConversationHandler] implementation. +type Credentials struct { + User string + Password string + ExpectedMessage string + CheckEmptyMessage bool + ExpectedStyle pam.Style + CheckZeroStyle bool + Context interface{} +} + +// RespondPAM handles PAM string conversations. +func (c Credentials) RespondPAM(s pam.Style, msg string) (string, error) { + if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && + msg != c.ExpectedMessage { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected prompt: %s vs %s", msg, c.ExpectedMessage), + }) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected style: %#v vs %#v", s, c.ExpectedStyle), + }) + } + + switch s { + case pam.PromptEchoOn: + return c.User, nil + case pam.PromptEchoOff: + return c.Password, nil + } + + return "", errors.Join(pam.ErrConv, + &SerializableError{fmt.Sprintf("unhandled style: %v", s)}) +} diff --git a/module-transaction-mock.go b/module-transaction-mock.go new file mode 100644 index 0000000..f00202e --- /dev/null +++ b/module-transaction-mock.go @@ -0,0 +1,106 @@ +//go:build !go_pam_module + +package pam + +/* +#cgo CFLAGS: -Wall -std=c99 +#include +#include +*/ +import "C" + +import ( + "errors" + "fmt" + "runtime" + "testing" + "unsafe" +) + +type mockModuleTransactionExpectations struct { + UserPrompt string +} + +type mockModuleTransactionReturnedData struct { + User string + InteractiveUser bool + Status Error +} + +type mockModuleTransaction struct { + moduleTransaction + T *testing.T + Expectations mockModuleTransactionExpectations + RetData mockModuleTransactionReturnedData + ConversationHandler ConversationHandler + allocatedData []unsafe.Pointer +} + +func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + runtime.SetFinalizer(m, func(m *mockModuleTransaction) { + for _, ptr := range m.allocatedData { + C.free(ptr) + } + }) + return m +} + +func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + goPrompt := C.GoString(prompt) + if goPrompt != m.Expectations.UserPrompt { + m.T.Fatalf("unexpected prompt: %s vs %s", goPrompt, m.Expectations.UserPrompt) + return C.int(ErrAbort) + } + + user := m.RetData.User + if m.RetData.InteractiveUser || (m.RetData.User == "" && m.ConversationHandler != nil) { + if m.ConversationHandler == nil { + m.T.Fatalf("no conversation handler provided") + } + u, err := m.ConversationHandler.RespondPAM(PromptEchoOn, goPrompt) + user = u + + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + return C.int(ErrAbort) + } + } + + cUser := C.CString(user) + m.allocatedData = append(m.allocatedData, unsafe.Pointer(cUser)) + + *outUser = cUser + return C.int(m.RetData.Status) +} + +type mockConversationHandler struct { + User string + ExpectedMessage string + CheckEmptyMessage bool + ExpectedStyle Style + CheckZeroStyle bool +} + +func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) { + if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && + msg != c.ExpectedMessage { + return "", fmt.Errorf("%w: unexpected prompt: %s vs %s", + ErrConv, msg, c.ExpectedMessage) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", fmt.Errorf("%w: unexpected style: %#v vs %#v", + ErrConv, s, c.ExpectedStyle) + } + + switch s { + case PromptEchoOn: + return c.User, nil + } + + return "", fmt.Errorf("%w: unhandled style: %v", ErrConv, s) +} diff --git a/module-transaction.go b/module-transaction.go index 0e87fe5..a698f42 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -1,11 +1,20 @@ // Package pam provides a wrapper for the PAM application API. package pam +/* +#cgo CFLAGS: -Wall -std=c99 +#cgo LDFLAGS: -lpam + +#include +#include +#include +*/ import "C" import ( "errors" "fmt" + "unsafe" ) // ModuleTransaction is an interface that a pam module transaction @@ -16,6 +25,7 @@ type ModuleTransaction interface { PutEnv(nameVal string) error GetEnv(name string) string GetEnvList() (map[string]string, error) + GetUser(prompt string) (string, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -89,3 +99,31 @@ func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, m.lastStatus.Store(status) return err } + +type moduleTransactionIface interface { + getUser(outUser **C.char, prompt *C.char) C.int +} + +func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + return C.pam_get_user(m.handle, outUser, prompt) +} + +// getUserImpl is the default implementation for GetUser, but kept as private so +// that can be used to test the pam package +func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, + prompt string) (string, error) { + var user *C.char + var cPrompt = C.CString(prompt) + defer C.free(unsafe.Pointer(cPrompt)) + err := m.handlePamStatus(iface.getUser(&user, cPrompt)) + if err != nil { + return "", err + } + return C.GoString(user), nil +} + +// GetUser is similar to GetItem(User), but it would start a conversation if +// no user is currently set in PAM. +func (m *moduleTransaction) GetUser(prompt string) (string, error) { + return m.getUserImpl(m, prompt) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index 8661f68..7a44fd3 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -62,6 +62,12 @@ func Test_NewNullModuleTransaction(t *testing.T) { return nil, err }, }, + "GetUser": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetUser("prompt") + }, + }, } for name, tc := range tests { @@ -235,3 +241,96 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + t.Parallel() + + mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) + + tests := map[string]struct { + testFunc func(mock *mockModuleTransaction) (any, error) + mockExpectations mockModuleTransactionExpectations + mockRetData mockModuleTransactionReturnedData + conversationHandler ConversationHandler + + expectedError error + expectedValue any + ignoreError bool + }{ + "GetUser-empty": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-preset-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + mockRetData: mockModuleTransactionReturnedData{User: "dummy-user"}, + expectedValue: "dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you?", + User: "returned-dummy-user", + }, + expectedValue: "returned-dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-prompt": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you???", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-style": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "who are you?", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mock := newMockModuleTransaction(&mockModuleTransaction{T: t, + Expectations: tc.mockExpectations, RetData: tc.mockRetData, + ConversationHandler: tc.conversationHandler}) + data, err := tc.testFunc(mock) + + if !tc.ignoreError && !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + if !reflect.DeepEqual(data, tc.expectedValue) { + t.Fatalf("data mismatch, %#v vs %#v", data, tc.expectedValue) + } + }) + } +}