diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..4729015 --- /dev/null +++ b/.clang-format @@ -0,0 +1,111 @@ +# clang-format configuration file. For more information, see: +# +# https://clang.llvm.org/docs/ClangFormat.html +# https://clang.llvm.org/docs/ClangFormatStyleOptions.html +# +--- +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeComma +BreakStringLiterals: false +ColumnLimit: 120 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 8 +ContinuationIndentWidth: 8 +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: false +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^"(allez)/' + Priority: 2 + SortPriority: 2 + CaseSensitive: true + - Regex: '.*' + Priority: 1 + SortPriority: 0 +IndentCaseLabels: false +IndentGotoLabels: false +IndentPPDirectives: None +IndentWidth: 8 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 8 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true + +# Taken from git's rules +PenaltyBreakAssignment: 10 +PenaltyBreakBeforeFirstCallParameter: 30 +PenaltyBreakComment: 10 +PenaltyBreakFirstLessLess: 0 +PenaltyBreakString: 10 +PenaltyExcessCharacter: 2 +PenaltyReturnTypeOnItsOwnLine: 60 + +PointerAlignment: Right +ReflowComments: false +SortIncludes: true +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatementsExceptForEachMacros +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +TabWidth: 8 +UseTab: Always +... diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..5066aeb --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,3 @@ +ignore: + # Ignore pam-moduler generated files + - "**/pam_module.go" diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..771e735 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,22 @@ +on: [push, pull_request] +name: Lint + +permissions: + contents: read + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version: '1.21' + cache: false + - name: Install PAM + run: sudo apt install -y libpam-dev + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.54 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0dd0b77..a46db7e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,19 +4,48 @@ jobs: test: strategy: matrix: - go-version: [1.19.x, 1.20.x] + go-version: [1.20.x, 1.21.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: - name: Install Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go-version }} - name: Install PAM - run: sudo apt install -y libpam-dev + run: | + sudo apt update -y + sudo apt install -y libpam-dev + - name: Install Debug symbols + run: | + sudo apt install -y ubuntu-dev-tools + (cd /tmp && pull-lp-ddebs libpam0g $(lsb_release -c -s)) + (cd /tmp && pull-lp-ddebs libpam-modules $(lsb_release -c -s)) + sudo dpkg -i /tmp/libpam*-dbgsym_*.ddeb - name: Add a test user run: sudo useradd -d /tmp/test -p '$1$Qd8H95T5$RYSZQeoFbEB.gS19zS99A0' -s /bin/false test - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Test - run: sudo go test -v ./... + run: sudo go test -v -cover -coverprofile=coverage.out -coverpkg=./... ./... + - name: Test with Address Sanitizer + env: + GO_PAM_TEST_WITH_ASAN: true + CGO_CFLAGS: "-O0 -g3 -fno-omit-frame-pointer" + run: | + # Do not run sudo-requiring go tests because as PAM has some leaks in 22.04 + go test -v -asan -cover -coverprofile=coverage-asan-tx.out -coverpkg=./... -gcflags=all="-N -l" + + # Run the rest of tests normally + sudo go test -v -cover -coverprofile=coverage-asan-module.out -coverpkg=./... -asan -gcflags=all="-N -l" -run Module + sudo go test -C cmd -coverprofile=coverage-asan.out -v -coverpkg=./... -asan -gcflags=all="-N -l" ./... + - name: Generate example module + run: | + rm -f example-module/pam_go.so + go generate -C example-module -v + test -e example-module/pam_go.so + git diff --exit-code example-module + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8206d2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +coverage*.out +example-module/*.so +example-module/*.h +cmd/pam-moduler/tests/*/*.so +cmd/pam-moduler/tests/*/*.h diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..bbfa6b4 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,61 @@ +# This is for linting. To run it, please use: +# golangci-lint run ${MODULE}/... [--fix] + +linters: + # linters to run in addition to default ones + enable: + - dupl + - durationcheck + - errname + - errorlint + - exportloopref + - forbidigo + - forcetypeassert + - gci + - godot + - gofmt + - gosec + - misspell + - nakedret + - nolintlint + - revive + - thelper + - tparallel + - unconvert + - unparam + - whitespace + +run: + timeout: 5m + +# Get all linter issues, even if duplicated +issues: + exclude-use-default: false + max-issues-per-linter: 0 + max-same-issues: 0 + fix: false # we don’t want this in CI + exclude: + # EXC0001 errcheck: most errors are in defer calls, which are safe to ignore and idiomatic Go (would be good to only ignore defer ones though) + - 'Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv|w\.Stop). is not checked' + # EXC0008 gosec: duplicated of errcheck + - (G104|G307) + # EXC0010 gosec: False positive is triggered by 'src, err := ioutil.ReadFile(filename)' + - Potential file inclusion via variable + # We want named parameters even if unused, as they help better document the function + - unused-parameter + # Sometimes it is more readable it do a `if err:=a(); err != nil` tha simpy `return a()` + - if-return + +nolintlint: + require-explanation: true + require-specific: true + +linters-settings: + # Forbid the usage of deprecated ioutil and debug prints + forbidigo: + forbid: + - ioutil\. + - ^print.*$ + # Never have naked return ever + nakedret: + max-func-lines: 1 diff --git a/README.md b/README.md index deb946b..60f1c0c 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,128 @@ -[![GoDoc](https://godoc.org/github.com/msteinert/pam?status.svg)](http://godoc.org/github.com/msteinert/pam) -[![Go Report Card](https://goreportcard.com/badge/github.com/msteinert/pam)](https://goreportcard.com/report/github.com/msteinert/pam) +[![GoDoc](https://godoc.org/github.com/msteinert/pam/v2?status.svg)](http://godoc.org/github.com/msteinert/pam/v2) +[![codecov](https://codecov.io/gh/msteinert/pam/graph/badge.svg?token=L1K3UTB065)](https://codecov.io/gh/msteinert/pam) +[![Go Report Card](https://goreportcard.com/badge/github.com/msteinert/pam/v2)](https://goreportcard.com/report/github.com/msteinert/pam/v2) # Go PAM This is a Go wrapper for the PAM application API. +## Module support + +Go PAM can also used to create PAM modules in a simple way, using the go. + +The code can be generated using [pam-moduler](cmd/pam-moduler/moduler.go) and +an example how to use it using `go generate` create them is available as an +[example module](example-module/module.go). + +### Modules and PAM applications + +The modules generated with go can be used by any PAM application, however there +are some caveats, in fact a Go shared library could misbehave when loaded +improperly. In particular if a Go shared library is loaded and then the program +`fork`s, the library will have an undefined behavior. + +This is the case of SSHd that loads a pam library before forking, making any +go PAM library to make it hang. + +To solve this case, we can use a little workaround: to ensure that the go +library is loaded only after the program has forked, we can just `dload` it once +a PAM library is called, in this way go code will be loaded only after that the +PAM application has `fork`'ed. + +To do this, we can use a very simple wrapper written in C: + +```c +#include +#include +#include +#include + +typedef int (*PamHandler)(pam_handle_t *, + int flags, + int argc, + const char **argv); + +static void +on_go_module_removed (pam_handle_t *pamh, + void *go_module, + int error_status) +{ + dlclose (go_module); +} + +static void * +load_module (pam_handle_t *pamh, + const char *module_path) +{ + void *go_module; + + if (pam_get_data (pamh, "go-module", (const void **) &go_module) == PAM_SUCCESS) + return go_module; + + go_module = dlopen (module_path, RTLD_LAZY); + if (!go_module) + return NULL; + + pam_set_data (pamh, "go-module", go_module, on_go_module_removed); + + return go_module; +} + +static inline int +call_pam_function (pam_handle_t *pamh, + const char *function, + int flags, + int argc, + const char **argv) +{ + char module_path[PATH_MAX] = {0}; + const char *sub_module; + PamHandler func; + void *go_module; + + if (argc < 1) + { + pam_error (pamh, "%s: no module provided", function); + return PAM_MODULE_UNKNOWN; + } + + sub_module = argv[0]; + argc -= 1; + argv = (argc == 0) ? NULL : &argv[1]; + + strncpy (module_path, sub_module, PATH_MAX - 1); + + go_module = load_module (pamh, module_path); + if (!go_module) + { + pam_error (pamh, "Impossible to load module %s", module_path); + return PAM_OPEN_ERR; + } + + *(void **) (&func) = dlsym (go_module, function); + if (!func) + { + pam_error (pamh, "Symbol %s not found in %s", function, module_path); + return PAM_OPEN_ERR; + } + + return func (pamh, flags, argc, argv); +} + +#define DEFINE_PAM_WRAPPER(name) \ + PAM_EXTERN int \ + (pam_sm_ ## name) (pam_handle_t * pamh, int flags, int argc, const char **argv) \ + { \ + return call_pam_function (pamh, "pam_sm_" #name, flags, argc, argv); \ + } + +DEFINE_PAM_WRAPPER (authenticate) +DEFINE_PAM_WRAPPER (chauthtok) +DEFINE_PAM_WRAPPER (close_session) +DEFINE_PAM_WRAPPER (open_session) +DEFINE_PAM_WRAPPER (setcred) +``` + ## Testing To run the full suite, the tests must be run as the root user. To setup your @@ -24,5 +142,8 @@ Then execute the tests: $ sudo GOPATH=$GOPATH $(which go) test -v ``` -[1]: http://godoc.org/github.com/msteinert/pam +Other tests can instead run as user without any setup with +normal `go test ./...` + +[1]: http://godoc.org/github.com/msteinert/pam/v2 [2]: http://www.linux-pam.org/Linux-PAM-html/Linux-PAM_ADG.html diff --git a/app-transaction.go b/app-transaction.go new file mode 100644 index 0000000..39a3cf4 --- /dev/null +++ b/app-transaction.go @@ -0,0 +1,292 @@ +//go:build !go_pam_module + +package pam + +/* +#include "transaction.h" +*/ +import "C" + +import ( + "fmt" + "runtime/cgo" + "sync/atomic" + "unsafe" +) + +// ConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication. +type ConversationHandler interface { + // RespondPAM receives a message style and a message string. If the + // message Style is PromptEchoOff or PromptEchoOn then the function + // should return a response string. + RespondPAM(Style, string) (string, error) +} + +// BinaryConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication if binary protocol is going +// to be supported. +type BinaryConversationHandler interface { + ConversationHandler + // RespondPAMBinary receives a pointer to the binary message. It's up to + // the receiver to parse it according to the protocol specifications. + // The function can return a byte array that will passed as pointer back + // to the module. + RespondPAMBinary(BinaryPointer) ([]byte, error) +} + +// BinaryPointerConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication if binary protocol is going +// to be supported. +type BinaryPointerConversationHandler interface { + ConversationHandler + // RespondPAMBinary receives a pointer to the binary message. It's up to + // the receiver to parse it according to the protocol specifications. + // The function must return a pointer that is allocated via malloc or + // similar, as it's expected to be free'd by the conversation handler. + RespondPAMBinary(BinaryPointer) (BinaryPointer, error) +} + +// ConversationFunc is an adapter to allow the use of ordinary functions as +// conversation callbacks. +type ConversationFunc func(Style, string) (string, error) + +// RespondPAM is a conversation callback adapter. +func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { + return f(s, msg) +} + +// BinaryConversationFunc is an adapter to allow the use of ordinary functions +// as binary (only) conversation callbacks. +type BinaryConversationFunc func(BinaryPointer) ([]byte, error) + +// RespondPAMBinary is a conversation callback adapter. +func (f BinaryConversationFunc) RespondPAMBinary(ptr BinaryPointer) ([]byte, error) { + return f(ptr) +} + +// RespondPAM is a dummy conversation callback adapter. +func (f BinaryConversationFunc) RespondPAM(Style, string) (string, error) { + return "", ErrConv +} + +// BinaryPointerConversationFunc is an adapter to allow the use of ordinary +// functions as binary pointer (only) conversation callbacks. +type BinaryPointerConversationFunc func(BinaryPointer) (BinaryPointer, error) + +// RespondPAMBinary is a conversation callback adapter. +func (f BinaryPointerConversationFunc) RespondPAMBinary(ptr BinaryPointer) (BinaryPointer, error) { + return f(ptr) +} + +// RespondPAM is a dummy conversation callback adapter. +func (f BinaryPointerConversationFunc) RespondPAM(Style, string) (string, error) { + return "", ErrConv +} + +// _go_pam_conv_handler is a C wrapper for the conversation callback function. +// +//export _go_pam_conv_handler +func _go_pam_conv_handler(msg *C.struct_pam_message, c C.uintptr_t, outMsg **C.char) C.int { + convHandler, ok := cgo.Handle(c).Value().(ConversationHandler) + if !ok || convHandler == nil { + return C.int(ErrConv) + } + replyMsg, r := pamConvHandler(Style(msg.msg_style), msg.msg, convHandler) + *outMsg = replyMsg + return r +} + +// pamConvHandler is a Go wrapper for the conversation callback function. +func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.char, C.int) { + var r string + var err error + switch cb := handler.(type) { + case BinaryConversationHandler: + if style == BinaryPrompt { + bytes, err := cb.RespondPAMBinary(BinaryPointer(msg)) + if err != nil { + return nil, C.int(ErrConv) + } + if bytes == nil { + return nil, success + } + return (*C.char)(C.CBytes(bytes)), success + } + handler = cb + case BinaryPointerConversationHandler: + if style == BinaryPrompt { + ptr, err := cb.RespondPAMBinary(BinaryPointer(msg)) + if err != nil { + defer C.free(unsafe.Pointer(ptr)) + return nil, C.int(ErrConv) + } + return (*C.char)(ptr), success + } + handler = cb + case ConversationHandler: + if style == BinaryPrompt { + return nil, C.int(ErrConv) + } + handler = cb + default: + return nil, C.int(ErrConv) + } + r, err = handler.RespondPAM(style, C.GoString(msg)) + if err != nil { + return nil, C.int(ErrConv) + } + return C.CString(r), success +} + +// Transaction is the application's handle for a PAM transaction. +type Transaction struct { + transactionBase + + conv *C.struct_pam_conv + c cgo.Handle +} + +// Start initiates a new PAM transaction. Service is treated identically to +// how pam_start treats it internally. +// +// All application calls to PAM begin with Start*. The returned +// transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. +func Start(service, user string, handler ConversationHandler) (*Transaction, error) { + return start(service, user, handler, "") +} + +// StartFunc registers the handler func as a conversation handler and starts +// the transaction (see Start() documentation). +func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { + return start(service, user, ConversationFunc(handler), "") +} + +// StartConfDir initiates a new PAM transaction. Service is treated identically to +// how pam_start treats it internally. +// confdir allows to define where all pam services are defined. This is used to provide +// custom paths for tests. +// +// All application calls to PAM begin with Start*. The returned +// transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. +func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { + if !CheckPamHasStartConfdir() { + return nil, fmt.Errorf( + "%w: StartConfDir was used, but the pam version on the system is not recent enough", + ErrSystem) + } + + return start(service, user, handler, confDir) +} + +func start(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { + switch handler.(type) { + case BinaryConversationHandler: + if !CheckPamHasBinaryProtocol() { + return nil, fmt.Errorf("%w: BinaryConversationHandler was used, but it is not supported by this platform", + ErrSystem) + } + case BinaryPointerConversationHandler: + if !CheckPamHasBinaryProtocol() { + return nil, fmt.Errorf( + "%w: BinaryPointerConversationHandler was used, but it is not supported by this platform", + ErrSystem) + } + } + t := &Transaction{ + conv: &C.struct_pam_conv{}, + c: cgo.NewHandle(handler), + } + + C.init_pam_conv(t.conv, C.uintptr_t(t.c)) + s := C.CString(service) + defer C.free(unsafe.Pointer(s)) + var u *C.char + if len(user) != 0 { + u = C.CString(user) + defer C.free(unsafe.Pointer(u)) + } + var err error + if confDir == "" { + err = t.handlePamStatus(C.pam_start(s, u, t.conv, &t.handle)) + } else { + c := C.CString(confDir) + defer C.free(unsafe.Pointer(c)) + err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) + } + if err != nil { + var _ = t.End() + return nil, err + } + return t, nil +} + +// Authenticate is used to authenticate the user. +// +// Valid flags: Silent, DisallowNullAuthtok +func (t *Transaction) Authenticate(f Flags) error { + return t.handlePamStatus(C.pam_authenticate(t.handle, C.int(f))) +} + +// SetCred is used to establish, maintain and delete the credentials of a +// user. +// +// Valid flags: EstablishCred, DeleteCred, ReinitializeCred, RefreshCred +func (t *Transaction) SetCred(f Flags) error { + return t.handlePamStatus(C.pam_setcred(t.handle, C.int(f))) +} + +// AcctMgmt is used to determine if the user's account is valid. +// +// Valid flags: Silent, DisallowNullAuthtok +func (t *Transaction) AcctMgmt(f Flags) error { + return t.handlePamStatus(C.pam_acct_mgmt(t.handle, C.int(f))) +} + +// ChangeAuthTok is used to change the authentication token. +// +// Valid flags: Silent, ChangeExpiredAuthtok +func (t *Transaction) ChangeAuthTok(f Flags) error { + return t.handlePamStatus(C.pam_chauthtok(t.handle, C.int(f))) +} + +// OpenSession sets up a user session for an authenticated user. +// +// Valid flags: Slient +func (t *Transaction) OpenSession(f Flags) error { + return t.handlePamStatus(C.pam_open_session(t.handle, C.int(f))) +} + +// CloseSession closes a previously opened session. +// +// Valid flags: Silent +func (t *Transaction) CloseSession(f Flags) error { + return t.handlePamStatus(C.pam_close_session(t.handle, C.int(f))) +} + +// End cleans up the PAM handle and deletes the callback function. +// It must be called when done with the transaction. +func (t *Transaction) End() error { + handle := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle)), nil) + if handle == nil { + return nil + } + + defer t.c.Delete() + return t.handlePamStatus(C.pam_end((*C.pam_handle_t)(handle), + C.int(t.lastStatus.Load()))) +} diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go new file mode 100644 index 0000000..0a74125 --- /dev/null +++ b/cmd/pam-moduler/moduler.go @@ -0,0 +1,313 @@ +// pam-moduler is a tool to automate the creation of PAM Modules in go +// +// The file is created in the same package and directory as the package that +// creates the module +// +// The module implementation should define a pamModuleHandler object that +// implements the pam.ModuleHandler type and that will be used for each callback +// +// Otherwise it's possible to provide a typename from command line that will +// be used for this purpose +// +// For example: +// +// //go:generate go run github.com/msteinert/pam/v2/pam-moduler +// //go:generate go generate --skip="pam_module" +// package main +// +// import "github.com/msteinert/pam/v2" +// +// type ExampleHandler struct{} +// var pamModuleHandler pam.ModuleHandler = &ExampleHandler{} +// +// func (h *ExampleHandler) AcctMgmt(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) Authenticate(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } + +// Package main provides the module shared library. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "log" + "os" + "path/filepath" + "strings" +) + +const toolName = "pam-moduler" + +var ( + output = flag.String("output", "", "output file name; default srcdir/pam_module.go") + libName = flag.String("libname", "", "output library name; default pam_go.so") + typeName = flag.String("type", "", "type name to be used as pam.ModuleHandler") + buildTags = flag.String("tags", "", "build tags expression to append to use in the go:build directive") + skipGenerator = flag.Bool("no-generator", false, "whether to add go:generator directives to the generated source") + moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") + moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") + noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") + parallelConv = flag.Bool("parallel-conv", false, "whether to support performing PAM conversations in parallel") +) + +// Usage is a replacement usage function for the flags package. +func Usage() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", toolName) + fmt.Fprintf(os.Stderr, "\t%s [flags] [-output O] [-libname pam_go] [-type N]\n", toolName) + flag.PrintDefaults() +} + +func main() { + log.SetFlags(0) + log.SetPrefix(toolName + ": ") + flag.Usage = Usage + flag.Parse() + + if *skipGenerator { + if *libName != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, libname will have no effect\n") + } + if *moduleBuildTags != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, build-tags will have no effect\n") + } + if *moduleBuildFlags != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, build-flags will have no effect\n") + } + } + + lib := *libName + if lib == "" { + lib = "pam_go" + } else { + lib, _ = strings.CutSuffix(lib, ".so") + lib, _ = strings.CutPrefix(lib, "lib") + } + + outputName, _ := strings.CutSuffix(*output, ".go") + if outputName == "" { + baseName := "pam_module" + outputName = filepath.Join(".", strings.ToLower(baseName)) + } + outputName = outputName + ".go" + + var tags string + if *buildTags != "" { + tags = *buildTags + } + + generateTags := []string{"go_pam_module"} + if len(*moduleBuildTags) > 0 { + generateTags = append(generateTags, strings.Split(*moduleBuildTags, ",")...) + } + + var buildFlags []string + if *moduleBuildFlags != "" { + buildFlags = strings.Split(*moduleBuildFlags, ",") + } + + g := Generator{ + outputName: outputName, + libName: lib, + tags: tags, + buildFlags: buildFlags, + generateTags: generateTags, + noMain: *noMain, + typeName: *typeName, + parallelConv: *parallelConv, + } + + // Print the header and package clause. + g.printf("// Code generated by \"%s %s\"; DO NOT EDIT.\n", + toolName, strings.Join(os.Args[1:], " ")) + g.printf("\n") + + // Generate the code + g.generate() + + // Format the output. + src := g.format() + + // Write to file. + err := os.WriteFile(outputName, src, 0600) + if err != nil { + log.Fatalf("writing output: %s", err) + } +} + +// Generator holds the state of the analysis. Primarily used to buffer +// the output for format.Source. +type Generator struct { + buf bytes.Buffer // Accumulated output. + + libName string + outputName string + typeName string + tags string + generateTags []string + buildFlags []string + noMain bool + parallelConv bool +} + +func (g *Generator) printf(format string, args ...interface{}) { + fmt.Fprintf(&g.buf, format, args...) +} + +// generate produces the String method for the named type. +func (g *Generator) generate() { + if g.tags != "" { + g.printf("//go:build %s\n", g.tags) + } + + var buildTagsArg string + if len(g.generateTags) > 0 { + buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) + } + + var transactionCreator = "NewModuleTransactionInvoker" + if g.parallelConv { + transactionCreator = "NewModuleTransactionInvokerParallelConv" + } + + // We use a slice since we want to keep order, for reproducible builds. + vFuncs := []struct { + cName string + goName string + }{ + {"authenticate", "Authenticate"}, + {"setcred", "SetCred"}, + {"acct_mgmt", "AcctMgmt"}, + {"open_session", "OpenSession"}, + {"close_session", "CloseSession"}, + {"chauthtok", "ChangeAuthTok"}, + } + + g.printf(`//go:generate go build "-ldflags=-extldflags -Wl,-soname,%[2]s.so" `+ + `-buildmode=c-shared -o %[2]s.so %[3]s %[4]s +`, + g.outputName, g.libName, buildTagsArg, strings.Join(g.buildFlags, " ")) + + g.printf(` +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "os" + "unsafe" + "github.com/msteinert/pam/v2" +) +`) + + if g.typeName != "" { + g.printf(` +var pamModuleHandler pam.ModuleHandler = &%[1]s{} +`, g.typeName) + } else { + g.printf(` +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler; +`) + } + + g.printf(` +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.%s(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} +`, transactionCreator) + + for _, f := range vFuncs { + g.printf(` +//export pam_sm_%[1]s +func pam_sm_%[1]s(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.%[2]s) +} +`, f.cName, f.goName) + } + + if !g.noMain { + g.printf("\nfunc main() {}\n") + } +} + +// format returns the gofmt-ed contents of the Generator's buffer. +func (g *Generator) format() []byte { + src, err := format.Source(g.buf.Bytes()) + if err != nil { + // Should never happen, but can arise when developing this code. + // The user can compile the output to see the error. + log.Printf("warning: internal error: invalid Go generated: %s", err) + log.Printf("warning: compile the package to analyze the error") + return g.buf.Bytes() + } + return src +} diff --git a/cmd/pam-moduler/tests/debug-module/debug-module.go b/cmd/pam-moduler/tests/debug-module/debug-module.go new file mode 100644 index 0000000..843b329 --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/debug-module.go @@ -0,0 +1,119 @@ +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -libname "pam_godebug.so" +//go:generate go generate --skip="pam_module.go" + +// This is a similar implementation of pam_debug.so + +// Package main is the package for the debug PAM module library +package main + +import ( + "fmt" + "strings" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +var pamModuleHandler pam.ModuleHandler = &DebugModule{} +var _ = pamModuleHandler + +var moduleArgsRetTypes = map[string]error{ + "success": nil, + "open_err": pam.ErrOpen, + "symbol_err": pam.ErrSymbol, + "service_err": pam.ErrService, + "system_err": pam.ErrSystem, + "buf_err": pam.ErrBuf, + "perm_denied": pam.ErrPermDenied, + "auth_err": pam.ErrAuth, + "cred_insufficient": pam.ErrCredInsufficient, + "authinfo_unavail": pam.ErrAuthinfoUnavail, + "user_unknown": pam.ErrUserUnknown, + "maxtries": pam.ErrMaxtries, + "new_authtok_reqd": pam.ErrNewAuthtokReqd, + "acct_expired": pam.ErrAcctExpired, + "session_err": pam.ErrSession, + "cred_unavail": pam.ErrCredUnavail, + "cred_expired": pam.ErrCredExpired, + "cred_err": pam.ErrCred, + "no_module_data": pam.ErrNoModuleData, + "conv_err": pam.ErrConv, + "authtok_err": pam.ErrAuthtok, + "authtok_recover_err": pam.ErrAuthtokRecovery, + "authtok_lock_busy": pam.ErrAuthtokLockBusy, + "authtok_disable_aging": pam.ErrAuthtokDisableAging, + "try_again": pam.ErrTryAgain, + "ignore": pam.ErrIgnore, + "abort": pam.ErrAbort, + "authtok_expired": pam.ErrAuthtokExpired, + "module_unknown": pam.ErrModuleUnknown, + "bad_item": pam.ErrBadItem, + "conv_again": pam.ErrConvAgain, + "incomplete": pam.ErrIncomplete, +} + +var debugModuleArgs = []string{"auth", "cred", "acct", "prechauthtok", + "chauthtok", "open_session", "close_session"} + +// DebugModule is the PAM module structure. +type DebugModule struct { + utils.BaseModule +} + +func (dm *DebugModule) getReturnType(args []string, key string) error { + var value string + for _, a := range args { + v, found := strings.CutPrefix(a, key+"=") + if found { + value = v + } + } + + if value == "" { + return fmt.Errorf("Value not found") + } + + if ret, found := moduleArgsRetTypes[value]; found { + return ret + } + return fmt.Errorf("Parameter %s not known", value) +} + +func (dm *DebugModule) handleCall(args []string, action string) error { + err := dm.getReturnType(args, action) + if err == nil { + return nil + } + + return fmt.Errorf("error %w", err) +} + +// AcctMgmt is a PAM handler. +func (dm *DebugModule) AcctMgmt(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "acct") +} + +// Authenticate is a PAM handler. +func (dm *DebugModule) Authenticate(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "auth") +} + +// ChangeAuthTok is a PAM handler. +func (dm *DebugModule) ChangeAuthTok(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "chauthtok") +} + +// OpenSession is a PAM handler. +func (dm *DebugModule) OpenSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "open_session") +} + +// CloseSession is a PAM handler. +func (dm *DebugModule) CloseSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "close_session") +} + +// SetCred is a PAM handler. +func (dm *DebugModule) SetCred(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "cred") +} diff --git a/cmd/pam-moduler/tests/debug-module/debug-module_test.go b/cmd/pam-moduler/tests/debug-module/debug-module_test.go new file mode 100644 index 0000000..8a5d58d --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/debug-module_test.go @@ -0,0 +1,120 @@ +package main + +import ( + "errors" + "fmt" + "testing" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func Test_DebugModule_ActionStatus(t *testing.T) { + t.Parallel() + + module := DebugModule{} + + for ret, expected := range moduleArgsRetTypes { + ret := ret + expected := expected + for actionName, action := range utils.Actions { + actionName := actionName + action := action + t.Run(fmt.Sprintf("%s %s", ret, actionName), func(t *testing.T) { + t.Parallel() + moduleArgs := make([]string, 0) + for _, a := range debugModuleArgs { + moduleArgs = append(moduleArgs, fmt.Sprintf("%s=%s", a, ret)) + } + + mt := pam.ModuleTransactionInvoker(nil) + var err error + + switch action { + case utils.Account: + err = module.AcctMgmt(mt, 0, moduleArgs) + case utils.Auth: + err = module.Authenticate(mt, 0, moduleArgs) + case utils.Password: + err = module.ChangeAuthTok(mt, 0, moduleArgs) + case utils.Session: + err = module.OpenSession(mt, 0, moduleArgs) + } + + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected %#v vs %#v", expected, err) + } + }) + } + } +} + +func Test_DebugModuleTransaction_ActionStatus(t *testing.T) { + t.Parallel() + if !pam.CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + modulePath := ts.GenerateModule(".", "pam_godebug.so") + + for ret, expected := range moduleArgsRetTypes { + ret := ret + expected := expected + for actionName, action := range utils.Actions { + ret := ret + expected := expected + actionName := actionName + action := action + t.Run(fmt.Sprintf("%s %s", ret, actionName), func(t *testing.T) { + t.Parallel() + serviceName := ret + "-" + actionName + moduleArgs := make([]string, 0) + for _, a := range debugModuleArgs { + moduleArgs = append(moduleArgs, fmt.Sprintf("%s=%s", a, ret)) + } + control := utils.Requisite + fallbackModule := utils.Permit + if ret == "success" { + fallbackModule = utils.Deny + control = utils.Sufficient + } + ts.CreateService(serviceName, []utils.ServiceLine{ + {Action: action, Control: control, Module: modulePath, Args: moduleArgs}, + {Action: action, Control: control, Module: fallbackModule.String(), Args: []string{}}, + }) + + tx, err := pam.StartConfDir(serviceName, "user", nil, ts.WorkDir()) + if err != nil { + t.Fatalf("start #error: %v", err) + } + defer func() { + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } + }() + + switch action { + case utils.Account: + err = tx.AcctMgmt(pam.Silent) + case utils.Auth: + err = tx.Authenticate(pam.Silent) + case utils.Password: + err = tx.ChangeAuthTok(pam.Silent) + case utils.Session: + err = tx.OpenSession(pam.Silent) + } + + if errors.Is(expected, pam.ErrIgnore) { + // Ignore can't be returned + expected = nil + } + + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected %#v vs %#v", expected, err) + } + }) + } + } +} diff --git a/cmd/pam-moduler/tests/debug-module/pam_module.go b/cmd/pam-moduler/tests/debug-module/pam_module.go new file mode 100644 index 0000000..837842e --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/pam_module.go @@ -0,0 +1,96 @@ +// Code generated by "pam-moduler -libname pam_godebug.so"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_godebug.so" -buildmode=c-shared -o pam_godebug.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication.go b/cmd/pam-moduler/tests/integration-tester-module/communication.go new file mode 100644 index 0000000..67bada4 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication.go @@ -0,0 +1,230 @@ +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "io" + "net" + "runtime" +) + +// Request is a serializable integration module tester structure request. +type Request struct { + Action string + ActionArgs []interface{} +} + +// Result is a serializable integration module tester structure result. +type Result = Request + +// NewRequest returns a new Request. +func NewRequest(action string, actionArgs ...interface{}) Request { + return Request{action, actionArgs} +} + +// GOB serializes the request in binary format. +func (r *Request) GOB() ([]byte, error) { + b := bytes.Buffer{} + e := gob.NewEncoder(&b) + if err := e.Encode(r); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// NewRequestFromGOB gets a Request from a serialized binary. +func NewRequestFromGOB(data []byte) (*Request, error) { + b := bytes.Buffer{} + b.Write(data) + d := gob.NewDecoder(&b) + + var req Request + if err := d.Decode(&req); err != nil { + return nil, err + } + return &req, nil +} + +const bufSize = 1024 + +type connectionHandler struct { + inOutData chan []byte + outErr chan error + SocketPath string +} + +// Listener is a socket listener. +type Listener struct { + connectionHandler + listener net.Listener +} + +// NewListener creates a new Listener. +func NewListener(socketPath string) *Listener { + if len(socketPath) > 90 { + // See https://manpages.ubuntu.com/manpages/jammy/man7/sys_un.h.7posix.html#application%20usage + panic(fmt.Sprintf("Socket path %s too long", socketPath)) + } + return &Listener{connectionHandler{SocketPath: socketPath}, nil} +} + +// WaitForData waits for result data (or an error) on connection to be returned. +func (c *connectionHandler) WaitForData() (*Result, error) { + data, err := <-c.inOutData, <-c.outErr + if err != nil { + if errors.Is(err, io.EOF) { + return nil, nil + } + return nil, err + } + + req, err := NewRequestFromGOB(data) + if err != nil { + return nil, err + } + + return req, nil +} + +// SendRequest sends a request to the connection. +func (c *connectionHandler) SendRequest(req *Request) error { + bytes, err := req.GOB() + if err != nil { + return err + } + + c.inOutData <- bytes + return nil +} + +// SendResult sends the Result to the connection. +func (c *connectionHandler) SendResult(res *Result) error { + return c.SendRequest(res) +} + +// DoRequest performs a Request on the connection, waiting for data. +func (c *connectionHandler) DoRequest(req *Request) (*Result, error) { + if err := c.SendRequest(req); err != nil { + return nil, err + } + + return c.WaitForData() +} + +// Send performs a request. +func (r *Request) Send(c *connectionHandler) error { + return c.SendRequest(r) +} + +// ErrAlreadyListening is the error if a listener is already set. +var ErrAlreadyListening = errors.New("listener already set") + +// StartListening initiates the unix listener. +func (l *Listener) StartListening() error { + if l.listener != nil { + return ErrAlreadyListening + } + + listener, err := net.Listen("unix", l.SocketPath) + if err != nil { + return err + } + + l.listener = listener + l.inOutData, l.outErr = make(chan []byte), make(chan error) + + go func() { + bytes, err := func() ([]byte, error) { + for { + c, err := l.listener.Accept() + if err != nil { + return nil, err + } + + for { + buf := make([]byte, bufSize) + nr, err := c.Read(buf) + if err != nil { + return buf, err + } + + data := buf[0:nr] + l.inOutData <- data + l.outErr <- nil + + _, err = c.Write(<-l.inOutData) + if err != nil { + return nil, err + } + } + } + }() + + l.inOutData <- bytes + l.outErr <- err + }() + + return nil +} + +// Connector is a connection type. +type Connector struct { + connectionHandler + connection net.Conn +} + +// NewConnector creates a new connection. +func NewConnector(socketPath string) *Connector { + return &Connector{connectionHandler{SocketPath: socketPath}, nil} +} + +// ErrAlreadyConnected is the error if a connection is already set. +var ErrAlreadyConnected = errors.New("connection already set") + +// Connect connects to a listening unix socket. +func (c *Connector) Connect() error { + if c.connection != nil { + return ErrAlreadyConnected + } + + connection, err := net.Dial("unix", c.SocketPath) + if err != nil { + return err + } + + runtime.SetFinalizer(c, func(c *Connector) { + c.connection.Close() + }) + + c.connection = connection + c.inOutData, c.outErr = make(chan []byte), make(chan error) + + go func() { + buf := make([]byte, bufSize) + writeAndRead := func() ([]byte, error) { + data := <-c.inOutData + _, err := c.connection.Write(data) + if err != nil { + return nil, err + } + + n, err := c.connection.Read(buf[:]) + if err != nil { + return nil, err + } + + return buf[0:n], nil + } + + for { + bytes, err := writeAndRead() + c.inOutData <- bytes + c.outErr <- err + } + }() + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go new file mode 100644 index 0000000..7ef01f7 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "errors" + "path/filepath" + "reflect" + "testing" + + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func ensureError(t *testing.T, err error, expected error) { + t.Helper() + if err == nil { + t.Fatalf("error was expected, got none") + } + if !errors.Is(err, expected) { + t.Fatalf("error %v was expected, got %v", err, expected) + } +} + +func ensureEqual(t *testing.T, a any, b any) { + t.Helper() + if !reflect.DeepEqual(a, b) { + t.Fatalf("values mismatch %#v vs %#v", a, b) + } +} + +func Test_Communication(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + + for _, name := range []string{"test-1", "test-2"} { + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + + listener := NewListener(socketPath) + connector := NewConnector(socketPath) + + ensureNoError(t, listener.StartListening()) + ensureNoError(t, connector.Connect()) + + ensureError(t, listener.StartListening(), ErrAlreadyListening) + ensureError(t, connector.Connect(), ErrAlreadyConnected) + + resChan, errChan := make(chan *Result), make(chan error) + go func() { + res, err := listener.WaitForData() + resChan <- res + errChan <- err + }() + + req := NewRequest("A Request") + ensureNoError(t, connector.SendRequest(&req)) + + res, err := <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, req) + + go func() { + res := NewRequest("Listener result") + ensureNoError(t, listener.SendResult(&res)) + }() + + res, err = connector.WaitForData() + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Listener result")) + + go func() { + req, err := listener.WaitForData() + res := NewRequest("Response", *req) + + defer func() { + resChan <- &res + errChan <- err + }() + ensureNoError(t, listener.SendResult(&res)) + }() + + done := make(chan bool) + req = NewRequest("Requesting...") + go func() { + defer func() { + done <- true + }() + res, err := connector.DoRequest(&req) + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + }() + + res, err = <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + <-done + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go new file mode 100644 index 0000000..fcdeaa9 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -0,0 +1,159 @@ +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule -parallel-conv +//go:generate go generate --skip="pam_module.go" + +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +type integrationTesterModule struct { + utils.BaseModule +} + +type authRequest struct { + mt pam.ModuleTransaction + lastError error +} + +func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request) (res *Result, err error) { + switch r.Action { + case "bye": + return nil, authReq.lastError + } + + defer func() { + if p := recover(); p != nil { + if s, ok := p.(string); ok { + if strings.HasPrefix(s, "reflect:") { + res = nil + err = &utils.SerializableError{Msg: fmt.Sprintf( + "error on request %v: %v", *r, p)} + authReq.lastError = err + return + } + } + panic(p) + } + + if err != nil { + authReq.lastError = err + } + }() + + method := reflect.ValueOf(authReq.mt).MethodByName(r.Action) + if method == (reflect.Value{}) { + return nil, &utils.SerializableError{Msg: fmt.Sprintf( + "no method %s found", r.Action)} + } + + var args []reflect.Value + for i, arg := range r.ActionArgs { + switch v := arg.(type) { + case SerializableStringConvRequest: + args = append(args, reflect.ValueOf( + pam.NewStringConvRequest(v.Style, v.Request))) + case SerializableBinaryConvRequest: + args = append(args, reflect.ValueOf( + pam.NewBinaryConvRequestFromBytes(v.Request))) + default: + if arg == nil { + args = append(args, reflect.Zero(method.Type().In(i))) + } else { + args = append(args, reflect.ValueOf(arg)) + } + } + } + + res = &Result{Action: "return"} + for _, ret := range method.Call(args) { + iface := ret.Interface() + switch value := iface.(type) { + case pam.StringConvResponse: + res.ActionArgs = append(res.ActionArgs, + SerializableStringConvResponse{value.Style(), value.Response()}) + case pam.BinaryConvResponse: + data, err := value.Decode(utils.TestBinaryDataDecoder) + if err != nil { + return nil, err + } + res.ActionArgs = append(res.ActionArgs, SerializableBinaryConvResponse{data}) + case pam.Error: + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, value) + case error: + var pamError pam.Error + if errors.As(value, &pamError) { + retErr := &SerializablePamError{Msg: value.Error(), + RetStatus: pamError} + authReq.lastError = retErr + res.ActionArgs = append(res.ActionArgs, retErr) + return res, err + } + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, + &utils.SerializableError{Msg: value.Error()}) + default: + res.ActionArgs = append(res.ActionArgs, iface) + } + } + return res, err +} + +func (m *integrationTesterModule) handleError(err error) *Result { + return &Result{ + Action: "error", + ActionArgs: []interface{}{&utils.SerializableError{Msg: err.Error()}}, + } +} + +func (m *integrationTesterModule) Authenticate(mt pam.ModuleTransaction, _ pam.Flags, args []string) error { + if len(args) != 1 { + return errors.New("Invalid arguments") + } + + authRequest := authRequest{mt, nil} + connection := NewConnector(args[0]) + if err := connection.Connect(); err != nil { + return err + } + + connectionHandler := func() error { + if err := connection.SendRequest(&Request{Action: "hello"}); err != nil { + return err + } + + for { + req, err := connection.WaitForData() + if err != nil { + return err + } + + res, err := m.handleRequest(&authRequest, req) + if err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + if res == nil { + return nil + } + if err := connection.SendResult(res); err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + } + } + + if err := connectionHandler(); err != nil { + return err + } + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go new file mode 100644 index 0000000..45acc70 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -0,0 +1,1281 @@ +package main + +import ( + "errors" + "fmt" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func (r *Request) check(res *Result, expectedResults []interface{}) error { + switch res.Action { + case "return": + case "error": + return fmt.Errorf("module error: %v", res.ActionArgs...) + default: + return fmt.Errorf("unexpected action %v", res.Action) + } + + if !reflect.DeepEqual(res.ActionArgs, expectedResults) { + return fmt.Errorf("unexpected return values %#v vs %#v", + res.ActionArgs, expectedResults) + } + + return nil +} + +func (r *Request) checkRemote(listener *Listener, expectedResults []interface{}) error { + res, err := listener.DoRequest(r) + if err != nil { + return err + } + + return res.check(res, expectedResults) +} + +type checkedRequest struct { + r Request + exp []interface{} + compareWithTestState bool +} + +func (cr *checkedRequest) checkRemote(listener *Listener) error { + return cr.r.checkRemote(listener, cr.exp) +} + +func (cr *checkedRequest) check(res *Result) error { + return cr.r.check(res, cr.exp) +} + +func ensureUser(tx *pam.Transaction, expected string) error { + item := pam.User + if value, err := tx.GetItem(item); err != nil { + return err + } else if value != expected { + return fmt.Errorf("invalid item %v value: %s vs %v", item, value, expected) + } + return nil +} + +func ensureEnv(tx *pam.Transaction, variable string, expected string) error { + if env := tx.GetEnv(variable); env != expected { + return fmt.Errorf("unexpected env %s value: %s vs %s", variable, env, expected) + } + return nil +} + +func (r *Request) toBytes(t *testing.T) []byte { + t.Helper() + bytes, err := r.GOB() + if err != nil { + t.Fatalf("error: %v", err) + return nil + } + return bytes +} + +func (r *Request) toTransactionData(t *testing.T) []byte { + t.Helper() + return utils.TestBinaryDataEncoder(r.toBytes(t)) +} + +func Test_Moduler_IntegrationTesterModule(t *testing.T) { + t.Parallel() + if !pam.CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + modulePath := ts.GenerateModuleDefault(ts.GetCurrentFileDir()) + + type testState = map[string]interface{} + + tests := map[string]struct { + expectedError error + user string + credentials pam.ConversationHandler + checkedRequests []checkedRequest + setup func(*pam.Transaction, *Listener, testState) error + finish func(*pam.Transaction, *Listener, testState) error + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"get-item-service", nil}, + }}, + }, + "get-item-User-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"test-user", nil}, + }}, + }, + "get-item-Authtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Authtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-Oldauthtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Oldauthtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-UserPrompt-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.UserPrompt), + exp: []interface{}{"", nil}, + }}, + }, + "set-item-Service": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.Service, "foo-service"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"foo-service", nil}, + }, + }, + }, + "set-item-User-empty": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "an-user") + }, + }, + "set-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "an-user") + }, + }, + "set-get-item-User-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "set-get-item-User-preset": { + user: "test-user", + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "get-env-unset": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_HOPEFULLY_NOT_SET"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_HOPEFULLY_NOT_SET", "") + }, + }, + "get-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"foobar"}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "foobar") + }, + }, + "get-env-preset-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "get-env-preset-unset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "a value") + }, + }, + "put-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=another value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"another value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "another value") + }, + }, + "put-env-resets-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-resets-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-unsets-not-set": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_NEVER_SET"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "put-env-unsets-empty-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{ + map[string]string{"_PAM_GO_ENV_SET_VAR": ""}, nil, + }, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }, + }, + }, + "put-env-invalid-syntax": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "="), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "=bar"), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "with spaces"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "get-env-list-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return nil + }, + }, + "get-env-list-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + for env, value := range expected { + if err := tx.PutEnv(fmt.Sprintf("%s=%s", env, value)); err != nil { + return err + } + } + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{{expected, nil}} + return nil + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("Unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + "get-env-list-module-set": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{ + nil, nil, nil, nil, nil, nil, nil, {expected, nil}, + } + return nil + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR1=value1"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR2=value due"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR3=3"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_EMPTY="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET=unset"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV WITH SPACES=yes works"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + "get-user-empty-no-conv-set": { + expectedError: pam.ErrConv, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"", pam.ErrConv}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "") + }, + }, + "get-user-empty-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "who are you? ", + ExpectedStyle: pam.PromptEchoOn, + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"replying-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "replying-user") + }, + }, + "get-user-preset-without-conv": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, + "get-user-preset-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "No message should have been shown!", + ExpectedStyle: pam.PromptEchoOn, + }, + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, + "get-data-not-available": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }}, + }, + "set-data-empty-nil": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, + "set-data-empty-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", []string{"hello", "world"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{[]string{"hello", "world"}, nil}, + }, + }, + }, + "set-data-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-error-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-error-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + }, + }, + "set-data-to-value-replacing": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", "Hello"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{"Hello", nil}, + }, + }, + }, + "set-data-to-value-unset": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, + "start-conv-no-conv-set": { + expectedError: pam.ErrConv, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-prompt-text-info": { + credentials: utils.Credentials{ + ExpectedMessage: "hello PAM!", + ExpectedStyle: pam.TextInfo, + TextInfo: "nice to see you, Go!", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + { + r: NewRequest("StartStringConvf", pam.TextInfo, "hello %s!", "PAM"), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + }, + }, + "start-conv-prompt-error-msg": { + credentials: utils.Credentials{ + ExpectedMessage: "This is wrong, PAM!", + ExpectedStyle: pam.ErrorMsg, + ErrorMsg: "ops, sorry...", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.ErrorMsg, + "This is wrong, PAM!", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.ErrorMsg, + "This is wrong, PAM!", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + { + r: NewRequest("StartStringConvf", pam.ErrorMsg, + "This is wrong, %s!", "PAM", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + }, + }, + "start-conv-prompt-echo-on": { + credentials: utils.Credentials{ + ExpectedMessage: "Give me your non-private infos", + ExpectedStyle: pam.PromptEchoOn, + EchoOn: "here's my public data", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.PromptEchoOn, + "Give me your non-private infos", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOn, + "here's my public data", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.PromptEchoOn, + "Give me your non-private infos", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOn, + "here's my public data", + }, nil}, + }, + }, + }, + "start-conv-prompt-echo-off": { + credentials: utils.Credentials{ + ExpectedMessage: "Give me your super-secret data", + ExpectedStyle: pam.PromptEchoOff, + EchoOff: "here's my private token", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.PromptEchoOff, + "Give me your super-secret data", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOff, + "here's my private token", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.PromptEchoOff, + "Give me your super-secret data", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOff, + "here's my private token", + }, nil}, + }, + }, + }, + "start-conv-text-info-handle-failure-message-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.Credentials{ + ExpectedMessage: "This is an info message", + ExpectedStyle: pam.TextInfo, + TextInfo: "And this is what is returned", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "This should have been an info message, but is not", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, + "This should have been an info message, but is not", + ), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-text-info-handle-failure-style-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.Credentials{ + ExpectedMessage: "This is an info message", + ExpectedStyle: pam.PromptEchoOff, + TextInfo: "And this is what is returned", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "This is an info message", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, + "This is an info message", + ), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-binary": { + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!"), + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffYes it is!")), + }), + exp: []interface{}{SerializableBinaryConvResponse{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, nil}, + }, + { + r: NewRequest("StartBinaryConv", + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + exp: []interface{}{SerializableBinaryConvResponse{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, nil}, + }, + }, + }, + "start-conv-binary-handle-failure-passed-data-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!"), + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + (&Request{"Not the expected binary data", nil}).toTransactionData(t), + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartBinaryConv", + (&Request{"Not the expected binary data", nil}).toTransactionData(t)), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-binary-handle-failure-returned-data-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.NewBinaryTransactionWithRandomData(100, + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + (&Request{"Wrong binary data", nil}).toTransactionData(t), + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartBinaryConv", + (&Request{"Wrong binary data", nil}).toTransactionData(t)), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-binary-in-nil": { + credentials: utils.NewBinaryTransactionWithData(nil, + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t)), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{}), + exp: []interface{}{SerializableBinaryConvResponse{ + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t), + }, nil}, + }, + { + r: NewRequest("StartBinaryConv", nil), + exp: []interface{}{SerializableBinaryConvResponse{ + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t), + }, nil}, + }, + }, + }, + "start-conv-binary-out-nil": { + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffGimme nil!"), nil), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffGimme nil!")), + }), + exp: []interface{}{SerializableBinaryConvResponse{}, nil}, + }, + { + r: NewRequest("StartBinaryConv", + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffGimme nil!"))), + exp: []interface{}{SerializableBinaryConvResponse{}, nil}, + }, + }, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + ts.CreateService(name, []utils.ServiceLine{ + {Action: utils.Auth, Control: utils.Requisite, Module: modulePath, + Args: []string{socketPath}}, + }) + + switch tc.credentials.(type) { + case pam.BinaryConversationHandler: + if !pam.CheckPamHasBinaryProtocol() { + t.Skip("Binary protocol is not supported") + } + case pam.BinaryPointerConversationHandler: + if !pam.CheckPamHasBinaryProtocol() { + t.Skip("Binary protocol is not supported") + } + } + + tx, err := pam.StartConfDir(name, tc.user, tc.credentials, ts.WorkDir()) + if err != nil { + t.Fatalf("start #error: %v", err) + } + defer func() { + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } + }() + + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + req := NewRequest("GetItem", pam.Service) + if err := req.checkRemote(listener, + []interface{}{strings.ToLower(name), nil}); err != nil { + return err + } + + testState := testState{} + if tc.setup != nil { + if err := tc.setup(tx, listener, testState); err != nil { + return err + } + } + + for i, req := range tc.checkedRequests { + if req.compareWithTestState { + expectedResults, _ := testState["expectedResults"].([][]interface{}) + if err := req.r.checkRemote(listener, expectedResults[i]); err != nil { + return err + } + } else if err := req.checkRemote(listener); err != nil { + return err + } + } + + if tc.finish != nil { + if err := tc.finish(tx, listener, testState); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- tx.Authenticate(pam.Silent) + }() + + if err = <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err = <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } + + t.Cleanup(func() { + // Ensure GC will happen, so that transaction's pam_end will be called + runtime.GC() + time.Sleep(5 * time.Millisecond) + }) +} + +func Test_Moduler_IntegrationTesterModule_handleRequest(t *testing.T) { + t.Parallel() + + module := integrationTesterModule{} + mt := pam.NewModuleTransactionInvoker(nil) + + tests := []struct { + checkedRequest + name string + parallel bool + }{ + { + name: "putEnv", + checkedRequest: checkedRequest{ + r: NewRequest("PutEnv", "FOO_ENV=Bar"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + { + parallel: true, + name: "get-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + { + parallel: true, + name: "set-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("SetItem", pam.Service, "foo"), + exp: []interface{}{pam.ErrSystem}, + }, + }, + } + + for _, cr := range tests { + cr := cr + t.Run(cr.name, func(t *testing.T) { + if cr.parallel { + t.Parallel() + } + + authRequest := authRequest{mt, nil} + res, err := module.handleRequest(&authRequest, &cr.r) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if res.Action != "return" { + t.Fatalf("unexpected result action %v", res.Action) + } + + if err := cr.check(res); err != nil { + t.Fatalf("unexpected result %v", err) + } + }) + } + + t.Run("missing-method", func(t *testing.T) { + t.Parallel() + req := NewRequest("Hopefully a missing method") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) + + t.Run("wrong-signature", func(t *testing.T) { + t.Parallel() + req := NewRequest("GetItem", "this", "and", 3, "of that") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) +} + +func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + module := integrationTesterModule{} + + tests := map[string]struct { + expectedError error + credentials pam.ConversationHandler + checkedRequests []checkedRequest + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "get-item-User": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "putEnv": { + expectedError: pam.ErrAbort, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "FooBar=Baz"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + }, + "SetData-nil": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, + "SetData": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", true), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, + "StartConv": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, + "StartStringConv": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, + "StartConv-Binary": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + for _, req := range tc.checkedRequests { + if err := req.checkRemote(listener); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- module.Authenticate( + pam.NewModuleTransactionInvoker(nil), + pam.Silent, []string{socketPath}) + }() + + if err := <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err := <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go new file mode 100644 index 0000000..e64a4f9 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go @@ -0,0 +1,95 @@ +// Code generated by "pam-moduler -type integrationTesterModule -parallel-conv"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +var pamModuleHandler pam.ModuleHandler = &integrationTesterModule{} + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvokerParallelConv(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/cmd/pam-moduler/tests/integration-tester-module/serialization.go b/cmd/pam-moduler/tests/integration-tester-module/serialization.go new file mode 100644 index 0000000..7a549c2 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/serialization.go @@ -0,0 +1,67 @@ +package main + +import ( + "encoding/gob" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +// SerializablePamError represents a [pam.Error] in a +// serializable way that splits message and return code. +type SerializablePamError struct { + Msg string + RetStatus pam.Error +} + +// NewSerializablePamError initializes a SerializablePamError from +// the default status error message. +func NewSerializablePamError(status pam.Error) SerializablePamError { + return SerializablePamError{Msg: status.Error(), RetStatus: status} +} + +func (e *SerializablePamError) Error() string { + return e.RetStatus.Error() +} + +// SerializableStringConvRequest is a serializable string request. +type SerializableStringConvRequest struct { + Style pam.Style + Request string +} + +// SerializableStringConvResponse is a serializable string response. +type SerializableStringConvResponse struct { + Style pam.Style + Response string +} + +// SerializableBinaryConvRequest is a serializable binary request. +type SerializableBinaryConvRequest struct { + Request []byte +} + +// SerializableBinaryConvResponse is a serializable binary response. +type SerializableBinaryConvResponse struct { + Response []byte +} + +func init() { + gob.Register(map[string]string{}) + gob.Register(Request{}) + gob.Register(pam.Item(0)) + gob.Register(pam.Error(0)) + gob.Register(pam.Style(0)) + gob.Register([]pam.ConvResponse{}) + gob.RegisterName("main.SerializablePamError", + SerializablePamError{}) + gob.RegisterName("main.SerializableStringConvRequest", + SerializableStringConvRequest{}) + gob.RegisterName("main.SerializableStringConvResponse", + SerializableStringConvResponse{}) + gob.RegisterName("main.SerializableBinaryConvRequest", + SerializableBinaryConvRequest{}) + gob.RegisterName("main.SerializableBinaryConvResponse", + SerializableBinaryConvResponse{}) + gob.Register(utils.SerializableError{}) +} diff --git a/cmd/pam-moduler/tests/internal/utils/base-module.go b/cmd/pam-moduler/tests/internal/utils/base-module.go new file mode 100644 index 0000000..494b077 --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/base-module.go @@ -0,0 +1,38 @@ +package utils + +import "github.com/msteinert/pam/v2" + +// BaseModule is the type for a base PAM module. +type BaseModule struct{} + +// AcctMgmt is the handler function for PAM AcctMgmt. +func (h *BaseModule) AcctMgmt(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// Authenticate is the handler function for PAM Authenticate. +func (h *BaseModule) Authenticate(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// ChangeAuthTok is the handler function for PAM ChangeAuthTok. +func (h *BaseModule) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// OpenSession is the handler function for PAM OpenSession. +func (h *BaseModule) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// CloseSession is the handler function for PAM CloseSession. +func (h *BaseModule) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// SetCred is the handler function for PAM SetCred. +func (h *BaseModule) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +var _ pam.ModuleHandler = &BaseModule{} diff --git a/cmd/pam-moduler/tests/internal/utils/base-module_test.go b/cmd/pam-moduler/tests/internal/utils/base-module_test.go new file mode 100644 index 0000000..461d90f --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/base-module_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "testing" + + "github.com/msteinert/pam/v2" +) + +func TestMain(t *testing.T) { + bm := BaseModule{} + + if bm.AcctMgmt(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.Authenticate(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.ChangeAuthTok(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.OpenSession(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.CloseSession(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.SetCred(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-setup.go b/cmd/pam-moduler/tests/internal/utils/test-setup.go new file mode 100644 index 0000000..77fc71d --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-setup.go @@ -0,0 +1,135 @@ +// Package utils contains the internal test utils +package utils + +import ( + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/msteinert/pam/v2" +) + +// TestSetup is an utility type for having a playground for test PAM modules. +type TestSetup struct { + t *testing.T + workDir string +} + +type withWorkDir struct{} + +//nolint:revive +func WithWorkDir() withWorkDir { + return withWorkDir{} +} + +// NewTestSetup creates a new TestSetup. +func NewTestSetup(t *testing.T, args ...interface{}) *TestSetup { + t.Helper() + + ts := &TestSetup{t: t} + for _, arg := range args { + switch argType := arg.(type) { + case withWorkDir: + ts.ensureWorkDir() + default: + t.Fatalf("Unknown parameter of type %v", argType) + } + } + + return ts +} + +// CreateTemporaryDir creates a temporary directory with provided basename. +func (ts *TestSetup) CreateTemporaryDir(basename string) string { + tmpDir, err := os.MkdirTemp(os.TempDir(), basename) + if err != nil { + ts.t.Fatalf("can't create service path %v", err) + } + + ts.t.Cleanup(func() { os.RemoveAll(tmpDir) }) + return tmpDir +} + +func (ts *TestSetup) ensureWorkDir() string { + if ts.workDir != "" { + return ts.workDir + } + + ts.workDir = ts.CreateTemporaryDir("go-pam-*") + return ts.workDir +} + +// WorkDir returns the test setup work directory. +func (ts TestSetup) WorkDir() string { + return ts.workDir +} + +// GenerateModule generates a PAM module for the provided path and name. +func (ts *TestSetup) GenerateModule(testModulePath string, moduleName string) string { + cmd := exec.Command("go", "generate", "-C", testModulePath) + out, err := cmd.CombinedOutput() + if err != nil { + ts.t.Fatalf("can't build pam module %v: %s", err, out) + } + + builtFile := filepath.Join(cmd.Dir, testModulePath, moduleName) + modulePath := filepath.Join(ts.ensureWorkDir(), filepath.Base(builtFile)) + if err = os.Rename(builtFile, modulePath); err != nil { + ts.t.Fatalf("can't move module: %v", err) + os.Remove(builtFile) + } + + return modulePath +} + +func (ts TestSetup) currentFile(skip int) string { + _, currentFile, _, ok := runtime.Caller(skip) + if !ok { + ts.t.Fatalf("can't get current binary path") + } + return currentFile +} + +// GetCurrentFile returns the current file path. +func (ts TestSetup) GetCurrentFile() string { + // This is a library so we care about the caller location + return ts.currentFile(2) +} + +// GetCurrentFileDir returns the current file directory. +func (ts TestSetup) GetCurrentFileDir() string { + return filepath.Dir(ts.currentFile(2)) +} + +// GenerateModuleDefault generates a default module. +func (ts *TestSetup) GenerateModuleDefault(testModulePath string) string { + return ts.GenerateModule(testModulePath, "pam_go.so") +} + +// CreateService creates a service file. +func (ts *TestSetup) CreateService(serviceName string, services []ServiceLine) string { + if !pam.CheckPamHasStartConfdir() { + ts.t.Skip("PAM has no support for custom service paths") + return "" + } + + serviceName = strings.ToLower(serviceName) + serviceFile := filepath.Join(ts.ensureWorkDir(), serviceName) + var contents = []string{} + + for _, s := range services { + contents = append(contents, strings.TrimRight(strings.Join([]string{ + s.Action.String(), s.Control.String(), s.Module, strings.Join(s.Args, " "), + }, "\t"), "\t")) + } + + if err := os.WriteFile(serviceFile, + []byte(strings.Join(contents, "\n")), 0600); err != nil { + ts.t.Fatalf("can't create service file %v: %v", serviceFile, err) + } + + return serviceFile +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-setup_test.go b/cmd/pam-moduler/tests/internal/utils/test-setup_test.go new file mode 100644 index 0000000..f8a17a6 --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-setup_test.go @@ -0,0 +1,180 @@ +package utils + +import ( + "fmt" + "math/rand" + "os" + "path/filepath" + "strings" + "testing" +) + +func isDir(t *testing.T, path string) bool { + t.Helper() + if file, err := os.Open(path); err == nil { + if fileInfo, err := file.Stat(); err == nil { + return fileInfo.IsDir() + } + t.Fatalf("error: %v", err) + } else { + t.Fatalf("error: %v", err) + } + return false +} + +func Test_CreateTemporaryDir(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t) + dir := ts.CreateTemporaryDir("") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } + + dir = ts.CreateTemporaryDir("foo-prefix-*") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } +} + +func Test_TestSetupWithWorkDir(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t, WithWorkDir()) + if !isDir(t, ts.WorkDir()) { + t.Fatalf("%s not a directory", ts.WorkDir()) + } +} + +func Test_CreateService(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t) + + tests := map[string]struct { + services []ServiceLine + expectedContent string + }{ + "empty": {}, + "CApital-Empty": {}, + "auth-sufficient-permit": { + services: []ServiceLine{ + {Auth, Sufficient, Permit.String(), []string{}}, + }, + expectedContent: "auth sufficient pam_permit.so", + }, + "auth-sufficient-permit-args": { + services: []ServiceLine{ + {Auth, Required, Deny.String(), []string{"a b c [d e]"}}, + }, + expectedContent: "auth required pam_deny.so a b c [d e]", + }, + "complete-custom": { + services: []ServiceLine{ + {Account, Required, "pam_account_module.so", []string{"a", "b", "c", "[d e]"}}, + {Account, Required, Deny.String(), []string{}}, + {Auth, Requisite, "pam_auth_module.so", []string{}}, + {Auth, Requisite, Deny.String(), []string{}}, + {Password, Sufficient, "pam_password_module.so", []string{"arg"}}, + {Password, Sufficient, Deny.String(), []string{}}, + {Session, Optional, "pam_session_module.so", []string{""}}, + {Session, Optional, Deny.String(), []string{}}, + }, + expectedContent: `account required pam_account_module.so a b c [d e] +account required pam_deny.so +auth requisite pam_auth_module.so +auth requisite pam_deny.so +password sufficient pam_password_module.so arg +password sufficient pam_deny.so +session optional pam_session_module.so +session optional pam_deny.so`, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + service := ts.CreateService(name, tc.services) + + if filepath.Base(service) != strings.ToLower(name) { + t.Fatalf("Invalid service name %s", service) + } + + if bytes, err := os.ReadFile(service); err != nil { + t.Fatalf("Failed reading %s: %v", service, err) + } else { + if string(bytes) != tc.expectedContent { + t.Fatalf("Unexpected file content:\n%s\n---\n%s", + tc.expectedContent, string(bytes)) + } + } + }) + } +} + +func Test_GenerateModule(t *testing.T) { + ts := NewTestSetup(t) + dir := ts.CreateTemporaryDir("") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } + + f, err := os.Create(filepath.Join(dir, "test-generate.go")) + if err != nil { + t.Fatalf("can't create file %v", err) + } + defer f.Close() + + randomName := "" + for i := 0; i < 10; i++ { + // #nosec:G404 - it's a test, we don't care. + randomName += string(byte('a' + rand.Intn('z'-'a'))) + } + + wantFile := randomName + ".so" + fmt.Fprintf(f, `//go:generate touch %s +package generate_file +`, wantFile) + + mod, err := os.Create(filepath.Join(dir, "go.mod")) + if err != nil { + t.Fatalf("can't create file %v", err) + } + defer mod.Close() + + fmt.Fprintf(mod, `module example.com/greetings + +go 1.20 +`) + + fakeModule := ts.GenerateModule(dir, wantFile) + if _, err := os.Stat(fakeModule); err != nil { + t.Fatalf("module not generated %v", err) + } + + fmt.Fprint(f, `//go:generate touch pam_go.so +package generate_file +`, wantFile) + + fakeModule = ts.GenerateModuleDefault(dir) + if _, err := os.Stat(fakeModule); err != nil { + t.Fatalf("module not generated %v", err) + } +} + +func Test_GetCurrentFileDir(t *testing.T) { + t.Parallel() + + ts := NewTestSetup(t) + if !strings.HasSuffix(ts.GetCurrentFileDir(), filepath.Join("internal", "utils")) { + t.Fatalf("unexpected file %v", ts.GetCurrentFileDir()) + } +} + +func Test_GetCurrentFile(t *testing.T) { + t.Parallel() + + ts := NewTestSetup(t) + if !strings.HasSuffix(ts.GetCurrentFile(), filepath.Join("utils", "test-setup_test.go")) { + t.Fatalf("unexpected file %v", ts.GetCurrentFile()) + } +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go new file mode 100644 index 0000000..fd6f11b --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -0,0 +1,263 @@ +// Package utils contains the internal test utils +package utils + +//#include +import "C" + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "reflect" + "unsafe" + + "github.com/msteinert/pam/v2" +) + +// Action represents a PAM action to perform. +type Action int + +const ( + // Account is the account. + Account Action = iota + 1 + // Auth is the auth. + Auth + // Password is the password. + Password + // Session is the session. + Session +) + +func (a Action) String() string { + switch a { + case Account: + return "account" + case Auth: + return "auth" + case Password: + return "password" + case Session: + return "session" + default: + return "" + } +} + +// Actions is a map with all the available Actions by their name. +var Actions = map[string]Action{ + Account.String(): Account, + Auth.String(): Auth, + Password.String(): Password, + Session.String(): Session, +} + +// Control represents how a PAM module should controlled in PAM service file. +type Control int + +const ( + // Required implies that the module is required. + Required Control = iota + 1 + // Requisite implies that the module is requisite. + Requisite + // Sufficient implies that the module is sufficient. + Sufficient + // Optional implies that the module is optional. + Optional +) + +func (c Control) String() string { + switch c { + case Required: + return "required" + case Requisite: + return "requisite" + case Sufficient: + return "sufficient" + case Optional: + return "optional" + default: + return "" + } +} + +// ServiceLine is the representation of a PAM module service file line. +type ServiceLine struct { + Action Action + Control Control + Module string + Args []string +} + +// FallBackModule is a type to represent the module that should be used as fallback. +type FallBackModule int + +const ( + // NoFallback add no fallback module. + NoFallback FallBackModule = iota + 1 + // Permit uses a module that always permits. + Permit + // Deny uses a module that always denys. + Deny +) + +func (a FallBackModule) String() string { + switch a { + case Permit: + return "pam_permit.so" + case Deny: + return "pam_deny.so" + default: + return "" + } +} + +// SerializableError is a representation of an error in a way can be serialized. +type SerializableError struct { + Msg string +} + +func (e *SerializableError) Error() string { + return e.Msg +} + +// Credentials is a test [pam.ConversationHandler] implementation. +type Credentials struct { + User string + Password string + EchoOn string + EchoOff string + TextInfo string + ErrorMsg string + ExpectedMessage string + CheckEmptyMessage bool + ExpectedStyle pam.Style + CheckZeroStyle bool + Context interface{} +} + +// RespondPAM handles PAM string conversations. +func (c Credentials) RespondPAM(s pam.Style, msg string) (string, error) { + if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && + msg != c.ExpectedMessage { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected prompt: %s vs %s", msg, c.ExpectedMessage), + }) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected style: %#v vs %#v", s, c.ExpectedStyle), + }) + } + + switch s { + case pam.PromptEchoOn: + if c.User != "" { + return c.User, nil + } + return c.EchoOn, nil + case pam.PromptEchoOff: + if c.Password != "" { + return c.Password, nil + } + return c.EchoOff, nil + case pam.TextInfo: + return c.TextInfo, nil + case pam.ErrorMsg: + return c.ErrorMsg, nil + } + + return "", errors.Join(pam.ErrConv, + &SerializableError{fmt.Sprintf("unhandled style: %v", s)}) +} + +// BinaryTransaction represents a binary PAM transaction handler struct. +type BinaryTransaction struct { + data []byte + ExpectedNull bool + ReturnedData []byte +} + +// TestBinaryDataEncoder encodes a test binary data. +func TestBinaryDataEncoder(bytes []byte) []byte { + if len(bytes) > 0xff { + panic("Binary transaction size not supported") + } + + if bytes == nil { + return bytes + } + + data := make([]byte, 0, len(bytes)+1) + data = append(data, byte(len(bytes))) + data = append(data, bytes...) + return data +} + +// TestBinaryDataDecoder decodes a test binary data. +func TestBinaryDataDecoder(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + + length := uint8(*((*C.uint8_t)(ptr))) + if length == 0 { + return []byte{}, nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(length+1))[1:], nil +} + +// NewBinaryTransactionWithData creates a new [pam.BinaryTransaction] from bytes. +func NewBinaryTransactionWithData(data []byte, retData []byte) BinaryTransaction { + t := BinaryTransaction{ReturnedData: retData} + t.data = TestBinaryDataEncoder(data) + t.ExpectedNull = data == nil + return t +} + +// NewBinaryTransactionWithRandomData creates a new [pam.BinaryTransaction] with random data. +func NewBinaryTransactionWithRandomData(size uint8, retData []byte) BinaryTransaction { + t := BinaryTransaction{ReturnedData: retData} + randomData := make([]byte, size) + if err := binary.Read(rand.Reader, binary.LittleEndian, &randomData); err != nil { + panic(err) + } + + t.data = TestBinaryDataEncoder(randomData) + return t +} + +// Data returns the bytes of the transaction. +func (b BinaryTransaction) Data() []byte { + return b.data +} + +// RespondPAM (not) handles the PAM string conversations. +func (b BinaryTransaction) RespondPAM(s pam.Style, msg string) (string, error) { + return "", errors.Join(pam.ErrConv, + &SerializableError{"unexpected non-binary request"}) +} + +// RespondPAMBinary handles the PAM binary conversations. +func (b BinaryTransaction) RespondPAMBinary(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil && !b.ExpectedNull { + return nil, errors.Join(pam.ErrConv, + &SerializableError{"unexpected null binary data"}) + } else if ptr == nil { + return TestBinaryDataEncoder(b.ReturnedData), nil + } + + bytes, _ := TestBinaryDataDecoder(ptr) + if !reflect.DeepEqual(bytes, b.data[1:]) { + return nil, errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("data mismatch %#v vs %#v", bytes, b.data[1:]), + }) + } + + return TestBinaryDataEncoder(b.ReturnedData), nil +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..2f81a9a --- /dev/null +++ b/errors.go @@ -0,0 +1,94 @@ +package pam + +/* +#include +*/ +import "C" + +// Error is the Type for PAM Return types +type Error int + +// Pam Return types +const ( + // OpenErr indicates a dlopen() failure when dynamically loading a + // service module. + ErrOpen Error = C.PAM_OPEN_ERR + // ErrSymbol indicates a symbol not found. + ErrSymbol Error = C.PAM_SYMBOL_ERR + // ErrService indicates a error in service module. + ErrService Error = C.PAM_SERVICE_ERR + // ErrSystem indicates a system error. + ErrSystem Error = C.PAM_SYSTEM_ERR + // ErrBuf indicates a memory buffer error. + ErrBuf Error = C.PAM_BUF_ERR + // ErrPermDenied indicates a permission denied. + ErrPermDenied Error = C.PAM_PERM_DENIED + // ErrAuth indicates a authentication failure. + ErrAuth Error = C.PAM_AUTH_ERR + // ErrCredInsufficient indicates a can not access authentication data due to + // insufficient credentials. + ErrCredInsufficient Error = C.PAM_CRED_INSUFFICIENT + // ErrAuthinfoUnavail indicates that the underlying authentication service + // can not retrieve authentication information. + ErrAuthinfoUnavail Error = C.PAM_AUTHINFO_UNAVAIL + // ErrUserUnknown indicates a user not known to the underlying authentication + // module. + ErrUserUnknown Error = C.PAM_USER_UNKNOWN + // ErrMaxtries indicates that an authentication service has maintained a retry + // count which has been reached. No further retries should be attempted. + ErrMaxtries Error = C.PAM_MAXTRIES + // ErrNewAuthtokReqd indicates a new authentication token required. This is + // normally returned if the machine security policies require that the + // password should be changed because the password is nil or it has aged. + ErrNewAuthtokReqd Error = C.PAM_NEW_AUTHTOK_REQD + // ErrAcctExpired indicates that an user account has expired. + ErrAcctExpired Error = C.PAM_ACCT_EXPIRED + // ErrSession indicates a can not make/remove an entry for the + // specified session. + ErrSession Error = C.PAM_SESSION_ERR + // ErrCredUnavail indicates that an underlying authentication service can not + // retrieve user credentials. + ErrCredUnavail Error = C.PAM_CRED_UNAVAIL + // ErrCredExpired indicates that an user credentials expired. + ErrCredExpired Error = C.PAM_CRED_EXPIRED + // ErrCred indicates a failure setting user credentials. + ErrCred Error = C.PAM_CRED_ERR + // ErrNoModuleData indicates a no module specific data is present. + ErrNoModuleData Error = C.PAM_NO_MODULE_DATA + // ErrConv indicates a conversation error. + ErrConv Error = C.PAM_CONV_ERR + // ErrAuthtokErr indicates an authentication token manipulation error. + ErrAuthtok Error = C.PAM_AUTHTOK_ERR + // ErrAuthtokRecoveryErr indicates an authentication information cannot + // be recovered. + ErrAuthtokRecovery Error = C.PAM_AUTHTOK_RECOVERY_ERR + // ErrAuthtokLockBusy indicates am authentication token lock busy. + ErrAuthtokLockBusy Error = C.PAM_AUTHTOK_LOCK_BUSY + // ErrAuthtokDisableAging indicates an authentication token aging disabled. + ErrAuthtokDisableAging Error = C.PAM_AUTHTOK_DISABLE_AGING + // ErrTryAgain indicates a preliminary check by password service. + ErrTryAgain Error = C.PAM_TRY_AGAIN + // ErrIgnore indicates to ignore underlying account module regardless of + // whether the control flag is required, optional, or sufficient. + ErrIgnore Error = C.PAM_IGNORE + // ErrAbort indicates a critical error (module fail now request). + ErrAbort Error = C.PAM_ABORT + // ErrAuthtokExpired indicates an user's authentication token has expired. + ErrAuthtokExpired Error = C.PAM_AUTHTOK_EXPIRED + // ErrModuleUnknown indicates a module is not known. + ErrModuleUnknown Error = C.PAM_MODULE_UNKNOWN + // ErrBadItem indicates a bad item passed to pam_*_item(). + ErrBadItem Error = C.PAM_BAD_ITEM + // ErrConvAgain indicates a conversation function is event driven and data + // is not available yet. + ErrConvAgain Error = C.PAM_CONV_AGAIN + // ErrIncomplete indicates to please call this function again to complete + // authentication stack. Before calling again, verify that conversation + // is completed. + ErrIncomplete Error = C.PAM_INCOMPLETE +) + +// Error returns the error message for the given status. +func (status Error) Error() string { + return C.GoString(C.pam_strerror(nil, C.int(status))) +} diff --git a/example-module/module.go b/example-module/module.go new file mode 100644 index 0000000..634e3ac --- /dev/null +++ b/example-module/module.go @@ -0,0 +1,50 @@ +// These go:generate directive allow to generate the module by just using +// `go generate` once in the module directory. +// This is not strictly needed + +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler +//go:generate go generate --skip="pam_module.go" + +// Package main provides the module shared library. +package main + +import ( + "fmt" + + "github.com/msteinert/pam/v2" +) + +type exampleHandler struct{} + +var pamModuleHandler pam.ModuleHandler = &exampleHandler{} +var _ = pamModuleHandler + +// AcctMgmt is the module handle function for account management. +func (h *exampleHandler) AcctMgmt(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("AcctMgmt not implemented: %w", pam.ErrIgnore) +} + +// Authenticate is the module handle function for authentication. +func (h *exampleHandler) Authenticate(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return pam.ErrAuthinfoUnavail +} + +// ChangeAuthTok is the module handle function for changing authentication token. +func (h *exampleHandler) ChangeAuthTok(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("ChangeAuthTok not implemented: %w", pam.ErrIgnore) +} + +// OpenSession is the module handle function for open session. +func (h *exampleHandler) OpenSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("OpenSession not implemented: %w", pam.ErrIgnore) +} + +// CloseSession is the module handle function for close session. +func (h *exampleHandler) CloseSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("CloseSession not implemented: %w", pam.ErrIgnore) +} + +// SetCred is the module handle function for set credentials. +func (h *exampleHandler) SetCred(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("SetCred not implemented: %w", pam.ErrIgnore) +} diff --git a/example-module/pam_module.go b/example-module/pam_module.go new file mode 100644 index 0000000..080e97c --- /dev/null +++ b/example-module/pam_module.go @@ -0,0 +1,96 @@ +// Code generated by "pam-moduler "; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/example_test.go b/example_test.go index 8a347a6..6782bdf 100644 --- a/example_test.go +++ b/example_test.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/msteinert/pam" + "github.com/msteinert/pam/v2" "golang.org/x/term" ) @@ -14,7 +14,7 @@ import ( // should cause PAM to ask its conversation handler for a username and password // in sequence. func Example() { - t, err := pam.StartFunc("", "", func(s pam.Style, msg string) (string, error) { + t, err := pam.StartFunc("passwd", "", func(s pam.Style, msg string) (string, error) { switch s { case pam.PromptEchoOff: fmt.Print(msg) @@ -40,12 +40,19 @@ func Example() { } }) if err != nil { - fmt.Fprintf(os.Stderr, "start: %s\n", err.Error()) + fmt.Fprintf(os.Stderr, "start: %v\n", err) os.Exit(1) } + defer func() { + err := t.End() + if err != nil { + fmt.Fprintf(os.Stderr, "end: %v\n", err) + os.Exit(1) + } + }() err = t.Authenticate(0) if err != nil { - fmt.Fprintf(os.Stderr, "authenticate: %s\n", err.Error()) + fmt.Fprintf(os.Stderr, "authenticate: %v\n", err) os.Exit(1) } fmt.Println("authentication succeeded!") diff --git a/go.mod b/go.mod index b3e7d5f..b30e2a4 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/msteinert/pam +module github.com/msteinert/pam/v2 go 1.20 diff --git a/module-transaction-mock.go b/module-transaction-mock.go new file mode 100644 index 0000000..c76087d --- /dev/null +++ b/module-transaction-mock.go @@ -0,0 +1,233 @@ +//go:build !go_pam_module + +package pam + +/* +#cgo CFLAGS: -Wall -std=c99 +#include +#include +#include + +void init_pam_conv(struct pam_conv *conv, uintptr_t appdata); +*/ +import "C" + +import ( + "errors" + "fmt" + "reflect" + "runtime" + "runtime/cgo" + "testing" + "unsafe" +) + +type mockModuleTransactionExpectations struct { + UserPrompt string + DataKey string +} + +type mockModuleTransactionReturnedData struct { + User string + InteractiveUser bool + Status Error +} + +type mockModuleTransaction struct { + moduleTransaction + T *testing.T + Expectations mockModuleTransactionExpectations + RetData mockModuleTransactionReturnedData + ConversationHandler ConversationHandler + moduleData map[string]uintptr + allocatedData []unsafe.Pointer + binaryProtocol bool +} + +func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + m.moduleData = make(map[string]uintptr) + m.binaryProtocol = true + runtime.SetFinalizer(m, func(m *mockModuleTransaction) { + for _, ptr := range m.allocatedData { + C.free(ptr) + } + for _, handle := range m.moduleData { + _go_pam_data_cleanup(nil, C.uintptr_t(handle), C.PAM_DATA_SILENT) + } + }) + return m +} + +func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + goPrompt := C.GoString(prompt) + if goPrompt != m.Expectations.UserPrompt { + m.T.Fatalf("unexpected prompt: %s vs %s", goPrompt, m.Expectations.UserPrompt) + return C.int(ErrAbort) + } + + user := m.RetData.User + if m.RetData.InteractiveUser || (m.RetData.User == "" && m.ConversationHandler != nil) { + if m.ConversationHandler == nil { + m.T.Fatalf("no conversation handler provided") + } + u, err := m.ConversationHandler.RespondPAM(PromptEchoOn, goPrompt) + user = u + + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + return C.int(ErrAbort) + } + } + + cUser := C.CString(user) + m.allocatedData = append(m.allocatedData, unsafe.Pointer(cUser)) + + *outUser = cUser + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if handle, ok := m.moduleData[goKey]; ok { + *outHandle = C.uintptr_t(handle) + } else { + *outHandle = 0 + } + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if oldHandle, ok := m.moduleData[goKey]; ok { + _go_pam_data_cleanup(nil, C.uintptr_t(oldHandle), C.PAM_DATA_REPLACE) + } + if handle != 0 { + m.moduleData[goKey] = uintptr(handle) + } + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) getConv() (*C.struct_pam_conv, error) { + if m.ConversationHandler != nil { + conv := C.struct_pam_conv{} + handler := cgo.NewHandle(m.ConversationHandler) + C.init_pam_conv(&conv, C.uintptr_t(handler)) + return &conv, nil + } + if C.int(m.RetData.Status) != success { + return nil, m.RetData.Status + } + return nil, nil +} + +func (m *mockModuleTransaction) hasBinaryProtocol() bool { + return m.binaryProtocol +} + +type mockConversationHandler struct { + User string + PromptEchoOn string + PromptEchoOff string + TextInfo string + ErrorMsg string + Binary []byte + ExpectedMessage string + ExpectedMessagesByStyle map[Style]string + ExpectedNil bool + ExpectedBinary []byte + CheckEmptyMessage bool + ExpectedStyle Style + CheckZeroStyle bool + IgnoreUnknownStyle bool +} + +func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) { + var expectedMsg = c.ExpectedMessage + if msg, ok := c.ExpectedMessagesByStyle[s]; ok { + expectedMsg = msg + } + + if (expectedMsg != "" || c.CheckEmptyMessage) && + msg != expectedMsg { + return "", fmt.Errorf("%w: unexpected prompt: %s vs %s", + ErrConv, msg, c.ExpectedMessage) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", fmt.Errorf("%w: unexpected style: %#v vs %#v", + ErrConv, s, c.ExpectedStyle) + } + + switch s { + case PromptEchoOn: + if c.User != "" { + return c.User, nil + } + return c.PromptEchoOn, nil + case PromptEchoOff: + return c.PromptEchoOff, nil + case TextInfo: + return c.TextInfo, nil + case ErrorMsg: + return c.ErrorMsg, nil + } + + if c.IgnoreUnknownStyle { + return c.ExpectedMessage, nil + } + + return "", fmt.Errorf("%w: unhandled style: %v", ErrConv, s) +} + +func testBinaryDataEncoder(bytes []byte) []byte { + if len(bytes) > 0xff { + panic("Binary transaction size not supported") + } + + if bytes == nil { + return bytes + } + + data := make([]byte, 0, len(bytes)+1) + data = append(data, byte(len(bytes))) + data = append(data, bytes...) + return data +} + +func testBinaryDataDecoder(ptr BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + + length := uint8(*((*C.uint8_t)(ptr))) + if length == 0 { + return []byte{}, nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(length+1))[1:], nil +} + +func (c mockConversationHandler) RespondPAMBinary(ptr BinaryPointer) ([]byte, error) { + if ptr == nil && !c.ExpectedNil { + return nil, fmt.Errorf("%w: unexpected null binary data", ErrConv) + } else if ptr == nil { + return testBinaryDataEncoder(c.Binary), nil + } + + bytes, _ := testBinaryDataDecoder(ptr) + if !reflect.DeepEqual(bytes, c.ExpectedBinary) { + return nil, fmt.Errorf("%w: data mismatch %#v vs %#v", + ErrConv, bytes, c.ExpectedBinary) + } + + return testBinaryDataEncoder(c.Binary), nil +} diff --git a/module-transaction.go b/module-transaction.go new file mode 100644 index 0000000..fc754a1 --- /dev/null +++ b/module-transaction.go @@ -0,0 +1,627 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +/* +#include "transaction.h" +*/ +import "C" + +import ( + "errors" + "fmt" + "runtime" + "runtime/cgo" + "sync" + "sync/atomic" + "unsafe" +) + +const maxNumMsg = C.PAM_MAX_NUM_MSG + +// ModuleTransaction is an interface that a pam module transaction +// should implement. +type ModuleTransaction interface { + SetItem(Item, string) error + GetItem(Item) (string, error) + PutEnv(nameVal string) error + GetEnv(name string) string + GetEnvList() (map[string]string, error) + GetUser(prompt string) (string, error) + SetData(key string, data any) error + GetData(key string) (any, error) + StartStringConv(style Style, prompt string) (StringConvResponse, error) + StartStringConvf(style Style, format string, args ...interface{}) ( + StringConvResponse, error) + StartBinaryConv([]byte) (BinaryConvResponse, error) + StartConv(ConvRequest) (ConvResponse, error) + StartConvMulti([]ConvRequest) ([]ConvResponse, error) +} + +// ModuleHandlerFunc is a function type used by the ModuleHandler. +type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error + +// ModuleTransaction is the module-side handle for a PAM transaction. +type moduleTransaction struct { + transactionBase + convMutex *sync.Mutex +} + +// ModuleHandler is an interface for objects that can be used to create +// PAM modules from go. +type ModuleHandler interface { + AcctMgmt(ModuleTransaction, Flags, []string) error + Authenticate(ModuleTransaction, Flags, []string) error + ChangeAuthTok(ModuleTransaction, Flags, []string) error + CloseSession(ModuleTransaction, Flags, []string) error + OpenSession(ModuleTransaction, Flags, []string) error + SetCred(ModuleTransaction, Flags, []string) error +} + +// ModuleTransactionInvoker is an interface that a pam module transaction +// should implement to redirect requests from C handlers to go, +type ModuleTransactionInvoker interface { + ModuleTransaction + InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error +} + +// NewModuleTransactionParallelConv allows initializing a transaction from the +// module side. Conversations using this transaction can be multi-thread, but +// this requires the application loading the module to support this, otherwise +// we may just break their assumptions. +func NewModuleTransactionParallelConv(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}, nil} +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from the +// module side. +func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { + return &moduleTransaction{transactionBase{handle: handle}, &sync.Mutex{}} +} + +// NewModuleTransactionInvokerParallelConv allows initializing a transaction invoker +// from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionInvokerParallelConv(handle NativeHandle) ModuleTransactionInvoker { + return &moduleTransaction{transactionBase{handle: handle}, nil} +} + +func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, + flags Flags, args []string) error { + invoker := func() error { + if handler == nil { + return ErrIgnore + } + err := handler(m, flags, args) + if err != nil { + service, _ := m.GetItem(Service) + + var pamErr Error + if !errors.As(err, &pamErr) { + err = ErrSystem + } + + if pamErr == ErrIgnore || service == "" { + return err + } + + return fmt.Errorf("%s failed: %w", service, err) + } + return nil + } + err := invoker() + if errors.Is(err, Error(0)) { + err = nil + } + var status int32 + if err != nil { + status = int32(ErrSystem) + + var pamErr Error + if errors.As(err, &pamErr) { + status = int32(pamErr) + } + } + m.lastStatus.Store(status) + return err +} + +type moduleTransactionIface interface { + getUser(outUser **C.char, prompt *C.char) C.int + setData(key *C.char, handle C.uintptr_t) C.int + getData(key *C.char, outHandle *C.uintptr_t) C.int + getConv() (*C.struct_pam_conv, error) + hasBinaryProtocol() bool + startConv(conv *C.struct_pam_conv, nMsg C.int, + messages **C.struct_pam_message, + outResponses **C.struct_pam_response) C.int +} + +func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + return C.pam_get_user(m.handle, outUser, prompt) +} + +// getUserImpl is the default implementation for GetUser, but kept as private so +// that can be used to test the pam package +func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, + prompt string) (string, error) { + var user *C.char + var cPrompt = C.CString(prompt) + defer C.free(unsafe.Pointer(cPrompt)) + err := m.handlePamStatus(iface.getUser(&user, cPrompt)) + if err != nil { + return "", err + } + return C.GoString(user), nil +} + +// GetUser is similar to GetItem(User), but it would start a conversation if +// no user is currently set in PAM. +func (m *moduleTransaction) GetUser(prompt string) (string, error) { + return m.getUserImpl(m, prompt) +} + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *moduleTransaction) SetData(key string, data any) error { + return m.setDataImpl(m, key, data) +} + +func (m *moduleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + return C.set_data(m.handle, key, handle) +} + +// setDataImpl is the implementation for SetData for testing purposes. +func (m *moduleTransaction) setDataImpl(iface moduleTransactionIface, + key string, data any) error { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle cgo.Handle + if data != nil { + handle = cgo.NewHandle(data) + } + return m.handlePamStatus(iface.setData(cKey, C.uintptr_t(handle))) +} + +//export _go_pam_data_cleanup +func _go_pam_data_cleanup(h NativeHandle, handle C.uintptr_t, status C.int) { + cgo.Handle(handle).Delete() +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *moduleTransaction) GetData(key string) (any, error) { + return m.getDataImpl(m, key) +} + +func (m *moduleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + return C.get_data(m.handle, key, outHandle) +} + +// getDataImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) getDataImpl(iface moduleTransactionIface, + key string) (any, error) { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle C.uintptr_t + if err := m.handlePamStatus(iface.getData(cKey, &handle)); err != nil { + return nil, err + } + if goHandle := cgo.Handle(handle); goHandle != cgo.Handle(0) { + return goHandle.Value(), nil + } + + return nil, m.handlePamStatus(C.int(ErrNoModuleData)) +} + +// getConv is a private function to get the conversation pointer to be used +// with C.do_conv() to initiate conversations. +func (m *moduleTransaction) getConv() (*C.struct_pam_conv, error) { + var convPtr unsafe.Pointer + + if err := m.handlePamStatus( + C.pam_get_item(m.handle, C.PAM_CONV, &convPtr)); err != nil { + return nil, err + } + + return (*C.struct_pam_conv)(convPtr), nil +} + +// ConvRequest is an interface that all the Conversation requests should +// implement. +type ConvRequest interface { + Style() Style +} + +// ConvResponse is an interface that all the Conversation responses should +// implement. +type ConvResponse interface { + Style() Style +} + +// StringConvRequest is a ConvRequest for performing text-based conversations. +type StringConvRequest struct { + style Style + prompt string +} + +// NewStringConvRequest creates a new StringConvRequest. +func NewStringConvRequest(style Style, prompt string) StringConvRequest { + return StringConvRequest{style, prompt} +} + +// Style returns the conversation style of the StringConvRequest. +func (s StringConvRequest) Style() Style { + return s.style +} + +// Prompt returns the conversation style of the StringConvRequest. +func (s StringConvRequest) Prompt() string { + return s.prompt +} + +// StringConvResponse is an interface that string Conversation responses implements. +type StringConvResponse interface { + ConvResponse + Response() string +} + +// stringConvResponse is a StringConvResponse implementation used for text-based +// conversation responses. +type stringConvResponse struct { + style Style + response string +} + +// Style returns the conversation style of the StringConvResponse. +func (s stringConvResponse) Style() Style { + return s.style +} + +// Response returns the string response of the conversation. +func (s stringConvResponse) Response() string { + return s.response +} + +// BinaryFinalizer is a type of function that can be used to release +// the binary when it's not required anymore +type BinaryFinalizer func(BinaryPointer) + +// BinaryConvRequester is the interface that binary ConvRequests should +// implement +type BinaryConvRequester interface { + ConvRequest + Pointer() BinaryPointer + CreateResponse(BinaryPointer) BinaryConvResponse + Release() +} + +// BinaryConvRequest is a ConvRequest for performing binary conversations. +type BinaryConvRequest struct { + ptr atomic.Uintptr + finalizer BinaryFinalizer + responseFinalizer BinaryFinalizer +} + +// NewBinaryConvRequestFull creates a new BinaryConvRequest with finalizer +// for response BinaryResponse. +func NewBinaryConvRequestFull(ptr BinaryPointer, finalizer BinaryFinalizer, + responseFinalizer BinaryFinalizer) *BinaryConvRequest { + b := &BinaryConvRequest{finalizer: finalizer, responseFinalizer: responseFinalizer} + b.ptr.Store(uintptr(ptr)) + if ptr == nil || finalizer == nil { + return b + } + + // The ownership of the data here is temporary + runtime.SetFinalizer(b, func(b *BinaryConvRequest) { b.Release() }) + return b +} + +// NewBinaryConvRequest creates a new BinaryConvRequest +func NewBinaryConvRequest(ptr BinaryPointer, finalizer BinaryFinalizer) *BinaryConvRequest { + return NewBinaryConvRequestFull(ptr, finalizer, finalizer) +} + +// NewBinaryConvRequestFromBytes creates a new BinaryConvRequest from an array +// of bytes. +func NewBinaryConvRequestFromBytes(bytes []byte) *BinaryConvRequest { + if bytes == nil { + return &BinaryConvRequest{} + } + return NewBinaryConvRequest(BinaryPointer(C.CBytes(bytes)), + func(ptr BinaryPointer) { C.free(unsafe.Pointer(ptr)) }) +} + +// Style returns the response style for the request, so always BinaryPrompt. +func (b *BinaryConvRequest) Style() Style { + return BinaryPrompt +} + +// Pointer returns the conversation style of the StringConvRequest. +func (b *BinaryConvRequest) Pointer() BinaryPointer { + ptr := b.ptr.Load() + return *(*BinaryPointer)(unsafe.Pointer(&ptr)) +} + +// CreateResponse creates a new BinaryConvResponse from the request +func (b *BinaryConvRequest) CreateResponse(ptr BinaryPointer) BinaryConvResponse { + bcr := &binaryConvResponse{ptr, b.responseFinalizer, &sync.Mutex{}} + runtime.SetFinalizer(bcr, func(bcr *binaryConvResponse) { + bcr.Release() + }) + return bcr +} + +// Release releases the resources allocated by the request +func (b *BinaryConvRequest) Release() { + ptr := b.ptr.Swap(0) + if b.finalizer != nil { + b.finalizer(*(*BinaryPointer)(unsafe.Pointer(&ptr))) + runtime.SetFinalizer(b, nil) + } +} + +// BinaryDecoder is a function type for decode the a binary pointer data into +// bytes +type BinaryDecoder func(BinaryPointer) ([]byte, error) + +// BinaryConvResponse is a subtype of ConvResponse used for binary +// conversation responses. +type BinaryConvResponse interface { + ConvResponse + Data() BinaryPointer + Decode(BinaryDecoder) ([]byte, error) + Release() +} + +type binaryConvResponse struct { + ptr BinaryPointer + finalizer BinaryFinalizer + mutex *sync.Mutex +} + +// Style returns the response style for the response, so always BinaryPrompt. +func (b binaryConvResponse) Style() Style { + return BinaryPrompt +} + +// Data returns the response native pointer, it's up to the protocol to parse +// it accordingly. +func (b *binaryConvResponse) Data() BinaryPointer { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.ptr +} + +// Decode decodes the binary data using the provided decoder function. +func (b *binaryConvResponse) Decode(decoder BinaryDecoder) ( + []byte, error) { + if decoder == nil { + return nil, errors.New("nil decoder provided") + } + b.mutex.Lock() + defer b.mutex.Unlock() + return decoder(b.ptr) +} + +// Release releases the binary conversation response data. +// This is also automatically via a finalizer, but applications may control +// this explicitly deferring execution of this. +func (b *binaryConvResponse) Release() { + b.mutex.Lock() + defer b.mutex.Unlock() + ptr := b.ptr + b.ptr = nil + if b.finalizer != nil { + b.finalizer(ptr) + } else { + C.free(unsafe.Pointer(ptr)) + } +} + +// StartStringConv starts a text-based conversation using the provided style +// and prompt. +func (m *moduleTransaction) StartStringConv(style Style, prompt string) ( + StringConvResponse, error) { + return m.startStringConvImpl(m, style, prompt) +} + +func (m *moduleTransaction) startStringConvImpl(iface moduleTransactionIface, + style Style, prompt string) ( + StringConvResponse, error) { + switch style { + case BinaryPrompt: + return nil, fmt.Errorf("%w: binary style is not supported", ErrConv) + } + + res, err := m.startConvImpl(iface, NewStringConvRequest(style, prompt)) + if err != nil { + return nil, err + } + + stringRes, _ := res.(stringConvResponse) + return stringRes, nil +} + +// StartStringConvf allows to start string conversation with formatting support. +func (m *moduleTransaction) StartStringConvf(style Style, format string, args ...interface{}) ( + StringConvResponse, error) { + return m.StartStringConv(style, fmt.Sprintf(format, args...)) +} + +// HasBinaryProtocol checks if binary protocol is supported. +func (m *moduleTransaction) hasBinaryProtocol() bool { + return CheckPamHasBinaryProtocol() +} + +// StartBinaryConv starts a binary conversation using the provided bytes. +func (m *moduleTransaction) StartBinaryConv(bytes []byte) ( + BinaryConvResponse, error) { + return m.startBinaryConvImpl(m, bytes) +} + +func (m *moduleTransaction) startBinaryConvImpl(iface moduleTransactionIface, + bytes []byte) ( + BinaryConvResponse, error) { + res, err := m.startConvImpl(iface, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return nil, err + } + + binaryRes, _ := res.(BinaryConvResponse) + return binaryRes, nil +} + +// StartConv initiates a PAM conversation using the provided ConvRequest. +func (m *moduleTransaction) StartConv(req ConvRequest) ( + ConvResponse, error) { + return m.startConvImpl(m, req) +} + +func (m *moduleTransaction) startConvImpl(iface moduleTransactionIface, req ConvRequest) ( + ConvResponse, error) { + resp, err := m.startConvMultiImpl(iface, []ConvRequest{req}) + if err != nil { + return nil, err + } + if len(resp) != 1 { + return nil, fmt.Errorf("%w: not enough values returned", ErrConv) + } + return resp[0], nil +} + +func (m *moduleTransaction) startConv(conv *C.struct_pam_conv, nMsg C.int, + messages **C.struct_pam_message, outResponses **C.struct_pam_response) C.int { + return C.start_pam_conv(conv, nMsg, messages, outResponses) +} + +// startConvMultiImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, + requests []ConvRequest) (responses []ConvResponse, err error) { + defer func() { + if err == nil { + _ = m.handlePamStatus(success) + return + } + var pamErr Error + if !errors.As(err, &pamErr) { + err = errors.Join(ErrConv, err) + pamErr = ErrConv + } + _ = m.handlePamStatus(C.int(pamErr)) + }() + + if len(requests) == 0 { + return nil, errors.New("no requests defined") + } + if len(requests) > maxNumMsg { + return nil, errors.New("too many requests") + } + + conv, err := iface.getConv() + if err != nil { + return nil, err + } + + if conv == nil || conv.conv == nil { + return nil, errors.New("impossible to find conv handler") + } + + // FIXME: Just use make([]C.struct_pam_message, 0, len(requests)) + // and append, when it's possible to use runtime.Pinner + var cMessagePtr *C.struct_pam_message + cMessages := (**C.struct_pam_message)(C.calloc(C.size_t(len(requests)), + (C.size_t)(unsafe.Sizeof(cMessagePtr)))) + defer C.free(unsafe.Pointer(cMessages)) + goMsgs := unsafe.Slice(cMessages, len(requests)) + + for i, req := range requests { + var cBytes unsafe.Pointer + switch r := req.(type) { + case StringConvRequest: + cBytes = unsafe.Pointer(C.CString(r.Prompt())) + defer C.free(cBytes) + case BinaryConvRequester: + if !iface.hasBinaryProtocol() { + return nil, errors.New("%w: binary protocol is not supported") + } + cBytes = unsafe.Pointer(r.Pointer()) + default: + return nil, fmt.Errorf("unsupported conversation type %#v", r) + } + + cMessage := (*C.struct_pam_message)(C.calloc(1, + (C.size_t)(unsafe.Sizeof(*goMsgs[i])))) + defer C.free(unsafe.Pointer(cMessage)) + cMessage.msg_style = C.int(req.Style()) + cMessage.msg = (*C.char)(cBytes) + goMsgs[i] = cMessage + } + + if m.convMutex != nil { + m.convMutex.Lock() + defer m.convMutex.Unlock() + } + var cResponses *C.struct_pam_response + ret := iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses) + if ret != success { + return nil, Error(ret) + } + + goResponses := unsafe.Slice(cResponses, len(requests)) + defer func() { + for i, resp := range goResponses { + if resp.resp == nil { + continue + } + switch req := requests[i].(type) { + case BinaryConvRequester: + // In the binary prompt case, we need to rely on the provided + // finalizer to release the response, so let's create a new one. + req.CreateResponse(BinaryPointer(resp.resp)).Release() + default: + C.free(unsafe.Pointer(resp.resp)) + } + } + C.free(unsafe.Pointer(cResponses)) + }() + + responses = make([]ConvResponse, 0, len(requests)) + for i, resp := range goResponses { + request := requests[i] + msgStyle := request.Style() + switch msgStyle { + case PromptEchoOff: + fallthrough + case PromptEchoOn: + fallthrough + case ErrorMsg: + fallthrough + case TextInfo: + responses = append(responses, stringConvResponse{ + style: msgStyle, + response: C.GoString(resp.resp), + }) + case BinaryPrompt: + // Let's steal the resp ownership here, so that the request + // finalizer won't act on it. + bcr, _ := request.(BinaryConvRequester) + resp := bcr.CreateResponse(BinaryPointer(resp.resp)) + goResponses[i].resp = nil + responses = append(responses, resp) + default: + return nil, + fmt.Errorf("unsupported conversation type %v", msgStyle) + } + } + + return responses, nil +} + +// StartConvMulti initiates a PAM conversation with multiple ConvRequest's. +func (m *moduleTransaction) StartConvMulti(requests []ConvRequest) ( + []ConvResponse, error) { + return m.startConvMultiImpl(m, requests) +} diff --git a/module-transaction_test.go b/module-transaction_test.go new file mode 100644 index 0000000..0514694 --- /dev/null +++ b/module-transaction_test.go @@ -0,0 +1,1111 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +import ( + "errors" + "fmt" + "reflect" + "strings" + "testing" +) + +type customConvRequest int + +func (r customConvRequest) Style() Style { + return Style(r) +} + +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func Test_NewNullModuleTransaction(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + mt := moduleTransaction{} + + if mt.handle != nil { + t.Fatalf("unexpected handle value: %v", mt.handle) + } + + if s := Error(mt.lastStatus.Load()); s != success { + t.Fatalf("unexpected status: %v", s) + } + + tests := map[string]struct { + testFunc func(t *testing.T) (any, error) + expectedError error + ignoreError bool + }{ + "GetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetItem(Service) + }, + }, + "SetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetItem(Service, "foo") + }, + }, + "GetEnv": { + ignoreError: true, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetEnv("foo"), nil + }, + }, + "PutEnv": { + expectedError: ErrAbort, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.PutEnv("foo=bar") + }, + }, + "GetEnvList": { + expectedError: ErrBuf, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + list, err := mt.GetEnvList() + if len(list) > 0 { + t.Fatalf("unexpected list: %v", list) + } + return nil, err + }, + }, + "GetUser": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetUser("prompt") + }, + }, + "GetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetData("some-data") + }, + }, + "SetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", []interface{}{}) + }, + }, + "SetData-nil": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", nil) + }, + }, + "StartConv-StringConv": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartConv(NewStringConvRequest(TextInfo, "a prompt")) + }, + }, + "StartStringConv": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartStringConv(TextInfo, "a prompt") + }, + }, + "StartStringConvf": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartStringConvf(TextInfo, "a prompt %s", "with info") + }, + }, + "StartConvMulti": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartConvMulti([]ConvRequest{ + NewStringConvRequest(TextInfo, "a prompt"), + NewStringConvRequest(ErrorMsg, "another prompt"), + NewBinaryConvRequest(BinaryPointer(&mt), nil), + NewBinaryConvRequestFromBytes([]byte("These are bytes!")), + NewBinaryConvRequestFromBytes([]byte{}), + NewBinaryConvRequestFromBytes(nil), + NewBinaryConvRequest(nil, nil), + }) + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name+"-error-check", func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + data, err := tc.testFunc(t) + + switch d := data.(type) { + case string: + if d != "" { + t.Fatalf("empty value was expected, got %s", d) + } + case interface{}: + if !reflect.ValueOf(d).IsNil() { + t.Fatalf("nil value was expected, got %v", d) + } + default: + if d != nil { + t.Fatalf("nil value was expected, got %v", d) + } + } + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + var expectedError error = ErrSystem + if tc.expectedError != nil { + expectedError = tc.expectedError + } + + if !errors.Is(err, expectedError) { + t.Fatalf("status %v was expected, but got %v", + expectedError, err) + } + }) + } + + for name, tc := range tests { + // These can't be parallel - we test a private value that is not thread safe + t.Run(name+"-lastStatus-check", func(t *testing.T) { + mt.lastStatus.Store(99999) + _, err := tc.testFunc(t) + status := Error(mt.lastStatus.Load()) + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + expectedStatus := ErrSystem + if tc.expectedError != nil { + errors.As(err, &expectedStatus) + } + + if status != expectedStatus { + t.Fatalf("status %v was expected, but got %d", + expectedStatus, status) + } + }) + } +} + +func Test_ModuleTransaction_InvokeHandler(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + mt := &moduleTransaction{} + + err := mt.InvokeHandler(nil, 0, nil) + if !errors.Is(err, ErrIgnore) { + t.Fatalf("unexpected err: %v", err) + } + + tests := map[string]struct { + flags Flags + args []string + returnedError error + expectedError error + expectedErrorMsg string + }{ + "success": { + expectedError: nil, + }, + "success-with-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + }, + "success-with-args": { + expectedError: nil, + args: []string{"foo", "bar"}, + }, + "success-with-args-and-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + args: []string{"foo", "bar"}, + }, + "ignore": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + }, + "ignore-with-args-and-flags": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + args: []string{"foo", "bar"}, + }, + "generic-error": { + expectedError: ErrSystem, + returnedError: errors.New("this is a generic go error"), + expectedErrorMsg: "this is a generic go error", + }, + "transaction-error-service-error": { + expectedError: ErrService, + returnedError: errors.Join(ErrService, errors.New("ErrService")), + expectedErrorMsg: ErrService.Error(), + }, + "return-type-as-error-success": { + expectedError: nil, + returnedError: Error(0), + }, + "return-type-as-error": { + expectedError: ErrNoModuleData, + returnedError: ErrNoModuleData, + expectedErrorMsg: ErrNoModuleData.Error(), + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := mt.InvokeHandler(func(handlerMt ModuleTransaction, + handlerFlags Flags, handlerArgs []string) error { + if handlerMt != mt { + t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt) + } + if handlerFlags != tc.flags { + t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags) + } + if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") { + t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs) + } + + return tc.returnedError + }, tc.flags, tc.args) + + status := Error(mt.lastStatus.Load()) + + if !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + var expectedStatus Error + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + expectedStatus = pamErr + } else { + expectedStatus = ErrSystem + } + } + + if status != expectedStatus { + t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus) + } + }) + } +} + +func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { + t.Helper() + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + testFunc func(mock *mockModuleTransaction) (any, error) + mockExpectations mockModuleTransactionExpectations + mockRetData mockModuleTransactionReturnedData + conversationHandler ConversationHandler + + expectedError error + expectedValue any + ignoreError bool + }{ + "GetUser-empty": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-preset-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + mockRetData: mockModuleTransactionReturnedData{User: "dummy-user"}, + expectedValue: "dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you?", + User: "returned-dummy-user", + }, + expectedValue: "returned-dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-prompt": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you???", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-style": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "who are you?", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetData-not-available": { + expectedError: ErrNoModuleData, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "GetData-not-available-other-failure": { + expectedError: ErrBuf, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + mockRetData: mockModuleTransactionReturnedData{Status: ErrBuf}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "SetData-empty-nil": { + expectedError: ErrNoModuleData, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", nil)) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-empty-to-value": { + expectedValue: []string{"hello", "world"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", + []string{"hello", "world"})) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-to-value": { + expectedValue: []interface{}{"a string", true, 0.55, errors.New("oh no")}, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "some-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "some-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + return mt.getDataImpl(mock, "some-data") + }, + }, + "SetData-to-value-replacing": { + expectedValue: "just a value", + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "replaced-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + "just a value")) + return mt.getDataImpl(mock, "replaced-data") + }, + }, + "StartConv-no-conv-set": { + expectedError: ErrConv, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + TextInfo, + "hello PAM!", + }) + }, + }, + "StartConv-text-info": { + expectedValue: stringConvResponse{TextInfo, "nice to see you, Go!"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: TextInfo, + ExpectedMessage: "hello PAM!", + TextInfo: "nice to see you, Go!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + TextInfo, + "hello PAM!", + }) + }, + }, + "StartConv-error-msg": { + expectedValue: stringConvResponse{ErrorMsg, "ops, sorry..."}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: ErrorMsg, + ExpectedMessage: "This is wrong, PAM!", + ErrorMsg: "ops, sorry...", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + ErrorMsg, + "This is wrong, PAM!", + }) + }, + }, + "StartConv-prompt-echo-on": { + expectedValue: stringConvResponse{PromptEchoOn, "here's my public data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "Give me your non-private infos", + PromptEchoOn: "here's my public data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + PromptEchoOn, + "Give me your non-private infos", + }) + }, + }, + "StartConv-prompt-echo-off": { + expectedValue: stringConvResponse{PromptEchoOff, "here's my private data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "Give me your private secrets", + PromptEchoOff: "here's my private data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + PromptEchoOff, + "Give me your private secrets", + }) + }, + }, + "StartConv-unknown-style": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: Style(9999), + ExpectedMessage: "hello PAM!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + Style(9999), + "hello PAM!", + }) + }, + }, + "StartConv-unknown-style-response": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: Style(9999), + ExpectedMessage: "hello PAM!", + IgnoreUnknownStyle: true, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + Style(9999), + "hello PAM!", + }) + }, + }, + "StartStringConv-text-info": { + expectedValue: stringConvResponse{TextInfo, "nice to see you, Go!"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: TextInfo, + ExpectedMessage: "hello PAM!", + TextInfo: "nice to see you, Go!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, TextInfo, + "hello PAM!") + }, + }, + "StartStringConv-error-msg": { + expectedValue: stringConvResponse{ErrorMsg, "ops, sorry..."}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: ErrorMsg, + ExpectedMessage: "This is wrong, PAM!", + ErrorMsg: "ops, sorry...", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, ErrorMsg, + "This is wrong, PAM!") + }, + }, + "StartStringConv-prompt-echo-on": { + expectedValue: stringConvResponse{PromptEchoOn, "here's my public data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "Give me your non-private infos", + PromptEchoOn: "here's my public data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOn, + "Give me your non-private infos") + }, + }, + "StartStringConv-prompt-echo-off": { + expectedValue: stringConvResponse{PromptEchoOff, "here's my private data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "Give me your private secrets", + PromptEchoOff: "here's my private data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOff, + "Give me your private secrets") + }, + }, + "StartStringConv-binary": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedMessage: "require binary data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOff, + "require binary data") + }, + }, + "StartConvMulti-missing": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvMultiImpl(mock, nil) + }, + }, + "StartConvMulti-too-many": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + reqs := [maxNumMsg + 1]ConvRequest{} + return mt.startConvMultiImpl(mock, reqs[:]) + }, + }, + "StartConvMulti-unexpected-style": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + var req ConvRequest = customConvRequest(0xdeadbeef) + return mt.startConvMultiImpl(mock, []ConvRequest{req}) + }, + }, + "StartConvMulti-string-as-binary": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvMultiImpl(mock, []ConvRequest{ + NewStringConvRequest(BinaryPrompt, "no binary!"), + }) + }, + }, + "StartConvMulti-all-types": { + expectedValue: []any{ + []ConvResponse{ + stringConvResponse{TextInfo, "nice to see you, Go!"}, + stringConvResponse{ErrorMsg, "ops, sorry..."}, + stringConvResponse{PromptEchoOn, "here's my public data"}, + stringConvResponse{PromptEchoOff, "here's my private data"}, + }, + [][]byte{ + {0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + }, + conversationHandler: mockConversationHandler{ + TextInfo: "nice to see you, Go!", + ErrorMsg: "ops, sorry...", + PromptEchoOn: "here's my public data", + PromptEchoOff: "here's my private data", + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + ExpectedMessagesByStyle: map[Style]string{ + TextInfo: "hello PAM!", + ErrorMsg: "This is wrong, PAM!", + PromptEchoOn: "Give me your non-private infos", + PromptEchoOff: "Give me your private secrets", + }, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + requests := []ConvRequest{ + NewStringConvRequest(TextInfo, "hello PAM!"), + NewStringConvRequest(ErrorMsg, "This is wrong, PAM!"), + NewStringConvRequest(PromptEchoOn, "Give me your non-private infos"), + NewStringConvRequest(PromptEchoOff, "Give me your private secrets"), + NewBinaryConvRequestFromBytes( + testBinaryDataEncoder([]byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + } + + data, err := mt.startConvMultiImpl(mock, requests) + if err != nil { + return data, err + } + + stringResponses := []ConvResponse{} + binaryResponses := [][]byte{} + for i, r := range data { + if r.Style() != requests[i].Style() { + mock.T.Fatalf("unexpected style %#v vs %#v", + r.Style(), requests[i].Style()) + } + + switch rt := r.(type) { + case BinaryConvResponse: + decoded, err := rt.Decode(testBinaryDataDecoder) + if err != nil { + return data, err + } + binaryResponses = append(binaryResponses, decoded) + case StringConvResponse: + stringResponses = append(stringResponses, r) + default: + mock.T.Fatalf("unexpected value %v", rt) + } + } + return []any{ + stringResponses, + binaryResponses, + }, err + }, + }, + "StartConvMulti-all-types-some-failing": { + expectedError: ErrConv, + expectedValue: []ConvResponse(nil), + conversationHandler: mockConversationHandler{ + TextInfo: "nice to see you, Go!", + ErrorMsg: "ops, sorry...", + PromptEchoOn: "here's my public data", + PromptEchoOff: "here's my private data", + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + ExpectedMessagesByStyle: map[Style]string{ + TextInfo: "hello PAM!", + ErrorMsg: "This is wrong, PAM!", + PromptEchoOn: "Give me your non-private infos", + PromptEchoOff: "Give me your private secrets", + Style(0xfaaf): "This will fail", + }, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + IgnoreUnknownStyle: true, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + requests := []ConvRequest{ + NewStringConvRequest(TextInfo, "hello PAM!"), + NewStringConvRequest(ErrorMsg, "This is wrong, PAM!"), + NewStringConvRequest(PromptEchoOn, "Give me your non-private infos"), + NewStringConvRequest(PromptEchoOff, "Give me your private secrets"), + NewStringConvRequest(Style(0xfaaf), "This will fail"), + NewBinaryConvRequestFromBytes( + testBinaryDataEncoder([]byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + } + + return mt.startConvMultiImpl(mock, requests) + }, + }, + "StartConv-Binary-unsupported": { + expectedValue: nil, + expectedError: ErrConv, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + mock.binaryProtocol = false + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + }, + }, + "StartConv-Binary": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-expected-data-mismatch": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is not the expected data!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + }, + }, + "StartConv-Binary-unexpected-nil": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(nil)) + }, + }, + "StartConv-Binary-expected-nil": { + expectedValue: []byte("\x1ASome binary Dat\xaa"), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedNil: true, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(nil)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-returns-nil": { + expectedValue: BinaryPointer(nil), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x1ASome binary Dat\xaa"), + Binary: nil, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte("\x1ASome binary Dat\xaa")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Data(), err + }, + }, + "StartBinaryConv": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartBinaryConv-expected-data-mismatch": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is not the expected data!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startBinaryConvImpl(mock, bytes) + }, + }, + "StartBinaryConv-unexpected-nil": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startBinaryConvImpl(mock, nil) + }, + }, + "StartBinaryConv-expected-nil": { + expectedValue: []byte("\x1ASome binary Dat\xaa"), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedNil: true, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + data, err := mt.startBinaryConvImpl(mock, nil) + if err != nil { + return data, err + } + return data.Decode(testBinaryDataDecoder) + }, + }, + "StartBinaryConv-returns-nil": { + expectedValue: BinaryPointer(nil), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x1ASome binary Dat\xaa"), + Binary: nil, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte("\x1ASome binary Dat\xaa")) + data, err := mt.startBinaryConvImpl(mock, bytes) + if err != nil { + return data, err + } + return data.Data(), err + }, + }, + "StartConv-Binary-with-ConvFunc": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, fmt.Errorf("%w, data mismatch %#v vs %#v", + ErrConv, bytes, expectedBinary) + } + return testBinaryDataEncoder([]byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-ConvFunc-error": { + expectedError: ErrConv, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + return nil, errors.New("got an error") + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes([]byte{})) + }, + }, + "StartConv-String-with-ConvBinaryFunc": { + expectedError: ErrConv, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewStringConvRequest(TextInfo, "prompt")) + }, + }, + "StartConv-Binary-with-PointerConvFunc": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x95}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From bytes pointer.") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From bytes pointer.")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-and-allocated-data": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x95}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...")) + data, err := mt.startConvImpl(mock, + NewBinaryConvRequest(allocateCBytes(bytes), binaryPointerCBytesFinalizer)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-and-allocated-data-erroring": { + expectedValue: nil, + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), ErrConv + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...")) + data, err := mt.startConvImpl(mock, + NewBinaryConvRequest(allocateCBytes(bytes), binaryPointerCBytesFinalizer)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-empty": { + expectedValue: []byte{}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is an empty binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is an empty binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-nil": { + expectedValue: []byte(nil), + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a nil binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a nil binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-error": { + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + return nil, errors.New("got an error") + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes([]byte{})) + }, + }, + "StartConv-String-with-ConvPointerBinaryFunc": { + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewStringConvRequest(TextInfo, "prompt")) + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + mock := newMockModuleTransaction(&mockModuleTransaction{T: t, + Expectations: tc.mockExpectations, RetData: tc.mockRetData, + ConversationHandler: tc.conversationHandler}) + data, err := tc.testFunc(mock) + + if !tc.ignoreError && !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + if !reflect.DeepEqual(data, tc.expectedValue) { + t.Fatalf("data mismatch, %#v vs %#v", data, tc.expectedValue) + } + }) + } +} + +func Test_MockModuleTransaction(t *testing.T) { + mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +} + +func Test_MockModuleTransactionParallelConv(t *testing.T) { + mt, _ := NewModuleTransactionInvokerParallelConv(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +} diff --git a/transaction.c b/transaction.c deleted file mode 100644 index df25cf6..0000000 --- a/transaction.c +++ /dev/null @@ -1,62 +0,0 @@ -#include "_cgo_export.h" -#include -#include -#include - -#ifdef __sun -#define PAM_CONST -#else -#define PAM_CONST const -#endif - -int cb_pam_conv( - int num_msg, - PAM_CONST struct pam_message **msg, - struct pam_response **resp, - void *appdata_ptr) -{ - *resp = calloc(num_msg, sizeof **resp); - if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG) { - return PAM_CONV_ERR; - } - if (!*resp) { - return PAM_BUF_ERR; - } - for (size_t i = 0; i < num_msg; ++i) { - struct cbPAMConv_return result = cbPAMConv( - msg[i]->msg_style, - (char *)msg[i]->msg, - (uintptr_t)appdata_ptr); - if (result.r1 != PAM_SUCCESS) { - goto error; - } - (*resp)[i].resp = result.r0; - } - return PAM_SUCCESS; -error: - for (size_t i = 0; i < num_msg; ++i) { - if ((*resp)[i].resp) { - memset((*resp)[i].resp, 0, strlen((*resp)[i].resp)); - free((*resp)[i].resp); - } - } - memset(*resp, 0, num_msg * sizeof *resp); - free(*resp); - *resp = NULL; - return PAM_CONV_ERR; -} - -void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) -{ - conv->conv = cb_pam_conv; - conv->appdata_ptr = (void *)appdata; -} - -// pam_start_confdir is a recent PAM api to declare a confdir (mostly for tests) -// weaken the linking dependency to detect if it’s present. -int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, const char *confdir, pam_handle_t **pamh) __attribute__ ((weak)); -int check_pam_start_confdir(void) { - if (pam_start_confdir == NULL) - return 1; - return 0; -} diff --git a/transaction.go b/transaction.go index cc730fb..8b32704 100644 --- a/transaction.go +++ b/transaction.go @@ -1,32 +1,21 @@ // Package pam provides a wrapper for the PAM application API. package pam -//#include -//#include -//#include //#cgo CFLAGS: -Wall -std=c99 //#cgo LDFLAGS: -lpam -//void init_pam_conv(struct pam_conv *conv, uintptr_t); -//int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, const char *confdir, pam_handle_t **pamh) __attribute__ ((weak)); -//int check_pam_start_confdir(void); // -//#ifdef PAM_BINARY_PROMPT -//#define BINARY_PROMPT_IS_SUPPORTED 1 -//#else -//#include -//#define PAM_BINARY_PROMPT INT_MAX -//#define BINARY_PROMPT_IS_SUPPORTED 0 -//#endif +//#include "transaction.h" import "C" import ( - "errors" - "runtime" - "runtime/cgo" "strings" + "sync/atomic" "unsafe" ) +// success indicates a successful function return. +const success = C.PAM_SUCCESS + // Style is the type of message that the conversation handler should display. type Style int @@ -37,155 +26,42 @@ const ( PromptEchoOff Style = C.PAM_PROMPT_ECHO_OFF // PromptEchoOn indicates the conversation handler should obtain a // string while echoing text. - PromptEchoOn = C.PAM_PROMPT_ECHO_ON + PromptEchoOn Style = C.PAM_PROMPT_ECHO_ON // ErrorMsg indicates the conversation handler should display an // error message. - ErrorMsg = C.PAM_ERROR_MSG + ErrorMsg Style = C.PAM_ERROR_MSG // TextInfo indicates the conversation handler should display some // text. - TextInfo = C.PAM_TEXT_INFO + TextInfo Style = C.PAM_TEXT_INFO + // BinaryPrompt indicates the conversation handler that should implement + // the private binary protocol + BinaryPrompt Style = C.PAM_BINARY_PROMPT ) -// ConversationHandler is an interface for objects that can be used as -// conversation callbacks during PAM authentication. -type ConversationHandler interface { - // RespondPAM receives a message style and a message string. If the - // message Style is PromptEchoOff or PromptEchoOn then the function - // should return a response string. - RespondPAM(Style, string) (string, error) -} - // BinaryPointer exposes the type used for the data in a binary conversation // it represents a pointer to data that is produced by the module and that // must be parsed depending on the protocol in use type BinaryPointer unsafe.Pointer -type BinaryConversationHandler interface { - ConversationHandler - // Respond receives a pointer to the binary message. It's up to the - // receiver to parse it according to the protocol specifications. - // The function can return a byte array that will passed as pointer back - // to the module. - RespondPAMBinary(BinaryPointer) ([]byte, error) +// NativeHandle is the type of the native PAM handle for a transaction so that +// it can be exported +type NativeHandle = *C.pam_handle_t + +// transactionBase is a handler for a PAM transaction that can be used to +// group the operations that can be performed both by the application and the +// module side +type transactionBase struct { + handle NativeHandle + lastStatus atomic.Int32 } -// ConversationFunc is an adapter to allow the use of ordinary functions as -// conversation callbacks. -type ConversationFunc func(Style, string) (string, error) - -// RespondPAM is a conversation callback adapter. -func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { - return f(s, msg) -} - -// cbPAMConv is a wrapper for the conversation callback function. -//export cbPAMConv -func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { - var r string - var err error - v := cgo.Handle(c).Value() - switch cb := v.(type) { - case ConversationHandler: - if s == C.PAM_BINARY_PROMPT { - return nil, C.PAM_AUTHINFO_UNAVAIL - } - r, err = cb.RespondPAM(Style(s), C.GoString(msg)) - case BinaryConversationHandler: - if s == C.PAM_BINARY_PROMPT { - bytes, err := cb.RespondPAMBinary(BinaryPointer(msg)) - if err != nil { - return nil, C.PAM_CONV_ERR - } - return (*C.char)(C.CBytes(bytes)), C.PAM_SUCCESS - } else { - r, err = cb.RespondPAM(Style(s), C.GoString(msg)) - } +// Allows to call pam functions managing return status +func (t *transactionBase) handlePamStatus(cStatus C.int) error { + t.lastStatus.Store(int32(cStatus)) + if status := Error(cStatus); status != success { + return status } - if err != nil { - return nil, C.PAM_CONV_ERR - } - return C.CString(r), C.PAM_SUCCESS -} - -// Transaction is the application's handle for a PAM transaction. -type Transaction struct { - handle *C.pam_handle_t - conv *C.struct_pam_conv - status C.int - c cgo.Handle -} - -// transactionFinalizer cleans up the PAM handle and deletes the callback -// function. -func transactionFinalizer(t *Transaction) { - C.pam_end(t.handle, t.status) - t.c.Delete() -} - -// Start initiates a new PAM transaction. Service is treated identically to -// how pam_start treats it internally. -// -// All application calls to PAM begin with Start*. The returned -// transaction provides an interface to the remainder of the API. -func Start(service, user string, handler ConversationHandler) (*Transaction, error) { - return start(service, user, handler, "") -} - -// StartFunc registers the handler func as a conversation handler. -func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { - return Start(service, user, ConversationFunc(handler)) -} - -// StartConfDir initiates a new PAM transaction. Service is treated identically to -// how pam_start treats it internally. -// confdir allows to define where all pam services are defined. This is used to provide -// custom paths for tests. -// -// All application calls to PAM begin with Start*. The returned -// transaction provides an interface to the remainder of the API. -func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { - if !CheckPamHasStartConfdir() { - return nil, errors.New("StartConfDir() was used, but the pam version on the system is not recent enough") - } - - return start(service, user, handler, confDir) -} - -func start(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { - switch handler.(type) { - case BinaryConversationHandler: - if C.BINARY_PROMPT_IS_SUPPORTED == 0 { - return nil, errors.New("BinaryConversationHandler() was used, but it is not supported by this platform") - } - } - t := &Transaction{ - conv: &C.struct_pam_conv{}, - c: cgo.NewHandle(handler), - } - C.init_pam_conv(t.conv, C.uintptr_t(t.c)) - runtime.SetFinalizer(t, transactionFinalizer) - s := C.CString(service) - defer C.free(unsafe.Pointer(s)) - var u *C.char - if len(user) != 0 { - u = C.CString(user) - defer C.free(unsafe.Pointer(u)) - } - if confDir == "" { - t.status = C.pam_start(s, u, t.conv, &t.handle) - } else { - c := C.CString(confDir) - defer C.free(unsafe.Pointer(c)) - t.status = C.pam_start_confdir(s, u, t.conv, c, &t.handle) - } - if t.status != C.PAM_SUCCESS { - return nil, t - } - return t, nil -} - -func (t *Transaction) Error() string { - return C.GoString(C.pam_strerror(t.handle, C.int(t.status))) + return nil } // Item is a an PAM information type. @@ -196,38 +72,42 @@ const ( // Service is the name which identifies the PAM stack. Service Item = C.PAM_SERVICE // User identifies the username identity used by a service. - User = C.PAM_USER + User Item = C.PAM_USER // Tty is the terminal name. - Tty = C.PAM_TTY + Tty Item = C.PAM_TTY // Rhost is the requesting host name. - Rhost = C.PAM_RHOST + Rhost Item = C.PAM_RHOST // Authtok is the currently active authentication token. - Authtok = C.PAM_AUTHTOK + Authtok Item = C.PAM_AUTHTOK // Oldauthtok is the old authentication token. - Oldauthtok = C.PAM_OLDAUTHTOK + Oldauthtok Item = C.PAM_OLDAUTHTOK // Ruser is the requesting user name. - Ruser = C.PAM_RUSER + Ruser Item = C.PAM_RUSER // UserPrompt is the string use to prompt for a username. - UserPrompt = C.PAM_USER_PROMPT + UserPrompt Item = C.PAM_USER_PROMPT + // FailDelay is the app supplied function to override failure delays. + FailDelay Item = C.PAM_FAIL_DELAY + // Xdisplay is the X display name + Xdisplay Item = C.PAM_XDISPLAY + // Xauthdata is the X server authentication data. + Xauthdata Item = C.PAM_XAUTHDATA + // AuthtokType is the type for pam_get_authtok + AuthtokType Item = C.PAM_AUTHTOK_TYPE ) // SetItem sets a PAM information item. -func (t *Transaction) SetItem(i Item, item string) error { +func (t *transactionBase) SetItem(i Item, item string) error { cs := unsafe.Pointer(C.CString(item)) defer C.free(cs) - t.status = C.pam_set_item(t.handle, C.int(i), cs) - if t.status != C.PAM_SUCCESS { - return t - } - return nil + return t.handlePamStatus(C.pam_set_item(t.handle, C.int(i), cs)) } // GetItem retrieves a PAM information item. -func (t *Transaction) GetItem(i Item) (string, error) { +func (t *transactionBase) GetItem(i Item) (string, error) { var s unsafe.Pointer - t.status = C.pam_get_item(t.handle, C.int(i), &s) - if t.status != C.PAM_SUCCESS { - return "", t + err := t.handlePamStatus(C.pam_get_item(t.handle, C.int(i), &s)) + if err != nil { + return "", err } return C.GoString((*C.char)(s)), nil } @@ -243,107 +123,36 @@ const ( Silent Flags = C.PAM_SILENT // DisallowNullAuthtok indicates that authorization should fail // if the user does not have a registered authentication token. - DisallowNullAuthtok = C.PAM_DISALLOW_NULL_AUTHTOK + DisallowNullAuthtok Flags = C.PAM_DISALLOW_NULL_AUTHTOK // EstablishCred indicates that credentials should be established // for the user. - EstablishCred = C.PAM_ESTABLISH_CRED - // DeleteCred inidicates that credentials should be deleted. - DeleteCred = C.PAM_DELETE_CRED + EstablishCred Flags = C.PAM_ESTABLISH_CRED + // DeleteCred indicates that credentials should be deleted. + DeleteCred Flags = C.PAM_DELETE_CRED // ReinitializeCred indicates that credentials should be fully // reinitialized. - ReinitializeCred = C.PAM_REINITIALIZE_CRED + ReinitializeCred Flags = C.PAM_REINITIALIZE_CRED // RefreshCred indicates that the lifetime of existing credentials // should be extended. - RefreshCred = C.PAM_REFRESH_CRED + RefreshCred Flags = C.PAM_REFRESH_CRED // ChangeExpiredAuthtok indicates that the authentication token // should be changed if it has expired. - ChangeExpiredAuthtok = C.PAM_CHANGE_EXPIRED_AUTHTOK + ChangeExpiredAuthtok Flags = C.PAM_CHANGE_EXPIRED_AUTHTOK ) -// Authenticate is used to authenticate the user. -// -// Valid flags: Silent, DisallowNullAuthtok -func (t *Transaction) Authenticate(f Flags) error { - t.status = C.pam_authenticate(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - -// SetCred is used to establish, maintain and delete the credentials of a -// user. -// -// Valid flags: EstablishCred, DeleteCred, ReinitializeCred, RefreshCred -func (t *Transaction) SetCred(f Flags) error { - t.status = C.pam_setcred(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - -// AcctMgmt is used to determine if the user's account is valid. -// -// Valid flags: Silent, DisallowNullAuthtok -func (t *Transaction) AcctMgmt(f Flags) error { - t.status = C.pam_acct_mgmt(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - -// ChangeAuthTok is used to change the authentication token. -// -// Valid flags: Silent, ChangeExpiredAuthtok -func (t *Transaction) ChangeAuthTok(f Flags) error { - t.status = C.pam_chauthtok(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - -// OpenSession sets up a user session for an authenticated user. -// -// Valid flags: Slient -func (t *Transaction) OpenSession(f Flags) error { - t.status = C.pam_open_session(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - -// CloseSession closes a previously opened session. -// -// Valid flags: Silent -func (t *Transaction) CloseSession(f Flags) error { - t.status = C.pam_close_session(t.handle, C.int(f)) - if t.status != C.PAM_SUCCESS { - return t - } - return nil -} - // PutEnv adds or changes the value of PAM environment variables. // // NAME=value will set a variable to a value. // NAME= will set a variable to an empty value. // NAME (without an "=") will delete a variable. -func (t *Transaction) PutEnv(nameval string) error { +func (t *transactionBase) PutEnv(nameval string) error { cs := C.CString(nameval) defer C.free(unsafe.Pointer(cs)) - t.status = C.pam_putenv(t.handle, cs) - if t.status != C.PAM_SUCCESS { - return t - } - return nil + return t.handlePamStatus(C.pam_putenv(t.handle, cs)) } // GetEnv is used to retrieve a PAM environment variable. -func (t *Transaction) GetEnv(name string) string { +func (t *transactionBase) GetEnv(name string) string { cs := C.CString(name) defer C.free(unsafe.Pointer(cs)) value := C.pam_getenv(t.handle, cs) @@ -358,13 +167,14 @@ func next(p **C.char) **C.char { } // GetEnvList returns a copy of the PAM environment as a map. -func (t *Transaction) GetEnvList() (map[string]string, error) { +func (t *transactionBase) GetEnvList() (map[string]string, error) { env := make(map[string]string) p := C.pam_getenvlist(t.handle) if p == nil { - t.status = C.PAM_BUF_ERR - return nil, t + t.lastStatus.Store(int32(ErrBuf)) + return nil, ErrBuf } + t.lastStatus.Store(success) for q := p; *q != nil; q = next(q) { chunks := strings.SplitN(C.GoString(*q), "=", 2) if len(chunks) == 2 { @@ -380,3 +190,8 @@ func (t *Transaction) GetEnvList() (map[string]string, error) { func CheckPamHasStartConfdir() bool { return C.check_pam_start_confdir() == 0 } + +// CheckPamHasBinaryProtocol return if pam on system supports PAM_BINARY_PROMPT +func CheckPamHasBinaryProtocol() bool { + return C.BINARY_PROMPT_IS_SUPPORTED != 0 +} diff --git a/transaction.h b/transaction.h new file mode 100644 index 0000000..292aa96 --- /dev/null +++ b/transaction.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef PAM_BINARY_PROMPT +#define BINARY_PROMPT_IS_SUPPORTED 1 +#else +#include +#define PAM_BINARY_PROMPT INT_MAX +#define BINARY_PROMPT_IS_SUPPORTED 0 +#endif + +#ifdef __sun +#define PAM_CONST +#else +#define PAM_CONST const +#endif + +extern int _go_pam_conv_handler(struct pam_message *, uintptr_t, char **reply); +extern void _go_pam_data_cleanup(pam_handle_t *, uintptr_t, int status); + +static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) +{ + if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG) + return PAM_CONV_ERR; + + *resp = calloc(num_msg, sizeof **resp); + if (!*resp) + return PAM_BUF_ERR; + + for (size_t i = 0; i < num_msg; ++i) { + int result = _go_pam_conv_handler((struct pam_message *)msg[i], (uintptr_t)appdata_ptr, &(*resp)[i].resp); + if (result != PAM_SUCCESS) + goto error; + } + + return PAM_SUCCESS; +error: + for (size_t i = 0; i < num_msg; ++i) { + if ((*resp)[i].resp) { +#ifdef PAM_BINARY_PROMPT + if (msg[i]->msg_style != PAM_BINARY_PROMPT) +#endif + memset((*resp)[i].resp, 0, strlen((*resp)[i].resp)); + free((*resp)[i].resp); + } + } + + memset(*resp, 0, num_msg * sizeof *resp); + free(*resp); + *resp = NULL; + return PAM_CONV_ERR; +} + +static inline void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) +{ + conv->conv = cb_pam_conv; + conv->appdata_ptr = (void *)appdata; +} + +static inline int start_pam_conv(struct pam_conv *pc, int num_msgs, const struct pam_message **msgs, struct pam_response **out_resp) +{ + return pc->conv(num_msgs, msgs, out_resp, pc->appdata_ptr); +} + +// pam_start_confdir is a recent PAM api to declare a confdir (mostly for +// tests) weaken the linking dependency to detect if it’s present. +int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, + const char *confdir, pam_handle_t **pamh) __attribute__((weak)); + +static inline int check_pam_start_confdir(void) +{ + if (pam_start_confdir == NULL) + return 1; + + return 0; +} + +static inline void data_cleanup(pam_handle_t *pamh, void *data, int error_status) +{ + _go_pam_data_cleanup(pamh, (uintptr_t)data, error_status); +} + +static inline int set_data(pam_handle_t *pamh, const char *name, uintptr_t handle) +{ + if (handle) + return pam_set_data(pamh, name, (void *)handle, data_cleanup); + + return pam_set_data(pamh, name, NULL, NULL); +} + +static inline int get_data(pam_handle_t *pamh, const char *name, uintptr_t *out_handle) +{ + return pam_get_data(pamh, name, (const void **)out_handle); +} diff --git a/transaction_test.go b/transaction_test.go index c7bcd2e..3166159 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -2,11 +2,44 @@ package pam import ( "errors" + "fmt" + "os" "os/user" + "path/filepath" + "runtime" + "sync/atomic" "testing" + "time" + "unsafe" ) +func maybeEndTransaction(t *testing.T, tx *Transaction) { + t.Helper() + + if tx == nil { + return + } + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } +} + +func ensureTransactionEnds(t *testing.T, tx *Transaction) { + t.Helper() + + runtime.SetFinalizer(tx, func(tx *Transaction) { + // #nosec:G103 - the pointer conversion is checked. + handle := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&tx.handle))) + if handle == nil { + return + } + t.Fatalf("transaction has not been finalized") + }) +} + func TestPAM_001(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -15,6 +48,8 @@ func TestPAM_001(t *testing.T) { tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) { return p, nil }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -33,6 +68,7 @@ func TestPAM_001(t *testing.T) { } func TestPAM_002(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -46,6 +82,8 @@ func TestPAM_002(t *testing.T) { } return "", errors.New("unexpected") }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -71,6 +109,7 @@ func (c Credentials) RespondPAM(s Style, msg string) (string, error) { } func TestPAM_003(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -80,6 +119,8 @@ func TestPAM_003(t *testing.T) { Password: "secret", } tx, err := Start("", "", c) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -90,6 +131,7 @@ func TestPAM_003(t *testing.T) { } func TestPAM_004(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -98,6 +140,8 @@ func TestPAM_004(t *testing.T) { Password: "secret", } tx, err := Start("", "test", c) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -108,16 +152,29 @@ func TestPAM_004(t *testing.T) { } func TestPAM_005(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") } + if _, found := os.LookupEnv("GO_PAM_TEST_WITH_ASAN"); found { + t.Skip("test fails under ASAN") + } tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "secret", nil }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } + service, err := tx.GetItem(Service) + if err != nil { + t.Fatalf("GetItem #error: %v", err) + } + if service != "passwd" { + t.Fatalf("Unexpected service: %v", service) + } err = tx.ChangeAuthTok(Silent) if err != nil { t.Fatalf("chauthtok #error: %v", err) @@ -125,6 +182,7 @@ func TestPAM_005(t *testing.T) { } func TestPAM_006(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -132,6 +190,8 @@ func TestPAM_006(t *testing.T) { tx, err := StartFunc("passwd", u.Username, func(s Style, msg string) (string, error) { return "secret", nil }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -146,6 +206,7 @@ func TestPAM_006(t *testing.T) { } func TestPAM_007(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -153,6 +214,8 @@ func TestPAM_007(t *testing.T) { tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) { return "", errors.New("Sorry, it didn't work") }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -164,15 +227,24 @@ func TestPAM_007(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) + } } func TestPAM_ConfDir(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() c := Credentials{ // the custom service always permits even with wrong password. Password: "wrongsecret", } tx, err := StartConfDir("permit-service", u.Username, c, "test-services") + defer func() { + if tx != nil { + _ = tx.End() + } + }() if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -180,6 +252,13 @@ func TestPAM_ConfDir(t *testing.T) { // nothing else we do, we don't support it. return } + service, err := tx.GetItem(Service) + if err != nil { + t.Fatalf("GetItem #error: %v", err) + } + if service != "permit-service" { + t.Fatalf("Unexpected service: %v", service) + } if err != nil { t.Fatalf("start #error: %v", err) } @@ -190,21 +269,36 @@ func TestPAM_ConfDir(t *testing.T) { } func TestPAM_ConfDir_FailNoServiceOrUnsupported(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } u, _ := user.Current() c := Credentials{ Password: "secret", } - _, err := StartConfDir("does-not-exists", u.Username, c, ".") + tx, err := StartConfDir("does-not-exists", u.Username, c, ".") if err == nil { t.Fatalf("authenticate #expected an error") } + if tx != nil { + t.Fatalf("authenticate #unexpected transaction") + } s := err.Error() if len(s) == 0 { t.Fatalf("error #expected an error message") } + var pamErr Error + if !errors.As(err, &pamErr) { + t.Fatalf("error #unexpected type: %#v", err) + } + if pamErr != ErrAbort { + t.Fatalf("error #unexpected status: %v", pamErr) + } } func TestPAM_ConfDir_InfoMessage(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() var infoText string tx, err := StartConfDir("echo-service", u.Username, @@ -216,24 +310,46 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) { } return "", errors.New("unexpected") }), "test-services") + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } + service, err := tx.GetItem(Service) + if err != nil { + t.Fatalf("GetItem #error: %v", err) + } + if service != "echo-service" { + t.Fatalf("Unexpected service: %v", service) + } err = tx.Authenticate(0) if err != nil { t.Fatalf("authenticate #error: %v", err) } - if infoText != "This is an info message for user " + u.Username + " on echo-service" { + if infoText != "This is an info message for user "+u.Username+" on echo-service" { t.Fatalf("Unexpected info message: %v", infoText) } } func TestPAM_ConfDir_Deny(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } u, _ := user.Current() tx, err := StartConfDir("deny-service", u.Username, Credentials{}, "test-services") + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } + service, err := tx.GetItem(Service) + if err != nil { + t.Fatalf("GetItem #error: %v", err) + } + if service != "deny-service" { + t.Fatalf("Unexpected service: %v", service) + } err = tx.Authenticate(0) if err == nil { t.Fatalf("authenticate #expected an error") @@ -242,15 +358,21 @@ func TestPAM_ConfDir_Deny(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) + } } func TestPAM_ConfDir_PromptForUserName(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) c := Credentials{ User: "testuser", // the custom service only cares about correct user name. Password: "wrongsecret", } tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services") + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -268,11 +390,14 @@ func TestPAM_ConfDir_PromptForUserName(t *testing.T) { } func TestPAM_ConfDir_WrongUserName(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) c := Credentials{ - User: "wronguser", + User: "wronguser", Password: "wrongsecret", } tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services") + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -288,12 +413,21 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) + } } func TestItem(t *testing.T) { - tx, _ := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { + t.Cleanup(maybeDoLeakCheck) + tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "", nil }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) + if err != nil { + t.Fatalf("start #error: %v", err) + } s, err := tx.GetItem(Service) if err != nil { @@ -325,9 +459,12 @@ func TestItem(t *testing.T) { } func TestEnv(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx, err := StartFunc("", "", func(s Style, msg string) (string, error) { return "", nil }) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -390,7 +527,143 @@ func TestEnv(t *testing.T) { } } +func Test_Error(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + statuses := map[string]error{ + "success": nil, + "open_err": ErrOpen, + "symbol_err": ErrSymbol, + "service_err": ErrService, + "system_err": ErrSystem, + "buf_err": ErrBuf, + "perm_denied": ErrPermDenied, + "auth_err": ErrAuth, + "cred_insufficient": ErrCredInsufficient, + "authinfo_unavail": ErrAuthinfoUnavail, + "user_unknown": ErrUserUnknown, + "maxtries": ErrMaxtries, + "new_authtok_reqd": ErrNewAuthtokReqd, + "acct_expired": ErrAcctExpired, + "session_err": ErrSession, + "cred_unavail": ErrCredUnavail, + "cred_expired": ErrCredExpired, + "cred_err": ErrCred, + "no_module_data": ErrNoModuleData, + "conv_err": ErrConv, + "authtok_err": ErrAuthtok, + "authtok_recover_err": ErrAuthtokRecovery, + "authtok_lock_busy": ErrAuthtokLockBusy, + "authtok_disable_aging": ErrAuthtokDisableAging, + "try_again": ErrTryAgain, + "ignore": nil, /* Ignore can't be returned */ + "abort": ErrAbort, + "authtok_expired": ErrAuthtokExpired, + "module_unknown": ErrModuleUnknown, + "bad_item": ErrBadItem, + "conv_again": ErrConvAgain, + "incomplete": ErrIncomplete, + } + + type Action int + const ( + account Action = iota + 1 + auth + password + session + ) + actions := map[string]Action{ + "account": account, + "auth": auth, + "password": password, + "session": session, + } + + c := Credentials{} + + servicePath := t.TempDir() + + for ret, expected := range statuses { + ret := ret + expected := expected + for actionName, action := range actions { + actionName := actionName + action := action + t.Run(fmt.Sprintf("%s %s", ret, actionName), func(t *testing.T) { + t.Parallel() + serviceName := ret + "-" + actionName + serviceFile := filepath.Join(servicePath, serviceName) + contents := fmt.Sprintf("%[1]s requisite pam_debug.so "+ + "auth=%[2]s cred=%[2]s acct=%[2]s prechauthtok=%[2]s "+ + "chauthtok=%[2]s open_session=%[2]s close_session=%[2]s\n"+ + "%[1]s requisite pam_permit.so\n", actionName, ret) + + if err := os.WriteFile(serviceFile, + []byte(contents), 0600); err != nil { + t.Fatalf("can't create service file %v: %v", serviceFile, err) + } + + tx, err := StartConfDir(serviceName, "user", c, servicePath) + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) + if err != nil { + t.Fatalf("start #error: %v", err) + } + + switch action { + case account: + err = tx.AcctMgmt(0) + case auth: + err = tx.Authenticate(0) + case password: + err = tx.ChangeAuthTok(0) + case session: + err = tx.OpenSession(0) + } + + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected status %#v vs %#v", err, + expected) + } + + if err != nil { + var status Error + if !errors.As(err, &status) || err.Error() != status.Error() { + t.Fatalf("error #unexpected status %#v vs %#v", err.Error(), + status.Error()) + } + } + }) + } + } +} + +func Test_Finalizer(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + func() { + tx, err := StartConfDir("permit-service", "", nil, "test-services") + ensureTransactionEnds(t, tx) + defer maybeEndTransaction(t, tx) + if err != nil { + t.Fatalf("start #error: %v", err) + } + }() + + runtime.GC() + // sleep to switch to finalizer goroutine + time.Sleep(5 * time.Millisecond) +} + func TestFailure_001(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} _, err := tx.GetEnvList() if err == nil { @@ -399,6 +672,7 @@ func TestFailure_001(t *testing.T) { } func TestFailure_002(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.PutEnv("") if err == nil { @@ -407,6 +681,7 @@ func TestFailure_002(t *testing.T) { } func TestFailure_003(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.CloseSession(0) if err == nil { @@ -415,6 +690,7 @@ func TestFailure_003(t *testing.T) { } func TestFailure_004(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.OpenSession(0) if err == nil { @@ -423,6 +699,7 @@ func TestFailure_004(t *testing.T) { } func TestFailure_005(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.ChangeAuthTok(0) if err == nil { @@ -431,6 +708,7 @@ func TestFailure_005(t *testing.T) { } func TestFailure_006(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.AcctMgmt(0) if err == nil { @@ -439,6 +717,7 @@ func TestFailure_006(t *testing.T) { } func TestFailure_007(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.SetCred(0) if err == nil { @@ -447,6 +726,7 @@ func TestFailure_007(t *testing.T) { } func TestFailure_008(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.SetItem(User, "test") if err == nil { @@ -455,9 +735,19 @@ func TestFailure_008(t *testing.T) { } func TestFailure_009(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} _, err := tx.GetItem(User) if err == nil { t.Fatalf("getenvlist #expected an error") } } + +func TestFailure_010(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) + tx := Transaction{} + err := tx.End() + if err != nil { + t.Fatalf("end #unexpected error %v", err) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..ad61daa --- /dev/null +++ b/utils.go @@ -0,0 +1,42 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +/* +#include + +#ifdef __SANITIZE_ADDRESS__ +#include +#endif + +static inline void +maybe_do_leak_check (void) +{ +#ifdef __SANITIZE_ADDRESS__ + __lsan_do_leak_check(); +#endif +} +*/ +import "C" + +import ( + "os" + "runtime" + "time" + "unsafe" +) + +func maybeDoLeakCheck() { + runtime.GC() + time.Sleep(time.Millisecond * 20) + if os.Getenv("GO_PAM_SKIP_LEAK_CHECK") == "" { + C.maybe_do_leak_check() + } +} + +func allocateCBytes(bytes []byte) BinaryPointer { + return BinaryPointer(C.CBytes(bytes)) +} + +func binaryPointerCBytesFinalizer(ptr BinaryPointer) { + C.free(unsafe.Pointer(ptr)) +}