test: use `T.Setenv` to set env vars in tests

This commit replaces `os.Setenv` with `t.Setenv` in tests. The
environment variable is automatically restored to its original value
when the test and all its subtests complete.

Reference: https://pkg.go.dev/testing#T.Setenv
Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
pull/1398/head
Eng Zer Jun 2 years ago
parent 835767a611
commit 37894e9b28
No known key found for this signature in database
GPG Key ID: DAEBBD2E34C111E6

@ -2,7 +2,6 @@ package cpu
import ( import (
"errors" "errors"
"os"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
@ -10,8 +9,7 @@ import (
) )
func TestTimesEmpty(t *testing.T) { func TestTimesEmpty(t *testing.T) {
orig := os.Getenv("HOST_PROC") t.Setenv("HOST_PROC", "testdata/linux/times_empty")
os.Setenv("HOST_PROC", "testdata/linux/times_empty")
_, err := Times(true) _, err := Times(true)
if err != nil { if err != nil {
t.Error("Times(true) failed") t.Error("Times(true) failed")
@ -20,12 +18,10 @@ func TestTimesEmpty(t *testing.T) {
if err != nil { if err != nil {
t.Error("Times(false) failed") t.Error("Times(false) failed")
} }
os.Setenv("HOST_PROC", orig)
} }
func TestCPUparseStatLine_424(t *testing.T) { func TestCPUparseStatLine_424(t *testing.T) {
orig := os.Getenv("HOST_PROC") t.Setenv("HOST_PROC", "testdata/linux/424/proc")
os.Setenv("HOST_PROC", "testdata/linux/424/proc")
{ {
l, err := Times(true) l, err := Times(true)
if err != nil || len(l) == 0 { if err != nil || len(l) == 0 {
@ -40,7 +36,6 @@ func TestCPUparseStatLine_424(t *testing.T) {
} }
t.Logf("Times(false): %#v", l) t.Logf("Times(false): %#v", l)
} }
os.Setenv("HOST_PROC", orig)
} }
func TestCPUCountsAgainstLscpu(t *testing.T) { func TestCPUCountsAgainstLscpu(t *testing.T) {
@ -93,9 +88,7 @@ func TestCPUCountsAgainstLscpu(t *testing.T) {
} }
func TestCPUCountsLogicalAndroid_1037(t *testing.T) { // https://github.com/shirou/gopsutil/issues/1037 func TestCPUCountsLogicalAndroid_1037(t *testing.T) { // https://github.com/shirou/gopsutil/issues/1037
orig := os.Getenv("HOST_PROC") t.Setenv("HOST_PROC", "testdata/linux/1037/proc")
os.Setenv("HOST_PROC", "testdata/linux/1037/proc")
defer os.Setenv("HOST_PROC", orig)
count, err := Counts(true) count, err := Counts(true)
if err != nil { if err != nil {

@ -4,7 +4,6 @@
package cpu package cpu
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -30,13 +29,9 @@ var timesTests = []struct {
} }
func TestTimesPlan9(t *testing.T) { func TestTimesPlan9(t *testing.T) {
origRoot := os.Getenv("HOST_ROOT")
t.Cleanup(func() {
os.Setenv("HOST_ROOT", origRoot)
})
for _, tt := range timesTests { for _, tt := range timesTests {
t.Run(tt.mockedRootFS, func(t *testing.T) { t.Run(tt.mockedRootFS, func(t *testing.T) {
os.Setenv("HOST_ROOT", filepath.Join("testdata/plan9", tt.mockedRootFS)) t.Setenv("HOST_ROOT", filepath.Join("testdata/plan9", tt.mockedRootFS))
stats, err := Times(false) stats, err := Times(false)
skipIfNotImplementedErr(t, err) skipIfNotImplementedErr(t, err)
if err != nil { if err != nil {

@ -364,16 +364,6 @@ func HostDev(combineWith ...string) string {
return GetEnv("HOST_DEV", "/dev", combineWith...) return GetEnv("HOST_DEV", "/dev", combineWith...)
} }
// MockEnv set environment variable and return revert function.
// MockEnv should be used testing only.
func MockEnv(key string, value string) func() {
original := os.Getenv(key)
os.Setenv(key, value)
return func() {
os.Setenv(key, original)
}
}
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running // getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
// sysctl commands (see DoSysctrl). // sysctl commands (see DoSysctrl).
func getSysctrlEnv(env []string) []string { func getSysctrlEnv(env []string) []string {

@ -4,7 +4,6 @@
package mem package mem
import ( import (
"os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
@ -111,12 +110,9 @@ var virtualMemoryTests = []struct {
} }
func TestVirtualMemoryLinux(t *testing.T) { func TestVirtualMemoryLinux(t *testing.T) {
origProc := os.Getenv("HOST_PROC")
defer os.Setenv("HOST_PROC", origProc)
for _, tt := range virtualMemoryTests { for _, tt := range virtualMemoryTests {
t.Run(tt.mockedRootFS, func(t *testing.T) { t.Run(tt.mockedRootFS, func(t *testing.T) {
os.Setenv("HOST_PROC", filepath.Join("testdata/linux/virtualmemory/", tt.mockedRootFS, "proc")) t.Setenv("HOST_PROC", filepath.Join("testdata/linux/virtualmemory/", tt.mockedRootFS, "proc"))
stat, err := VirtualMemory() stat, err := VirtualMemory()
skipIfNotImplementedErr(t, err) skipIfNotImplementedErr(t, err)

@ -4,7 +4,6 @@
package mem package mem
import ( import (
"os"
"reflect" "reflect"
"testing" "testing"
) )
@ -27,14 +26,9 @@ var virtualMemoryTests = []struct {
} }
func TestVirtualMemoryPlan9(t *testing.T) { func TestVirtualMemoryPlan9(t *testing.T) {
origProc := os.Getenv("HOST_ROOT")
t.Cleanup(func() {
os.Setenv("HOST_ROOT", origProc)
})
for _, tt := range virtualMemoryTests { for _, tt := range virtualMemoryTests {
t.Run(tt.mockedRootFS, func(t *testing.T) { t.Run(tt.mockedRootFS, func(t *testing.T) {
os.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") t.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/")
stat, err := VirtualMemory() stat, err := VirtualMemory()
skipIfNotImplementedErr(t, err) skipIfNotImplementedErr(t, err)
@ -62,14 +56,9 @@ var swapMemoryTests = []struct {
} }
func TestSwapMemoryPlan9(t *testing.T) { func TestSwapMemoryPlan9(t *testing.T) {
origProc := os.Getenv("HOST_ROOT")
t.Cleanup(func() {
os.Setenv("HOST_ROOT", origProc)
})
for _, tt := range swapMemoryTests { for _, tt := range swapMemoryTests {
t.Run(tt.mockedRootFS, func(t *testing.T) { t.Run(tt.mockedRootFS, func(t *testing.T) {
os.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") t.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/")
swap, err := SwapMemory() swap, err := SwapMemory()
skipIfNotImplementedErr(t, err) skipIfNotImplementedErr(t, err)

@ -11,7 +11,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/shirou/gopsutil/v3/internal/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -62,8 +61,7 @@ func Test_Process_splitProcStat_fromFile(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
f := common.MockEnv("HOST_PROC", "testdata/linux") t.Setenv("HOST_PROC", "testdata/linux")
defer f()
for _, pid := range pids { for _, pid := range pids {
pid, err := strconv.ParseInt(pid.Name(), 0, 32) pid, err := strconv.ParseInt(pid.Name(), 0, 32)
if err != nil { if err != nil {
@ -99,8 +97,7 @@ func Test_fillFromCommWithContext(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
f := common.MockEnv("HOST_PROC", "testdata/linux") t.Setenv("HOST_PROC", "testdata/linux")
defer f()
for _, pid := range pids { for _, pid := range pids {
pid, err := strconv.ParseInt(pid.Name(), 0, 32) pid, err := strconv.ParseInt(pid.Name(), 0, 32)
if err != nil { if err != nil {
@ -121,8 +118,7 @@ func Test_fillFromStatusWithContext(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
f := common.MockEnv("HOST_PROC", "testdata/linux") t.Setenv("HOST_PROC", "testdata/linux")
defer f()
for _, pid := range pids { for _, pid := range pids {
pid, err := strconv.ParseInt(pid.Name(), 0, 32) pid, err := strconv.ParseInt(pid.Name(), 0, 32)
if err != nil { if err != nil {
@ -139,8 +135,7 @@ func Test_fillFromStatusWithContext(t *testing.T) {
} }
func Benchmark_fillFromCommWithContext(b *testing.B) { func Benchmark_fillFromCommWithContext(b *testing.B) {
f := common.MockEnv("HOST_PROC", "testdata/linux") b.Setenv("HOST_PROC", "testdata/linux")
defer f()
pid := 1060 pid := 1060
p, _ := NewProcess(int32(pid)) p, _ := NewProcess(int32(pid))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -149,8 +144,7 @@ func Benchmark_fillFromCommWithContext(b *testing.B) {
} }
func Benchmark_fillFromStatusWithContext(b *testing.B) { func Benchmark_fillFromStatusWithContext(b *testing.B) {
f := common.MockEnv("HOST_PROC", "testdata/linux") b.Setenv("HOST_PROC", "testdata/linux")
defer f()
pid := 1060 pid := 1060
p, _ := NewProcess(int32(pid)) p, _ := NewProcess(int32(pid))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -163,8 +157,7 @@ func Test_fillFromTIDStatWithContext_lx_brandz(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
f := common.MockEnv("HOST_PROC", "testdata/lx_brandz") t.Setenv("HOST_PROC", "testdata/lx_brandz")
defer f()
for _, pid := range pids { for _, pid := range pids {
pid, err := strconv.ParseInt(pid.Name(), 0, 32) pid, err := strconv.ParseInt(pid.Name(), 0, 32)
if err != nil { if err != nil {

Loading…
Cancel
Save