Skip to content

Commit eeab7a9

Browse files
committed
fix: make exec start atomic with PID registration
Addresses a race condition where an exec process could start but not be registered in execPIDs before the kill loop begins, leaving orphan processes. - Add StartExecCmd() that atomically checks kill loop state, starts the command, and registers the PID under a single lock - Add ExecStarter interface for testability - Update ExecContainer and ExecSyncContainer to use atomic method Signed-off-by: Willian Paixao <[email protected]>
1 parent a97e4b9 commit eeab7a9

File tree

3 files changed

+162
-57
lines changed

3 files changed

+162
-57
lines changed

internal/oci/container.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,35 @@ func (c *Container) AddExecPID(pid int, shouldKill bool) error {
888888
return nil
889889
}
890890

891+
// StartExecCmd atomically starts an exec command and registers its PID.
892+
func (c *Container) StartExecCmd(cmd ExecStarter, shouldKill bool) (int, error) {
893+
c.stopLock.Lock()
894+
defer c.stopLock.Unlock()
895+
896+
// Check before starting - if kill loop has begun, don't start new execs
897+
if c.stopKillLoopBegun {
898+
return 0, errors.New("cannot start exec: container is being killed")
899+
}
900+
901+
// Start the command while holding the lock
902+
if err := cmd.Start(); err != nil {
903+
return 0, err
904+
}
905+
906+
pid := cmd.GetPid()
907+
logrus.Debugf("Started and tracking exec PID %d for container %s (should kill = %t) ...", pid, c.ID(), shouldKill)
908+
c.execPIDs[pid] = shouldKill
909+
910+
return pid, nil
911+
}
912+
913+
// ExecStarter is an interface for starting exec commands.
914+
// This abstraction allows for easier testing and decouples from exec.Cmd.
915+
type ExecStarter interface {
916+
Start() error
917+
GetPid() int
918+
}
919+
891920
// DeleteExecPID is for deregistering a pid after it has exited.
892921
func (c *Container) DeleteExecPID(pid int) {
893922
c.stopLock.Lock()

internal/oci/container_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,85 @@ var _ = t.Describe("Container", func() {
869869
})
870870
})
871871

