diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 68f4852..165d5a6 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -11,7 +11,7 @@ // // For example: // -// //go:generate go run github.com/msteinert/pam/pam-moduler +// //go:generate go run github.com/msteinert/pam/v2/pam-moduler // //go:generate go generate --skip="pam_module" // package main // diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication.go b/cmd/pam-moduler/tests/integration-tester-module/communication.go new file mode 100644 index 0000000..67bada4 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication.go @@ -0,0 +1,230 @@ +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "io" + "net" + "runtime" +) + +// Request is a serializable integration module tester structure request. +type Request struct { + Action string + ActionArgs []interface{} +} + +// Result is a serializable integration module tester structure result. +type Result = Request + +// NewRequest returns a new Request. +func NewRequest(action string, actionArgs ...interface{}) Request { + return Request{action, actionArgs} +} + +// GOB serializes the request in binary format. +func (r *Request) GOB() ([]byte, error) { + b := bytes.Buffer{} + e := gob.NewEncoder(&b) + if err := e.Encode(r); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// NewRequestFromGOB gets a Request from a serialized binary. +func NewRequestFromGOB(data []byte) (*Request, error) { + b := bytes.Buffer{} + b.Write(data) + d := gob.NewDecoder(&b) + + var req Request + if err := d.Decode(&req); err != nil { + return nil, err + } + return &req, nil +} + +const bufSize = 1024 + +type connectionHandler struct { + inOutData chan []byte + outErr chan error + SocketPath string +} + +// Listener is a socket listener. +type Listener struct { + connectionHandler + listener net.Listener +} + +// NewListener creates a new Listener. +func NewListener(socketPath string) *Listener { + if len(socketPath) > 90 { + // See https://manpages.ubuntu.com/manpages/jammy/man7/sys_un.h.7posix.html#application%20usage + panic(fmt.Sprintf("Socket path %s too long", socketPath)) + } + return &Listener{connectionHandler{SocketPath: socketPath}, nil} +} + +// WaitForData waits for result data (or an error) on connection to be returned. +func (c *connectionHandler) WaitForData() (*Result, error) { + data, err := <-c.inOutData, <-c.outErr + if err != nil { + if errors.Is(err, io.EOF) { + return nil, nil + } + return nil, err + } + + req, err := NewRequestFromGOB(data) + if err != nil { + return nil, err + } + + return req, nil +} + +// SendRequest sends a request to the connection. +func (c *connectionHandler) SendRequest(req *Request) error { + bytes, err := req.GOB() + if err != nil { + return err + } + + c.inOutData <- bytes + return nil +} + +// SendResult sends the Result to the connection. +func (c *connectionHandler) SendResult(res *Result) error { + return c.SendRequest(res) +} + +// DoRequest performs a Request on the connection, waiting for data. +func (c *connectionHandler) DoRequest(req *Request) (*Result, error) { + if err := c.SendRequest(req); err != nil { + return nil, err + } + + return c.WaitForData() +} + +// Send performs a request. +func (r *Request) Send(c *connectionHandler) error { + return c.SendRequest(r) +} + +// ErrAlreadyListening is the error if a listener is already set. +var ErrAlreadyListening = errors.New("listener already set") + +// StartListening initiates the unix listener. +func (l *Listener) StartListening() error { + if l.listener != nil { + return ErrAlreadyListening + } + + listener, err := net.Listen("unix", l.SocketPath) + if err != nil { + return err + } + + l.listener = listener + l.inOutData, l.outErr = make(chan []byte), make(chan error) + + go func() { + bytes, err := func() ([]byte, error) { + for { + c, err := l.listener.Accept() + if err != nil { + return nil, err + } + + for { + buf := make([]byte, bufSize) + nr, err := c.Read(buf) + if err != nil { + return buf, err + } + + data := buf[0:nr] + l.inOutData <- data + l.outErr <- nil + + _, err = c.Write(<-l.inOutData) + if err != nil { + return nil, err + } + } + } + }() + + l.inOutData <- bytes + l.outErr <- err + }() + + return nil +} + +// Connector is a connection type. +type Connector struct { + connectionHandler + connection net.Conn +} + +// NewConnector creates a new connection. +func NewConnector(socketPath string) *Connector { + return &Connector{connectionHandler{SocketPath: socketPath}, nil} +} + +// ErrAlreadyConnected is the error if a connection is already set. +var ErrAlreadyConnected = errors.New("connection already set") + +// Connect connects to a listening unix socket. +func (c *Connector) Connect() error { + if c.connection != nil { + return ErrAlreadyConnected + } + + connection, err := net.Dial("unix", c.SocketPath) + if err != nil { + return err + } + + runtime.SetFinalizer(c, func(c *Connector) { + c.connection.Close() + }) + + c.connection = connection + c.inOutData, c.outErr = make(chan []byte), make(chan error) + + go func() { + buf := make([]byte, bufSize) + writeAndRead := func() ([]byte, error) { + data := <-c.inOutData + _, err := c.connection.Write(data) + if err != nil { + return nil, err + } + + n, err := c.connection.Read(buf[:]) + if err != nil { + return nil, err + } + + return buf[0:n], nil + } + + for { + bytes, err := writeAndRead() + c.inOutData <- bytes + c.outErr <- err + } + }() + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go new file mode 100644 index 0000000..7abc2e3 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "errors" + "path/filepath" + "reflect" + "testing" + + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func ensureError(t *testing.T, err error, expected error) { + t.Helper() + if err == nil { + t.Fatalf("error was expected, got none") + } + if !errors.Is(err, expected) { + t.Fatalf("error %v was expected, got %v", err, expected) + } +} + +func ensureEqual(t *testing.T, a any, b any) { + t.Helper() + if !reflect.DeepEqual(a, b) { + t.Fatalf("values mismatch %v vs %v", a, b) + } +} + +func Test_Communication(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + + for _, name := range []string{"test-1", "test-2"} { + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + + listener := NewListener(socketPath) + connector := NewConnector(socketPath) + + ensureNoError(t, listener.StartListening()) + ensureNoError(t, connector.Connect()) + + ensureError(t, listener.StartListening(), ErrAlreadyListening) + ensureError(t, connector.Connect(), ErrAlreadyConnected) + + resChan, errChan := make(chan *Result), make(chan error) + go func() { + res, err := listener.WaitForData() + resChan <- res + errChan <- err + }() + + req := NewRequest("A Request") + ensureNoError(t, connector.SendRequest(&req)) + + res, err := <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, req) + + go func() { + res := NewRequest("Listener result") + ensureNoError(t, listener.SendResult(&res)) + }() + + res, err = connector.WaitForData() + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Listener result")) + + go func() { + req, err := listener.WaitForData() + res := NewRequest("Response", *req) + + defer func() { + resChan <- &res + errChan <- err + }() + ensureNoError(t, listener.SendResult(&res)) + }() + + done := make(chan bool) + req = NewRequest("Requesting...") + go func() { + defer func() { + done <- true + }() + res, err := connector.DoRequest(&req) + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + }() + + res, err = <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + <-done + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go new file mode 100644 index 0000000..995e0c2 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -0,0 +1,137 @@ +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule +//go:generate go generate --skip="pam_module.go" + +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +type integrationTesterModule struct { + utils.BaseModule +} + +type authRequest struct { + mt pam.ModuleTransaction + lastError error +} + +func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request) (res *Result, err error) { + switch r.Action { + case "bye": + return nil, authReq.lastError + } + + defer func() { + if p := recover(); p != nil { + if s, ok := p.(string); ok { + if strings.HasPrefix(s, "reflect:") { + res = nil + err = &utils.SerializableError{Msg: fmt.Sprintf( + "error on request %v: %v", *r, p)} + authReq.lastError = err + return + } + } + panic(p) + } + + if err != nil { + authReq.lastError = err + } + }() + + method := reflect.ValueOf(authReq.mt).MethodByName(r.Action) + if method == (reflect.Value{}) { + return nil, &utils.SerializableError{Msg: fmt.Sprintf( + "no method %s found", r.Action)} + } + + var args []reflect.Value + for _, arg := range r.ActionArgs { + args = append(args, reflect.ValueOf(arg)) + } + + res = &Result{Action: "return"} + for _, ret := range method.Call(args) { + iface := ret.Interface() + switch value := iface.(type) { + case pam.Error: + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, value) + case error: + var pamError pam.Error + if errors.As(value, &pamError) { + retErr := &SerializablePamError{Msg: value.Error(), + RetStatus: pamError} + authReq.lastError = retErr + res.ActionArgs = append(res.ActionArgs, retErr) + return res, err + } + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, + &utils.SerializableError{Msg: value.Error()}) + default: + res.ActionArgs = append(res.ActionArgs, iface) + } + } + return res, err +} + +func (m *integrationTesterModule) handleError(err error) *Result { + return &Result{ + Action: "error", + ActionArgs: []interface{}{&utils.SerializableError{Msg: err.Error()}}, + } +} + +func (m *integrationTesterModule) Authenticate(mt pam.ModuleTransaction, _ pam.Flags, args []string) error { + if len(args) != 1 { + return errors.New("Invalid arguments") + } + + authRequest := authRequest{mt, nil} + connection := NewConnector(args[0]) + if err := connection.Connect(); err != nil { + return err + } + + connectionHandler := func() error { + if err := connection.SendRequest(&Request{Action: "hello"}); err != nil { + return err + } + + for { + req, err := connection.WaitForData() + if err != nil { + return err + } + + res, err := m.handleRequest(&authRequest, req) + if err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + if res == nil { + return nil + } + if err := connection.SendResult(res); err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + } + } + + if err := connectionHandler(); err != nil { + return err + } + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go new file mode 100644 index 0000000..ecde5ce --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -0,0 +1,783 @@ +package main + +import ( + "errors" + "fmt" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func (r *Request) check(res *Result, expectedResults []interface{}) error { + switch res.Action { + case "return": + case "error": + return fmt.Errorf("module error: %v", res.ActionArgs...) + default: + return fmt.Errorf("unexpected action %v", res.Action) + } + + if !reflect.DeepEqual(res.ActionArgs, expectedResults) { + return fmt.Errorf("unexpected return values %#v vs %#v", + res.ActionArgs, expectedResults) + } + + return nil +} + +func (r *Request) checkRemote(listener *Listener, expectedResults []interface{}) error { + res, err := listener.DoRequest(r) + if err != nil { + return err + } + + return res.check(res, expectedResults) +} + +type checkedRequest struct { + r Request + exp []interface{} + compareWithTestState bool +} + +func (cr *checkedRequest) checkRemote(listener *Listener) error { + return cr.r.checkRemote(listener, cr.exp) +} + +func (cr *checkedRequest) check(res *Result) error { + return cr.r.check(res, cr.exp) +} + +func ensureItem(tx *pam.Transaction, item pam.Item, expected string) error { + if value, err := tx.GetItem(item); err != nil { + return err + } else if value != expected { + return fmt.Errorf("invalid item %v value: %s vs %v", item, value, expected) + } + return nil +} + +func ensureEnv(tx *pam.Transaction, variable string, expected string) error { + if env := tx.GetEnv(variable); env != expected { + return fmt.Errorf("unexpected env %s value: %s vs %s", variable, env, expected) + } + return nil +} + +func Test_Moduler_IntegrationTesterModule(t *testing.T) { + t.Parallel() + if !pam.CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + modulePath := ts.GenerateModuleDefault(ts.GetCurrentFileDir()) + + type testState = map[string]interface{} + + tests := map[string]struct { + expectedError error + user string + credentials pam.ConversationHandler + checkedRequests []checkedRequest + setup func(*pam.Transaction, *Listener, testState) error + finish func(*pam.Transaction, *Listener, testState) error + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"get-item-service", nil}, + }}, + }, + "get-item-User-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"test-user", nil}, + }}, + }, + "get-item-Authtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Authtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-Oldauthtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Oldauthtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-UserPrompt-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.UserPrompt), + exp: []interface{}{"", nil}, + }}, + }, + "set-item-Service": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.Service, "foo-service"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"foo-service", nil}, + }, + }, + }, + "set-item-User-empty": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureItem(tx, pam.User, "an-user") + }, + }, + "set-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureItem(tx, pam.User, "an-user") + }, + }, + "set-get-item-User-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "set-get-item-User-preset": { + user: "test-user", + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "get-env-unset": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_HOPEFULLY_NOT_SET"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_HOPEFULLY_NOT_SET", "") + }, + }, + "get-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"foobar"}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "foobar") + }, + }, + "get-env-preset-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "get-env-preset-unset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "a value") + }, + }, + "put-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=another value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"another value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "another value") + }, + }, + "put-env-resets-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-resets-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-unsets-not-set": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_NEVER_SET"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "put-env-unsets-empty-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{ + map[string]string{"_PAM_GO_ENV_SET_VAR": ""}, nil, + }, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }, + }, + }, + "put-env-invalid-syntax": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "="), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "=bar"), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "with spaces"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "get-env-list-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return nil + }, + }, + "get-env-list-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + for env, value := range expected { + if err := tx.PutEnv(fmt.Sprintf("%s=%s", env, value)); err != nil { + return err + } + } + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{{expected, nil}} + return nil + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("Unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + "get-env-list-module-set": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{ + nil, nil, nil, nil, nil, nil, nil, {expected, nil}, + } + return nil + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR1=value1"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR2=value due"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR3=3"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_EMPTY="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET=unset"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV WITH SPACES=yes works"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + ts.CreateService(name, []utils.ServiceLine{ + {Action: utils.Auth, Control: utils.Requisite, Module: modulePath, + Args: []string{socketPath}}, + }) + + tx, err := pam.StartConfDir(name, tc.user, tc.credentials, ts.WorkDir()) + if err != nil { + t.Fatalf("start #error: %v", err) + } + defer func() { + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } + }() + + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + req := NewRequest("GetItem", pam.Service) + if err := req.checkRemote(listener, + []interface{}{strings.ToLower(name), nil}); err != nil { + return err + } + + testState := testState{} + if tc.setup != nil { + if err := tc.setup(tx, listener, testState); err != nil { + return err + } + } + + for i, req := range tc.checkedRequests { + if req.compareWithTestState { + expectedResults, _ := testState["expectedResults"].([][]interface{}) + if err := req.r.checkRemote(listener, expectedResults[i]); err != nil { + return err + } + } else if err := req.checkRemote(listener); err != nil { + return err + } + } + + if tc.finish != nil { + if err := tc.finish(tx, listener, testState); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- tx.Authenticate(pam.Silent) + }() + + if err = <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err = <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } + + t.Cleanup(func() { + // Ensure GC will happen, so that transaction's pam_end will be called + runtime.GC() + time.Sleep(5 * time.Millisecond) + }) +} + +func Test_Moduler_IntegrationTesterModule_handleRequest(t *testing.T) { + t.Parallel() + + module := integrationTesterModule{} + mt := pam.NewModuleTransactionInvoker(nil) + + tests := []struct { + checkedRequest + name string + parallel bool + }{ + { + name: "putEnv", + checkedRequest: checkedRequest{ + r: NewRequest("PutEnv", "FOO_ENV=Bar"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + { + parallel: true, + name: "get-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + { + parallel: true, + name: "set-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("SetItem", pam.Service, "foo"), + exp: []interface{}{pam.ErrSystem}, + }, + }, + } + + for _, cr := range tests { + cr := cr + t.Run(cr.name, func(t *testing.T) { + if cr.parallel { + t.Parallel() + } + + authRequest := authRequest{mt, nil} + res, err := module.handleRequest(&authRequest, &cr.r) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if res.Action != "return" { + t.Fatalf("unexpected result action %v", res.Action) + } + + if err := cr.check(res); err != nil { + t.Fatalf("unexpected result %v", err) + } + }) + } + + t.Run("missing-method", func(t *testing.T) { + t.Parallel() + req := NewRequest("Hopefully a missing method") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) + + t.Run("wrong-signature", func(t *testing.T) { + t.Parallel() + req := NewRequest("GetItem", "this", "and", 3, "of that") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) +} + +func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + module := integrationTesterModule{} + + tests := map[string]struct { + expectedError error + credentials pam.ConversationHandler + checkedRequests []checkedRequest + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "get-item-User": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "putEnv": { + expectedError: pam.ErrAbort, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "FooBar=Baz"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + for _, req := range tc.checkedRequests { + if err := req.checkRemote(listener); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- module.Authenticate( + pam.NewModuleTransactionInvoker(nil), + pam.Silent, []string{socketPath}) + }() + + if err := <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err := <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go new file mode 100644 index 0000000..39a22b7 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go @@ -0,0 +1,95 @@ +// Code generated by "pam-moduler -type integrationTesterModule"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +var pamModuleHandler pam.ModuleHandler = &integrationTesterModule{} + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/cmd/pam-moduler/tests/integration-tester-module/serialization.go b/cmd/pam-moduler/tests/integration-tester-module/serialization.go new file mode 100644 index 0000000..33b26a7 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/serialization.go @@ -0,0 +1,35 @@ +package main + +import ( + "encoding/gob" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +// SerializablePamError represents a [pam.Error] in a +// serializable way that splits message and return code. +type SerializablePamError struct { + Msg string + RetStatus pam.Error +} + +// NewSerializablePamError initializes a SerializablePamError from +// the default status error message. +func NewSerializablePamError(status pam.Error) SerializablePamError { + return SerializablePamError{Msg: status.Error(), RetStatus: status} +} + +func (e *SerializablePamError) Error() string { + return e.RetStatus.Error() +} + +func init() { + gob.Register(map[string]string{}) + gob.Register(Request{}) + gob.Register(pam.Item(0)) + gob.Register(pam.Error(0)) + gob.RegisterName("main.SerializablePamError", + SerializablePamError{}) + gob.Register(utils.SerializableError{}) +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index 556f160..3fc6b0c 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -97,3 +97,12 @@ func (a FallBackModule) String() string { return "" } } + +// SerializableError is a representation of an error in a way can be serialized. +type SerializableError struct { + Msg string +} + +func (e *SerializableError) Error() string { + return e.Msg +}