Skip to content

Commit 1916e11

Browse files
Use non-blocking send on pid unset, retry on close (#176)
* tinker with async close * docs * docs.. * add badfd condition * lint * format * rewrite * error checking * fmt * rewrite * add else block * add tests, dumb linter * linter... * LINTER * refactor to use switch * still tinkering... * tinkering... * improve docs * refactor, again
1 parent c21d056 commit 1916e11

File tree

3 files changed

+182
-9
lines changed

3 files changed

+182
-9
lines changed

audit.go

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ const (
4949
// Netlink groups.
5050
const (
5151
NetlinkGroupNone = iota // Group 0 not used
52-
NetlinkGroupReadLog // "best effort" read only socket
52+
NetlinkGroupReadLog // "best effort" read only socket, defined in the kernel as AUDIT_NLGRP_READLOG
5353
)
5454

5555
// WaitMode is a flag to control the behavior of methods that abstract
@@ -427,16 +427,11 @@ func (c *AuditClient) Receive(nonBlocking bool) (*RawAuditMessage, error) {
427427
// become no-ops.
428428
func (c *AuditClient) Close() error {
429429
var err error
430-
431430
// Only unregister and close the socket once.
432431
c.closeOnce.Do(func() {
433432
if c.clearPIDOnClose {
434433
// Unregister from the kernel for a clean exit.
435-
status := AuditStatus{
436-
Mask: AuditStatusPID,
437-
PID: 0,
438-
}
439-
err = c.set(status, NoWait)
434+
err = c.closeAndUnsetPid()
440435
}
441436

442437
err = errors.Join(err, c.Netlink.Close())
@@ -505,6 +500,70 @@ func (c *AuditClient) getReply(seq uint32) (*syscall.NetlinkMessage, error) {
505500
return &msg, nil
506501
}
507502

503+
// unset our pid from the audit subsystem and close the socket.
504+
// This is a sort of isolated refactor, meant to deal with the deadlocks that can happen when we're not careful with blocking operations throughout a lot of this code.
505+
func (c *AuditClient) closeAndUnsetPid() error {
506+
msg := syscall.NetlinkMessage{
507+
Header: syscall.NlMsghdr{
508+
Type: AuditSet,
509+
Flags: syscall.NLM_F_REQUEST,
510+
},
511+
Data: AuditStatus{
512+
Mask: AuditStatusPID,
513+
PID: 0,
514+
}.toWireFormat(),
515+
}
516+
517+
// If our request to unset the PID would block, then try to drain events from
518+
// the netlink socket, resend, try again.
519+
// In netlink, EAGAIN usually indicates our read buffer is full.
520+
// The auditd code (which I'm using as a reference implementation) doesn't wait for a response when unsetting the audit pid.
521+
// The retry count here is largely arbitrary, and provides a buffer for either transient errors (EINTR) or retries.
522+
retries := 5
523+
outer:
524+
for i := 0; i < retries; i++ {
525+
_, err := c.Netlink.SendNoWait(msg)
526+
switch {
527+
case err == nil:
528+
return nil
529+
case errors.Is(err, syscall.EINTR):
530+
// got a transient interrupt, try again
531+
continue
532+
case errors.Is(err, syscall.EAGAIN):
533+
// send would block, try to drain the receive socket. The recv count here is just so we have enough of a buffer to attempt a send again/
534+
// The number is just here so we ideally have enough of a buffer to attempt the send again.
535+
maxRecv := 10000
536+
for i := 0; i < maxRecv; i++ {
537+
_, err = c.Netlink.Receive(true, noParse)
538+
switch {
539+
case err == nil, errors.Is(err, syscall.EINTR), errors.Is(err, syscall.ENOBUFS):
540+
// continue with receive, try to read more data
541+
continue
542+
case errors.Is(err, syscall.EAGAIN):
543+
// receive would block, try to send again
544+
continue outer
545+
default:
546+
// if receive returns an other error, just return that.
547+
return err
548+
}
549+
}
550+
default:
551+
// if Send returns and other error, just return that
552+
return err
553+
}
554+
555+
}
556+
// we may not want to treat this as a hard error?
557+
// It's not a massive error if this fails, since the kernel will unset the PID if it can't communicate with the process,
558+
// so this is largely for neatness.
559+
return fmt.Errorf("could not unset pid from audit after retries")
560+
}
561+
562+
// noParse is a no-op parser used by closeAndUnsetPID
563+
func noParse([]byte) ([]syscall.NetlinkMessage, error) {
564+
return nil, nil
565+
}
566+
508567
func (c *AuditClient) set(status AuditStatus, mode WaitMode) error {
509568
msg := syscall.NetlinkMessage{
510569
Header: syscall.NlMsghdr{
@@ -560,7 +619,7 @@ func parseNetlinkAuditMessage(buf []byte) ([]syscall.NetlinkMessage, error) {
560619
// https://github.com/linux-audit/audit-kernel/blob/v4.7/include/uapi/linux/audit.h#L318-L325
561620
type AuditStatusMask uint32
562621

563-
// Mask types for AuditStatus.
622+
// Mask types for AuditStatus. Originally defined in the kernel at audit.h
564623
const (
565624
AuditStatusEnabled AuditStatusMask = 1 << iota
566625
AuditStatusFailure

audit_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ import (
3030
"io"
3131
"os"
3232
"runtime"
33+
"slices"
34+
"sync"
3335
"syscall"
3436
"testing"
3537
"testing/quick"
3638
"time"
3739

3840
"github.com/stretchr/testify/assert"
41+
"github.com/stretchr/testify/require"
3942

4043
"github.com/elastic/go-libaudit/v2/rule"
4144
"github.com/elastic/go-libaudit/v2/rule/flags"
@@ -55,6 +58,107 @@ var (
5558
// -a always,exit -S open,truncate -F dir=/etc -F success=0
5659
const testRule = `BAAAAAIAAAACAAAABAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGsAAABoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAvZXRj`
5760

61+
// TestNetlinkIface is a mock interface for testing close behavior
62+
type TestNetlinkIface struct {
63+
recvStack []error
64+
sendStack []error
65+
}
66+
67+
func (*TestNetlinkIface) Close() error {
68+
return nil
69+
}
70+
71+
func (tn *TestNetlinkIface) Send(_ syscall.NetlinkMessage) (uint32, error) {
72+
top := tn.sendStack[0]
73+
tn.sendStack = slices.Delete(tn.sendStack, 0, 1)
74+
return 0, top
75+
}
76+
77+
func (tn *TestNetlinkIface) SendNoWait(_ syscall.NetlinkMessage) (uint32, error) {
78+
top := tn.sendStack[0]
79+
tn.sendStack = slices.Delete(tn.sendStack, 0, 1)
80+
return 0, top
81+
}
82+
83+
func (tn *TestNetlinkIface) Receive(_ bool, _ NetlinkParser) ([]syscall.NetlinkMessage, error) {
84+
top := tn.recvStack[0]
85+
tn.recvStack = slices.Delete(tn.recvStack, 0, 1)
86+
return nil, top
87+
}
88+
89+
func TestCloseBehavior(t *testing.T) {
90+
testCases := []struct {
91+
name string
92+
cfg *TestNetlinkIface
93+
err error
94+
}{
95+
{
96+
name: "retry",
97+
cfg: &TestNetlinkIface{
98+
// cause the first send to error out
99+
sendStack: []error{syscall.EWOULDBLOCK, nil, nil},
100+
// force the close logic to drain
101+
recvStack: []error{syscall.ENOBUFS, syscall.ENOBUFS, syscall.EAGAIN},
102+
},
103+
err: nil,
104+
},
105+
{
106+
name: "repeated-send-fail",
107+
cfg: &TestNetlinkIface{
108+
// cause the first send to error out
109+
sendStack: []error{syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, nil},
110+
// force the close logic to drain
111+
recvStack: []error{syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, nil, syscall.EWOULDBLOCK, nil, syscall.EWOULDBLOCK, nil},
112+
},
113+
err: nil,
114+
},
115+
{
116+
name: "transient-eintr-send",
117+
cfg: &TestNetlinkIface{
118+
// cause the first send to error out
119+
sendStack: []error{syscall.EINTR, nil, nil},
120+
// force the close logic to drain
121+
recvStack: []error{syscall.EAGAIN},
122+
},
123+
err: nil,
124+
},
125+
{
126+
name: "fail-recv-error",
127+
cfg: &TestNetlinkIface{
128+
// cause the first send to error out
129+
sendStack: []error{syscall.EWOULDBLOCK, nil, nil},
130+
// force the close logic to drain
131+
recvStack: []error{syscall.ENOBUFS, syscall.ENOBUFS, syscall.EBADFD},
132+
},
133+
err: syscall.EBADFD,
134+
},
135+
{
136+
name: "fail-send-error",
137+
cfg: &TestNetlinkIface{
138+
// cause the first send to error out
139+
sendStack: []error{syscall.EWOULDBLOCK, syscall.EBADFD, nil},
140+
// force the close logic to drain
141+
recvStack: []error{syscall.EAGAIN, syscall.EAGAIN},
142+
},
143+
err: syscall.EBADFD,
144+
},
145+
}
146+
147+
for _, test := range testCases {
148+
t.Run(test.name, func(t *testing.T) {
149+
testClient := AuditClient{
150+
Netlink: test.cfg,
151+
pendingAcks: []uint32{},
152+
clearPIDOnClose: true,
153+
closeOnce: sync.Once{},
154+
}
155+
156+
err := testClient.Close()
157+
require.True(t, errors.Is(err, test.err), "expected error %s", test.err)
158+
})
159+
}
160+
}
161+
58162
func TestAuditClientGetStatus(t *testing.T) {
59163
if os.Geteuid() != 0 {
60164
t.Skip("must be root to get audit status")

netlink.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
// in the message and an error if it occurred.
3737
type NetlinkSender interface {
3838
Send(msg syscall.NetlinkMessage) (uint32, error)
39+
SendNoWait(msg syscall.NetlinkMessage) (uint32, error)
3940
}
4041

4142
// NetlinkReceiver receives data from the netlink socket and uses the provided
@@ -126,17 +127,26 @@ func getPortID(fd int) (uint32, error) {
126127
return addr.Pid, nil
127128
}
128129

130+
// SendNoWait sends a message to the netlink client in non-blocking mode. Behavior is otherwise identical to Send()
131+
func (c *NetlinkClient) SendNoWait(msg syscall.NetlinkMessage) (uint32, error) {
132+
return c.send(msg, syscall.MSG_DONTWAIT)
133+
}
134+
129135
// Send sends a netlink message and returns the sequence number used
130136
// in the message and an error if it occurred. If the PID is not set then
131137
// the value will be populated automatically (recommended).
132138
func (c *NetlinkClient) Send(msg syscall.NetlinkMessage) (uint32, error) {
139+
return c.send(msg, 0)
140+
}
141+
142+
func (c *NetlinkClient) send(msg syscall.NetlinkMessage, flags int) (uint32, error) {
133143
if msg.Header.Pid == 0 {
134144
msg.Header.Pid = c.pid
135145
}
136146

137147
msg.Header.Seq = atomic.AddUint32(&c.seq, 1)
138148
to := &syscall.SockaddrNetlink{}
139-
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), 0, to)
149+
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), flags, to)
140150
}
141151

142152
func serialize(msg syscall.NetlinkMessage) []byte {

0 commit comments

Comments
 (0)