transaction: Use Atomic to store/load the status

Transactions save the status of each operation in a status field, however
such field could be written concurrently by various operations, so we
need to be sure that:
 - We always return the status for the current operation
 - We store the status in a atomic way so that other actions won't
   create write races

In general, in a multi-thread operation one should not rely on
Transaction.Error() to get info about the last operation.
This commit is contained in:
Marco Trevisan (Treviño)
2023-09-29 23:01:50 +02:00
parent 3e4f7f5e4b
commit 911a346a00

View File

@@ -27,6 +27,7 @@ import (
"runtime"
"runtime/cgo"
"strings"
"sync/atomic"
"unsafe"
)
@@ -129,22 +130,22 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) {
//
//nolint:errname
type Transaction struct {
handle *C.pam_handle_t
conv *C.struct_pam_conv
status C.int
c cgo.Handle
handle *C.pam_handle_t
conv *C.struct_pam_conv
lastStatus atomic.Int32
c cgo.Handle
}
// transactionFinalizer cleans up the PAM handle and deletes the callback
// function.
func transactionFinalizer(t *Transaction) {
C.pam_end(t.handle, t.status)
C.pam_end(t.handle, C.int(t.lastStatus.Load()))
t.c.Delete()
}
// Allows to call pam functions managing return status
func (t *Transaction) handlePamStatus(cStatus C.int) error {
t.status = cStatus
t.lastStatus.Store(int32(cStatus))
if cStatus != success {
return t
}
@@ -212,13 +213,13 @@ func start(service, user string, handler ConversationHandler, confDir string) (*
err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle))
}
if err != nil {
return nil, errors.Join(Error(t.status), err)
return nil, errors.Join(Error(t.lastStatus.Load()), err)
}
return t, nil
}
func (t *Transaction) Error() string {
return Error(t.status).Error()
return Error(t.lastStatus.Load()).Error()
}
// Item is a an PAM information type.
@@ -363,8 +364,10 @@ func (t *Transaction) GetEnvList() (map[string]string, error) {
env := make(map[string]string)
p := C.pam_getenvlist(t.handle)
if p == nil {
return nil, t.handlePamStatus(C.int(ErrBuf))
t.lastStatus.Store(int32(ErrBuf))
return nil, t
}
t.lastStatus.Store(success)
for q := p; *q != nil; q = next(q) {
chunks := strings.SplitN(C.GoString(*q), "=", 2)
if len(chunks) == 2 {