diff --git a/pexec/managed_process.go b/pexec/managed_process.go index 14752998..7f03fecf 100644 --- a/pexec/managed_process.go +++ b/pexec/managed_process.go @@ -32,6 +32,10 @@ type ManagedProcess interface { // there's any system level issue stopping the process. Stop() error + // KillGroup will attempt to kill the process group and not wait for completion. Only use this if + // comfortable with leaking resources (in cases where exiting the program as quickly as possible is desired). + KillGroup() + // Status return nil when the process is both alive and owned. // If err is non-nil, process may be a) alive but not owned or b) dead. Status() error @@ -432,3 +436,30 @@ func (p *managedProcess) Stop() error { } return errors.Errorf("non-successful exit code: %d", p.cmd.ProcessState.ExitCode()) } + +// KillGroup kills the process group. +func (p *managedProcess) KillGroup() { + // Minimally hold a lock here so that we can signal the + // management goroutine to stop. We will attempt to kill the + // process even if p.stopped is true. + p.mu.Lock() + if !p.stopped { + close(p.killCh) + p.stopped = true + } + + if p.cmd == nil { + p.mu.Unlock() + return + } + p.mu.Unlock() + + // Since p.cmd is mutex guarded and we just signaled the manage + // goroutine to stop, no new Start can happen and therefore + // p.cmd can no longer be modified rendering it safe to read + // without a lock held. + // We are intentionally not checking the error here, we are already + // in a bad state. + //nolint:errcheck,gosec + p.forceKillGroup() +} diff --git a/pexec/managed_process_test.go b/pexec/managed_process_test.go index 75e3e44f..059c5564 100644 --- a/pexec/managed_process_test.go +++ b/pexec/managed_process_test.go @@ -16,6 +16,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "testing" "time" @@ -411,10 +412,10 @@ func TestManagedProcessStop(t *testing.T) { bashScriptBuilder.WriteString("\n") } bashScriptBuilder.WriteString(fmt.Sprintf(`echo hello >> '%s' -while true -do echo hey -sleep 1 -done`, tempFile.Name())) + while true + do echo hey + sleep 1 + done`, tempFile.Name())) bashScriptBuilder.WriteString("\n") bashScript := bashScriptBuilder.String() @@ -571,6 +572,97 @@ done`, tempFile.Name())) }) } +func TestManagedProcessKillGroup(t *testing.T) { + t.Run("kill signaling with children", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("cannot test this on windows") + } + logger := golog.NewTestLogger(t) + + watcher1, tempFile1 := testutils.WatchedFile(t) + watcher2, tempFile2 := testutils.WatchedFile(t) + watcher3, tempFile3 := testutils.WatchedFile(t) + + // this script writes a string to the specified file every 100ms + script := ` + while true + do echo hello >> '%s' + sleep 0.1 + done + ` + + bashScript1 := fmt.Sprintf(script, tempFile1.Name()) + bashScript2 := fmt.Sprintf(script, tempFile2.Name()) + bashScriptParent := fmt.Sprintf(` + bash -c '%s' & + bash -c '%s' & + `+script, + bashScript1, + bashScript2, + tempFile3.Name(), + tempFile3.Name(), + ) + + proc := NewManagedProcess(ProcessConfig{ + Name: "bash", + Args: []string{"-c", bashScriptParent}, + }, logger) + + // To confirm that the processes have died, confirm that the size of the file stopped increasing + getSize := func(file *os.File) int64 { + info, _ := file.Stat() + return info.Size() + } + + file1SizeBeforeStart := getSize(tempFile1) + file2SizeBeforeStart := getSize(tempFile2) + file3SizeBeforeStart := getSize(tempFile3) + + test.That(t, proc.Start(context.Background()), test.ShouldBeNil) + + <-watcher1.Events + <-watcher2.Events + <-watcher3.Events + + proc.KillGroup() + + file1SizeAfterKill := getSize(tempFile1) + file2SizeAfterKill := getSize(tempFile2) + file3SizeAfterKill := getSize(tempFile3) + + test.That(t, file1SizeAfterKill, test.ShouldBeGreaterThan, file1SizeBeforeStart) + test.That(t, file2SizeAfterKill, test.ShouldBeGreaterThan, file2SizeBeforeStart) + test.That(t, file3SizeAfterKill, test.ShouldBeGreaterThan, file3SizeBeforeStart) + + // since KillGroup does not wait, we might have to check file size a few times as the kill + // might take a little to propagate. We want to make sure that the file size stops increasing. + testutils.WaitForAssertionWithSleep(t, 300*time.Millisecond, 50, func(tb testing.TB) { + tempSize1 := getSize(tempFile1) + tempSize2 := getSize(tempFile2) + tempSize3 := getSize(tempFile3) + + test.That(t, tempSize1, test.ShouldEqual, file1SizeAfterKill) + test.That(t, tempSize2, test.ShouldEqual, file2SizeAfterKill) + test.That(t, tempSize3, test.ShouldEqual, file3SizeAfterKill) + + file1SizeAfterKill = tempSize1 + file2SizeAfterKill = tempSize1 + file3SizeAfterKill = tempSize1 + }) + + // in CI, we have to send another signal to make sure the cmd.Wait() in + // the manage goroutine actually returns. + // We do not care about the error if it is expected. + // maybe related to https://github.com/golang/go/issues/18874 + if err := proc.(*managedProcess).cmd.Process.Signal(syscall.SIGTERM); err != nil { + test.That(t, errors.Is(err, os.ErrProcessDone), test.ShouldBeFalse) + } + + // wait on the managingCh to close + <-proc.(*managedProcess).managingCh + }) +} + func TestManagedProcessEnvironmentVariables(t *testing.T) { t.Run("set an environment variable on one-shot process", func(t *testing.T) { logger := golog.NewTestLogger(t) @@ -702,3 +794,5 @@ func (fp *fakeProcess) UnixPid() (int, error) { in reality tests should just depend on the methods they rely on. UnixPid is not one of those methods (for better or worse)`) } + +func (fp *fakeProcess) KillGroup() {} diff --git a/pexec/managed_process_unix.go b/pexec/managed_process_unix.go index 498be9e8..a5be4317 100644 --- a/pexec/managed_process_unix.go +++ b/pexec/managed_process_unix.go @@ -4,6 +4,7 @@ package pexec import ( "os" + "os/exec" "os/user" "strconv" "syscall" @@ -126,6 +127,15 @@ func (p *managedProcess) kill() (bool, error) { return forceKilled, nil } +// forceKillGroup kills everything in the process group. This will not wait for completion and may result the +// kill becoming a zombie process. +func (p *managedProcess) forceKillGroup() error { + pgidStr := strconv.Itoa(-p.cmd.Process.Pid) + p.logger.Infof("killing entire process group %d", p.cmd.Process.Pid) + //nolint:gosec + return exec.Command("kill", "-9", pgidStr).Start() +} + func isWaitErrUnknown(err string, forceKilled bool) bool { // This can easily happen if the process does not handle interrupts gracefully // and it won't provide us any exit code info. diff --git a/pexec/managed_process_windows.go b/pexec/managed_process_windows.go index ad1034b2..df758860 100644 --- a/pexec/managed_process_windows.go +++ b/pexec/managed_process_windows.go @@ -25,7 +25,6 @@ func parseSignal(sigStr, name string) (syscall.Signal, error) { return 0, errors.New("signals not supported on Windows") } - func (p *managedProcess) sysProcAttr() (*syscall.SysProcAttr, error) { ret := &syscall.SysProcAttr{ CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, @@ -107,6 +106,13 @@ func (p *managedProcess) kill() (bool, error) { return forceKilled, nil } +// forceKillGroup kills everything in the process tree. This will not wait for completion and may result in a zombie process. +func (p *managedProcess) forceKillGroup() error { + pidStr := strconv.Itoa(p.cmd.Process.Pid) + p.logger.Infof("force killing entire process tree %d", p.cmd.Process.Pid) + return exec.Command("taskkill", "/t", "/f", "/pid", pidStr).Start() +} + func isWaitErrUnknown(err string, forceKilled bool) bool { if !forceKilled { return false