diff --git a/internal/common/common.go b/internal/common/common.go index ebf70ea..f1e4154 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -352,6 +352,16 @@ func HostDev(combineWith ...string) string { return GetEnv("HOST_DEV", "/dev", combineWith...) } +// MockEnv set environment variable and return revert function. +// MockEnv should be used testing only. +func MockEnv(key string, value string) func() { + original := os.Getenv(key) + os.Setenv(key, value) + return func() { + os.Setenv(key, original) + } +} + // getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running // sysctl commands (see DoSysctrl). func getSysctrlEnv(env []string) []string { diff --git a/process/process_linux_test.go b/process/process_linux_test.go index f4e8164..bfbe96d 100644 --- a/process/process_linux_test.go +++ b/process/process_linux_test.go @@ -5,9 +5,10 @@ package process import ( "context" "io/ioutil" - "os" "strconv" "testing" + + "github.com/shirou/gopsutil/internal/common" ) func Test_fillFromStatusWithContext(t *testing.T) { @@ -15,10 +16,8 @@ func Test_fillFromStatusWithContext(t *testing.T) { if err != nil { t.Error(err) } - original := os.Getenv("HOST_PROC") - os.Setenv("HOST_PROC", "testdata/linux") - defer os.Setenv("HOST_PROC", original) - + f := common.MockEnv("HOST_PROC", "testdata/linux") + defer f() for _, pid := range pids { pid, _ := strconv.ParseInt(pid.Name(), 0, 32) p, _ := NewProcess(int32(pid)) diff --git a/v3/internal/common/common.go b/v3/internal/common/common.go index ebf70ea..f1e4154 100644 --- a/v3/internal/common/common.go +++ b/v3/internal/common/common.go @@ -352,6 +352,16 @@ func HostDev(combineWith ...string) string { return GetEnv("HOST_DEV", "/dev", combineWith...) } +// MockEnv set environment variable and return revert function. +// MockEnv should be used testing only. +func MockEnv(key string, value string) func() { + original := os.Getenv(key) + os.Setenv(key, value) + return func() { + os.Setenv(key, original) + } +} + // getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running // sysctl commands (see DoSysctrl). func getSysctrlEnv(env []string) []string { diff --git a/v3/process/process_linux_test.go b/v3/process/process_linux_test.go index f4e8164..ada0ab8 100644 --- a/v3/process/process_linux_test.go +++ b/v3/process/process_linux_test.go @@ -5,9 +5,10 @@ package process import ( "context" "io/ioutil" - "os" "strconv" "testing" + + "github.com/shirou/gopsutil/v3/internal/common" ) func Test_fillFromStatusWithContext(t *testing.T) { @@ -15,10 +16,8 @@ func Test_fillFromStatusWithContext(t *testing.T) { if err != nil { t.Error(err) } - original := os.Getenv("HOST_PROC") - os.Setenv("HOST_PROC", "testdata/linux") - defer os.Setenv("HOST_PROC", original) - + f := common.MockEnv("HOST_PROC", "testdata/linux") + defer f() for _, pid := range pids { pid, _ := strconv.ParseInt(pid.Name(), 0, 32) p, _ := NewProcess(int32(pid))