872+
t.Describe("StartExecCmd", func() {
873+
It("should fail when kill loop has begun", func() {
874+
// Given
875+
sut.SetAsStopping()
876+
sut.SetStopKillLoopBegun()
877+
878+
mockStarter := &mockExecStarter{
879+
startFunc: func() error { return nil },
880+
pid: 12345,
881+
}
882+
883+
// When
884+
pid, err := sut.StartExecCmd(mockStarter, true)
885+
886+
// Then - Should fail because kill loop has begun
887+
Expect(err).To(HaveOccurred())
888+
Expect(err.Error()).To(ContainSubstring("container is being killed"))
889+
Expect(pid).To(Equal(0))
890+
// Start should NOT have been called
891+
Expect(mockStarter.startCalled).To(BeFalse())
892+
})
893+
894+
It("should succeed during graceful termination", func() {
895+
// Given - Container is stopping but kill loop hasn't begun
896+
sut.SetAsStopping()
897+
898+
mockStarter := &mockExecStarter{
899+
startFunc: func() error { return nil },
900+
pid: 12345,
901+
}
902+
903+
// When
904+
pid, err := sut.StartExecCmd(mockStarter, true)
905+
906+
// Then - Should succeed because stopKillLoopBegun is false
907+
Expect(err).ToNot(HaveOccurred())
908+
Expect(pid).To(Equal(12345))
909+
Expect(mockStarter.startCalled).To(BeTrue())
910+
})
911+
912+
It("should propagate start errors", func() {
913+
// Given
914+
expectedErr := errors.New("start failed")
915+
mockStarter := &mockExecStarter{
916+
startFunc: func() error { return expectedErr },
917+
pid: 0,
918+
}
919+
920+
// When
921+
pid, err := sut.StartExecCmd(mockStarter, true)
922+
923+
// Then - Should propagate the error
924+
Expect(err).To(Equal(expectedErr))
925+
Expect(pid).To(Equal(0))
926+
Expect(mockStarter.startCalled).To(BeTrue())
927+
})
928+
929+
It("should register PID atomically on success", func() {
930+
// Given
931+
mockStarter := &mockExecStarter{
932+
startFunc: func() error { return nil },
933+
pid: 54321,
934+
}
935+
936+
// When
937+
pid, err := sut.StartExecCmd(mockStarter, false)
938+
939+
// Then
940+
Expect(err).ToNot(HaveOccurred())
941+
Expect(pid).To(Equal(54321))
942+
943+
// Verify PID is registered - we can check by trying to delete it
944+
// (DeleteExecPID doesn't error on non-existent PIDs)
945+
Expect(func() {
946+
sut.DeleteExecPID(54321)
947+
}).ToNot(Panic())
948+
})
949+
})
950+
872951
t.Describe("SetAsDoneStopping", func() {
873952
It("should complete without error when no watchers exist", func() {
874953
// Given - No watchers registered
@@ -1012,3 +1091,19 @@ var _ = t.Describe("SpoofedContainer", func() {
10121091
Expect(sut.Sandbox()).To(Equal("sbox"))
10131092
})
10141093
})
1094+
1095+
// mockExecStarter is a mock implementation of ExecStarter for testing.
1096+
type mockExecStarter struct {
1097+
startFunc func() error
1098+
pid int
1099+
startCalled bool
1100+
}
1101+
1102+
func (m *mockExecStarter) Start() error {
1103+
m.startCalled = true
1104+
return m.startFunc()
1105+
}
1106+
1107+
func (m *mockExecStarter) GetPid() int {
1108+
return m.pid
1109+
}

internal/oci/runtime_oci.go

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,19 @@ type exitCodeInfo struct {
101101
Message string `json:"message,omitempty"`
102102
}
103103

104+
// execCmdWrapper wraps exec.Cmd to implement the ExecStarter interface.
105+
type execCmdWrapper struct {
106+
cmd *exec.Cmd
107+
}
108+
109+
func (w *execCmdWrapper) Start() error {
110+
return w.cmd.Start()
111+
}
112+
113+
func (w *execCmdWrapper) GetPid() int {
114+
return w.cmd.Process.Pid
115+
}
116+
104117
// CreateContainer creates a container.
105118
func (r *runtimeOCI) CreateContainer(ctx context.Context, c *Container, cgroupParent string, restore bool) (retErr error) {
106119
ctx, span := log.StartSpan(ctx)
@@ -542,12 +555,10 @@ func (r *runtimeOCI) ExecContainer(ctx context.Context, c *Container, cmd []stri
542555
execCmd.Stderr = stderr
543556
}
544557

545-
if err := execCmd.Start(); err != nil {
546-
return err
547-
}
548-
549-
pid := execCmd.Process.Pid
550-
if err := c.AddExecPID(pid, true); err != nil {
558+
// Atomically start the exec and register its PID to prevent race conditions
559+
// where the exec starts but isn't registered before the kill loop begins.
560+
pid, err := c.StartExecCmd(&execCmdWrapper{cmd: execCmd}, true)
561+
if err != nil {
551562
return err
552563
}
553564
defer c.DeleteExecPID(pid)
@@ -710,7 +721,9 @@ func (r *runtimeOCI) ExecSyncContainer(ctx context.Context, c *Container, comman
710721
cmd.ExtraFiles = append(cmd.ExtraFiles, childPipe, childStartPipe)
711722
r.prepareEnv(cmd, true)
712723

713-
err = cmd.Start()
724+
// Atomically start the command and register its PID to prevent race conditions
725+
// where the exec starts but isn't registered before the kill loop begins.
726+
pid, err := c.StartExecCmd(&execCmdWrapper{cmd: cmd}, false)
714727
if err != nil {
715728
childPipe.Close()
716729
childStartPipe.Close()
@@ -745,14 +758,12 @@ func (r *runtimeOCI) ExecSyncContainer(ctx context.Context, c *Container, comman
745758
if !errors.As(waitErr, &exitErr) || exitErr.ExitCode() != -1 {
746759
retErr = fmt.Errorf("failed to wait %w after failing with: %w", waitErr, retErr)
747760
}
761+
// Clean up the PID registration since the exec failed
762+
c.DeleteExecPID(pid)
748763
}
749764
}()
750765

751-
// A neat trick we can do is register the exec PID before we send info down the start pipe.
752-
// Doing so guarantees we can short circuit the exec process if the container is stopping already.
753-
if err := c.AddExecPID(cmd.Process.Pid, false); err != nil {
754-
return err
755-
}
766+
// The exec PID was already registered atomically by StartExecCmd above
756767

757768
if r.handler.MonitorExecCgroup == config.MonitorExecCgroupContainer && r.config.InfraCtrCPUSet != "" {
758769
// Update the exec's cgroup
@@ -761,7 +772,7 @@ func (r *runtimeOCI) ExecSyncContainer(ctx context.Context, c *Container, comman
761772
return err
762773
}
763774

764-
err = cgmgr.MoveProcessToContainerCgroup(containerPid, cmd.Process.Pid)
775+
err = cgmgr.MoveProcessToContainerCgroup(containerPid, pid)
765776
if err != nil {
766777
return err
767778
}
@@ -784,9 +795,6 @@ func (r *runtimeOCI) ExecSyncContainer(ctx context.Context, c *Container, comman
784795
}
785796
}
786797

787-
// defer in case the Pid is changed after Wait()
788-
pid := cmd.Process.Pid
789-
790798
// first, wait till the command is done
791799
waitErr := cmd.Wait()
792800

@@ -975,47 +983,34 @@ func (r *runtimeOCI) StopLoopForContainer(ctx context.Context, c *Container, bm
975983
// when CRI-O is run directly in the foreground in the terminal.
976984
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
977985

986+
c.opLock.Lock()
987+
978988
defer func() {
979989
// Kill the exec PIDs after the main container to avoid pod lifecycle regressions:
980990
// Ref: https://github.com/kubernetes/kubernetes/issues/124743
981-
c.opLock.Lock()
982991
c.KillExecPIDs()
983992
c.state.Finished = time.Now()
984993
c.opLock.Unlock()
985994
c.SetAsDoneStopping()
986995
}()
987996

988-
c.opLock.Lock()
989-
isPaused := c.state.Status == ContainerStatePaused
990-
c.opLock.Unlock()
991-
992-
if isPaused {
993-
c.opLock.Lock()
994-
997+
if c.state.Status == ContainerStatePaused {
995998
if _, err := r.runtimeCmd("resume", c.ID()); err != nil {
996999
log.Errorf(ctx, "Failed to unpause container %s: %v", c.Name(), err)
9971000
}
998-
999-
c.opLock.Unlock()
10001001
}
10011002

10021003
// Begin the actual kill.
1003-
c.opLock.Lock()
1004-
1005-
_, killErr := r.runtimeCmd("kill", c.ID(), c.GetStopSignal())
1006-
if killErr != nil {
1004+
if _, err := r.runtimeCmd("kill", c.ID(), c.GetStopSignal()); err != nil {
10071005
if err := c.Living(); err != nil {
10081006
// The initial container process either doesn't exist, or isn't ours.
10091007
// Set state accordingly.
10101008
c.state.Finished = time.Now()
1011-
c.opLock.Unlock()
10121009

10131010
return
10141011
}
10151012
}
10161013

1017-
c.opLock.Unlock()
1018-
10191014
done := make(chan struct{})
10201015

10211016
go func() {
@@ -1030,11 +1025,8 @@ func (r *runtimeOCI) StopLoopForContainer(ctx context.Context, c *Container, bm
10301025
// Periodically check if the container is still running.
10311026
// This avoids busy-waiting and reduces resource usage while
10321027
// ensuring timely detection of container termination.
1033-
c.opLock.RLock()
1034-
err := c.Living()
1035-
c.opLock.RUnlock()
1036-
1037-
if err != nil {
1028+
//
1029+
if err := c.Living(); err != nil {
10381030
// The initial container process either doesn't exist, or isn't ours.
10391031
if !errors.Is(err, ErrNotFound) {
10401032
log.Warnf(ctx, "Failed to find process for container %s: %v", c.ID(), err)
@@ -1054,20 +1046,15 @@ func (r *runtimeOCI) StopLoopForContainer(ctx context.Context, c *Container, bm
10541046
targetTime := time.Now().AddDate(+1, 0, 0) // A year from this one.
10551047

10561048
blockedTimer := time.AfterFunc(stopProcessBlockedInterval, func() {
1057-
c.opLock.RLock()
1058-
state, err := c.ProcessState()
1059-
initPid := c.state.InitPid
1060-
c.opLock.RUnlock()
1061-
1062-
if err == nil && state == "D" {
1049+
if state, err := c.ProcessState(); err == nil && state == "D" {
10631050
log.Errorf(ctx,
10641051
"Detected process (%d) blocked in uninterruptible sleep for more than %d seconds for container %s",
1065-
initPid, int(time.Since(startTime)/time.Second), c.ID(),
1052+
c.state.InitPid, int(time.Since(startTime)/time.Second), c.ID(),
10661053
)
10671054
} else {
10681055
log.Warnf(ctx,
10691056
"Detected process (%d) in state %s blocked for more than %d seconds for container %s. One of the child processes might be in uninterruptible sleep.",
1070-
initPid, state, int(time.Since(startTime)/time.Second), c.ID(),
1057+
c.state.InitPid, state, int(time.Since(startTime)/time.Second), c.ID(),
10711058
)
10721059
}
10731060
})
@@ -1105,21 +1092,15 @@ func (r *runtimeOCI) StopLoopForContainer(ctx context.Context, c *Container, bm
11051092
killContainer:
11061093
// We cannot use ExponentialBackoff() here as its stop conditions are not flexible enough.
11071094
kwait.BackoffUntil(func() {
1108-
c.opLock.Lock()
1109-
1110-
_, killErr := r.runtimeCmd("kill", c.ID(), "KILL")
1111-
if killErr != nil {
1112-
if !errors.Is(killErr, ErrNotFound) {
1113-
log.Errorf(ctx, "Killing container %v failed: %v", c.ID(), killErr)
1095+
if _, err := r.runtimeCmd("kill", c.ID(), "KILL"); err != nil {
1096+
if !errors.Is(err, ErrNotFound) {
1097+
log.Errorf(ctx, "Killing container %v failed: %v", c.ID(), err)
11141098
} else {
1115-
log.Debugf(ctx, "Error while killing container %s: %v", c.ID(), killErr)
1099+
log.Debugf(ctx, "Error while killing container %s: %v", c.ID(), err)
11161100
}
11171101
}
11181102

1119-
err := c.Living()
1120-
c.opLock.Unlock()
1121-
1122-
if err != nil {
1103+
if err := c.Living(); err != nil {
11231104
log.Debugf(ctx, "Container is no longer alive")
11241105
stop()
11251106

0 commit comments

Comments
 (0)