In order to properly test the interaction of a module transaction from the application point of view, we need to perform operation in the module and ensure that the expected values are returned and handled In order to do this, without using the PAM apis that we want to test, use a simple trick: - Create an application that works as server using an unix socket - Create a module that connects to it - Pass the socket to the module via the module service file arguments - Add some basic protocol that allows the application to send a request and to the module to reply to that. - Use reflection and serialization to automatically call module methods and return the values to the application where we do the check
138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule
|
|
//go:generate go generate --skip="pam_module.go"
|
|
|
|
// Package main is the package for the integration tester module PAM shared library.
|
|
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/msteinert/pam/v2"
|
|
"github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils"
|
|
)
|
|
|
|
type integrationTesterModule struct {
|
|
utils.BaseModule
|
|
}
|
|
|
|
type authRequest struct {
|
|
mt pam.ModuleTransaction
|
|
lastError error
|
|
}
|
|
|
|
func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request) (res *Result, err error) {
|
|
switch r.Action {
|
|
case "bye":
|
|
return nil, authReq.lastError
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
if s, ok := p.(string); ok {
|
|
if strings.HasPrefix(s, "reflect:") {
|
|
res = nil
|
|
err = &utils.SerializableError{Msg: fmt.Sprintf(
|
|
"error on request %v: %v", *r, p)}
|
|
authReq.lastError = err
|
|
return
|
|
}
|
|
}
|
|
panic(p)
|
|
}
|
|
|
|
if err != nil {
|
|
authReq.lastError = err
|
|
}
|
|
}()
|
|
|
|
method := reflect.ValueOf(authReq.mt).MethodByName(r.Action)
|
|
if method == (reflect.Value{}) {
|
|
return nil, &utils.SerializableError{Msg: fmt.Sprintf(
|
|
"no method %s found", r.Action)}
|
|
}
|
|
|
|
var args []reflect.Value
|
|
for _, arg := range r.ActionArgs {
|
|
args = append(args, reflect.ValueOf(arg))
|
|
}
|
|
|
|
res = &Result{Action: "return"}
|
|
for _, ret := range method.Call(args) {
|
|
iface := ret.Interface()
|
|
switch value := iface.(type) {
|
|
case pam.Error:
|
|
authReq.lastError = value
|
|
res.ActionArgs = append(res.ActionArgs, value)
|
|
case error:
|
|
var pamError pam.Error
|
|
if errors.As(value, &pamError) {
|
|
retErr := &SerializablePamError{Msg: value.Error(),
|
|
RetStatus: pamError}
|
|
authReq.lastError = retErr
|
|
res.ActionArgs = append(res.ActionArgs, retErr)
|
|
return res, err
|
|
}
|
|
authReq.lastError = value
|
|
res.ActionArgs = append(res.ActionArgs,
|
|
&utils.SerializableError{Msg: value.Error()})
|
|
default:
|
|
res.ActionArgs = append(res.ActionArgs, iface)
|
|
}
|
|
}
|
|
return res, err
|
|
}
|
|
|
|
func (m *integrationTesterModule) handleError(err error) *Result {
|
|
return &Result{
|
|
Action: "error",
|
|
ActionArgs: []interface{}{&utils.SerializableError{Msg: err.Error()}},
|
|
}
|
|
}
|
|
|
|
func (m *integrationTesterModule) Authenticate(mt pam.ModuleTransaction, _ pam.Flags, args []string) error {
|
|
if len(args) != 1 {
|
|
return errors.New("Invalid arguments")
|
|
}
|
|
|
|
authRequest := authRequest{mt, nil}
|
|
connection := NewConnector(args[0])
|
|
if err := connection.Connect(); err != nil {
|
|
return err
|
|
}
|
|
|
|
connectionHandler := func() error {
|
|
if err := connection.SendRequest(&Request{Action: "hello"}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
req, err := connection.WaitForData()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err := m.handleRequest(&authRequest, req)
|
|
if err != nil {
|
|
_ = connection.SendResult(m.handleError(err))
|
|
return err
|
|
}
|
|
if res == nil {
|
|
return nil
|
|
}
|
|
if err := connection.SendResult(res); err != nil {
|
|
_ = connection.SendResult(m.handleError(err))
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := connectionHandler(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|