diff --git a/process/process.go b/process/process.go index 35a822c..0792122 100644 --- a/process/process.go +++ b/process/process.go @@ -6,6 +6,7 @@ import ( "errors" "runtime" "sort" + "sync" "syscall" "time" @@ -26,6 +27,7 @@ type Process struct { name string status string parent int32 + parentMutex *sync.RWMutex // for windows ppid cache numCtxSwitches *NumCtxSwitchesStat uids []int32 gids []int32 @@ -167,7 +169,10 @@ func NewProcess(pid int32) (*Process, error) { } func NewProcessWithContext(ctx context.Context, pid int32) (*Process, error) { - p := &Process{Pid: pid} + p := &Process{ + Pid: pid, + parentMutex: new(sync.RWMutex), + } exists, err := PidExistsWithContext(ctx, pid) if err != nil { diff --git a/process/process_race_test.go b/process/process_race_test.go new file mode 100644 index 0000000..fd444a8 --- /dev/null +++ b/process/process_race_test.go @@ -0,0 +1,30 @@ +// +build race + +package process + +import ( + "sync" + "testing" +) + +func Test_Process_Ppid_Race(t *testing.T) { + wg := sync.WaitGroup{} + testCount := 10 + p := testGetProcess() + wg.Add(testCount) + for i := 0; i < testCount; i++ { + go func(j int) { + ppid, err := p.Ppid() + wg.Done() + skipIfNotImplementedErr(t, err) + if err != nil { + t.Errorf("Ppid() failed, %v", err) + } + + if j == 9 { + t.Logf("Ppid(): %d", ppid) + } + }(i) + } + wg.Wait() +} diff --git a/process/process_windows.go b/process/process_windows.go index 5c3a402..1501dff 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -237,8 +237,9 @@ func PidExistsWithContext(ctx context.Context, pid int32) (bool, error) { func (p *Process) PpidWithContext(ctx context.Context) (int32, error) { // if cached already, return from cache - if p.parent != 0 { - return p.parent, nil + cachedPpid := p.getPpid() + if cachedPpid != 0 { + return cachedPpid, nil } ppid, _, _, err := getFromSnapProcess(p.Pid) @@ -246,8 +247,8 @@ func (p *Process) PpidWithContext(ctx context.Context) (int32, error) { return 0, err } - // if no errors, cache it - p.parent = ppid + // no errors and not cached already, so cache it + p.setPpid(ppid) return ppid, nil } @@ -258,8 +259,11 @@ func (p *Process) NameWithContext(ctx context.Context) (string, error) { return "", fmt.Errorf("could not get Name: %s", err) } - // if no errors, cache ppid + // if no errors and not cached already, cache ppid p.parent = ppid + if 0 == p.getPpid() { + p.setPpid(ppid) + } return name, nil } @@ -456,8 +460,11 @@ func (p *Process) NumThreadsWithContext(ctx context.Context) (int32, error) { return 0, err } - // if no errors, cache ppid + // if no errors and not cached already, cache ppid p.parent = ppid + if 0 == p.getPpid() { + p.setPpid(ppid) + } return ret, nil } @@ -613,6 +620,21 @@ func (p *Process) KillWithContext(ctx context.Context) error { return process.Kill() } +// retrieve Ppid in a thread-safe manner +func (p *Process) getPpid() int32 { + p.parentMutex.RLock() + defer p.parentMutex.RUnlock() + return p.parent +} + +// cache Ppid in a thread-safe manner (WINDOWS ONLY) +// see https://psutil.readthedocs.io/en/latest/#psutil.Process.ppid +func (p *Process) setPpid(ppid int32) { + p.parentMutex.Lock() + defer p.parentMutex.Unlock() + p.parent = ppid +} + func getFromSnapProcess(pid int32) (int32, int32, string, error) { snap, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, uint32(pid)) if err != nil {