diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 165d5a6..0a74125 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -69,6 +69,7 @@ var ( moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") + parallelConv = flag.Bool("parallel-conv", false, "whether to support performing PAM conversations in parallel") ) // Usage is a replacement usage function for the flags package. @@ -137,6 +138,7 @@ func main() { generateTags: generateTags, noMain: *noMain, typeName: *typeName, + parallelConv: *parallelConv, } // Print the header and package clause. @@ -169,6 +171,7 @@ type Generator struct { generateTags []string buildFlags []string noMain bool + parallelConv bool } func (g *Generator) printf(format string, args ...interface{}) { @@ -186,6 +189,11 @@ func (g *Generator) generate() { buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) } + var transactionCreator = "NewModuleTransactionInvoker" + if g.parallelConv { + transactionCreator = "NewModuleTransactionInvokerParallelConv" + } + // We use a slice since we want to keep order, for reproducible builds. vFuncs := []struct { cName string @@ -257,8 +265,8 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) - err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + mt := pam.%s(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) if err == nil { return 0 @@ -275,7 +283,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrSystem) } -`) +`, transactionCreator) for _, f := range vFuncs { g.printf(` diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 7991d5b..fcdeaa9 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule -parallel-conv //go:generate go generate --skip="pam_module.go" // Package main is the package for the integration tester module PAM shared library. diff --git a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go index 39a22b7..e64a4f9 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go @@ -1,4 +1,4 @@ -// Code generated by "pam-moduler -type integrationTesterModule"; DO NOT EDIT. +// Code generated by "pam-moduler -type integrationTesterModule -parallel-conv"; DO NOT EDIT. //go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module @@ -43,7 +43,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + mt := pam.NewModuleTransactionInvokerParallelConv(pam.NativeHandle(pamh)) err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) if err == nil { diff --git a/module-transaction.go b/module-transaction.go index df1bfa3..fc754a1 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -43,6 +43,7 @@ type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error // ModuleTransaction is the module-side handle for a PAM transaction. type moduleTransaction struct { transactionBase + convMutex *sync.Mutex } // ModuleHandler is an interface for objects that can be used to create @@ -63,10 +64,27 @@ type ModuleTransactionInvoker interface { InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error } -// NewModuleTransactionInvoker allows initializing a transaction invoker from -// the module side. +// NewModuleTransactionParallelConv allows initializing a transaction from the +// module side. Conversations using this transaction can be multi-thread, but +// this requires the application loading the module to support this, otherwise +// we may just break their assumptions. +func NewModuleTransactionParallelConv(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}, nil} +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from the +// module side. func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { - return &moduleTransaction{transactionBase{handle: handle}} + return &moduleTransaction{transactionBase{handle: handle}, &sync.Mutex{}} +} + +// NewModuleTransactionInvokerParallelConv allows initializing a transaction invoker +// from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionInvokerParallelConv(handle NativeHandle) ModuleTransactionInvoker { + return &moduleTransaction{transactionBase{handle: handle}, nil} } func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, @@ -542,6 +560,10 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, goMsgs[i] = cMessage } + if m.convMutex != nil { + m.convMutex.Lock() + defer m.convMutex.Unlock() + } var cResponses *C.struct_pam_response ret := iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses) if ret != success { diff --git a/module-transaction_test.go b/module-transaction_test.go index 85233d3..2e678e0 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -305,11 +305,10 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { } } -func Test_MockModuleTransaction(t *testing.T) { +func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { + t.Helper() t.Parallel() - mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) - tests := map[string]struct { testFunc func(mock *mockModuleTransaction) (any, error) mockExpectations mockModuleTransactionExpectations @@ -914,3 +913,13 @@ func Test_MockModuleTransaction(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +} + +func Test_MockModuleTransactionParallelConv(t *testing.T) { + mt, _ := NewModuleTransactionInvokerParallelConv(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +}