diff --git a/cpu/cpu_windows.go b/cpu/cpu_windows.go index b5bf825..b012fde 100644 --- a/cpu/cpu_windows.go +++ b/cpu/cpu_windows.go @@ -3,6 +3,7 @@ package cpu import ( + "context" "fmt" "unsafe" @@ -81,8 +82,9 @@ func Info() ([]InfoStat, error) { var ret []InfoStat var dst []Win32_Processor q := wmi.CreateQuery(&dst, "") - err := wmi.Query(q, &dst) - if err != nil { + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + if err := common.WMIQueryWithContext(ctx, q, &dst); err != nil { return ret, err } @@ -113,8 +115,11 @@ func Info() ([]InfoStat, error) { // Name property is the key by which overall, per cpu and per core metric is known. func PerfInfo() ([]Win32_PerfFormattedData_Counters_ProcessorInformation, error) { var ret []Win32_PerfFormattedData_Counters_ProcessorInformation + q := wmi.CreateQuery(&ret, "") - err := wmi.Query(q, &ret) + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, q, &ret) return ret, err } @@ -123,7 +128,9 @@ func PerfInfo() ([]Win32_PerfFormattedData_Counters_ProcessorInformation, error) func ProcInfo() ([]Win32_PerfFormattedData_PerfOS_System, error) { var ret []Win32_PerfFormattedData_PerfOS_System q := wmi.CreateQuery(&ret, "") - err := wmi.Query(q, &ret) + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, q, &ret) return ret, err } diff --git a/disk/disk_windows.go b/disk/disk_windows.go index 6cc22a2..9e5f681 100644 --- a/disk/disk_windows.go +++ b/disk/disk_windows.go @@ -4,9 +4,9 @@ package disk import ( "bytes" + "context" "unsafe" - "github.com/StackExchange/wmi" "github.com/shirou/gopsutil/internal/common" "golang.org/x/sys/windows" ) @@ -132,7 +132,9 @@ func IOCounters(names ...string) (map[string]IOCountersStat, error) { ret := make(map[string]IOCountersStat, 0) var dst []Win32_PerfFormattedData - err := wmi.Query("SELECT * FROM Win32_PerfFormattedData_PerfDisk_LogicalDisk ", &dst) + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, "SELECT * FROM Win32_PerfFormattedData_PerfDisk_LogicalDisk", &dst) if err != nil { return ret, err } diff --git a/host/host_windows.go b/host/host_windows.go index 9894302..920cec9 100644 --- a/host/host_windows.go +++ b/host/host_windows.go @@ -3,6 +3,7 @@ package host import ( + "context" "fmt" "os" "runtime" @@ -109,7 +110,9 @@ func getMachineGuid() (string, error) { func GetOSInfo() (Win32_OperatingSystem, error) { var dst []Win32_OperatingSystem q := wmi.CreateQuery(&dst, "") - err := wmi.Query(q, &dst) + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, q, &dst) if err != nil { return Win32_OperatingSystem{}, err } diff --git a/internal/common/common_windows.go b/internal/common/common_windows.go index 8e0177d..1dffe61 100644 --- a/internal/common/common_windows.go +++ b/internal/common/common_windows.go @@ -3,8 +3,10 @@ package common import ( + "context" "unsafe" + "github.com/StackExchange/wmi" "golang.org/x/sys/windows" ) @@ -49,7 +51,7 @@ var ( ModNt = windows.NewLazyDLL("ntdll.dll") ModPdh = windows.NewLazyDLL("pdh.dll") ModPsapi = windows.NewLazyDLL("psapi.dll") - + ProcGetSystemTimes = Modkernel32.NewProc("GetSystemTimes") ProcNtQuerySystemInformation = ModNt.NewProc("NtQuerySystemInformation") PdhOpenQuery = ModPdh.NewProc("PdhOpenQuery") @@ -110,3 +112,18 @@ func CreateCounter(query windows.Handle, pname, cname string) (*CounterInfo, err Counter: counter, }, nil } + +// WMIQueryWithContext - wraps wmi.Query with a timed-out context to avoid hanging +func WMIQueryWithContext(ctx context.Context, query string, dst interface{}, connectServerArgs ...interface{}) error { + errChan := make(chan error, 1) + go func() { + errChan <- wmi.Query(query, dst, connectServerArgs...) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errChan: + return err + } +} diff --git a/process/process_windows.go b/process/process_windows.go index bbd7762..be91f97 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -3,6 +3,7 @@ package process import ( + "context" "fmt" "strings" "syscall" @@ -130,8 +131,10 @@ func GetWin32Proc(pid int32) ([]Win32_Process, error) { var dst []Win32_Process query := fmt.Sprintf("WHERE ProcessId = %d", pid) q := wmi.CreateQuery(&dst, query) - - if err := wmi.Query(q, &dst); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, q, &dst) + if err != nil { return []Win32_Process{}, fmt.Errorf("could not get win32Proc: %s", err) } @@ -333,7 +336,9 @@ func (p *Process) MemoryInfoEx() (*MemoryInfoExStat, error) { func (p *Process) Children() ([]*Process, error) { var dst []Win32_Process query := wmi.CreateQuery(&dst, fmt.Sprintf("Where ParentProcessId = %d", p.Pid)) - err := wmi.Query(query, &dst) + ctx, cancel := context.WithTimeout(context.Background(), common.Timeout) + defer cancel() + err := common.WMIQueryWithContext(ctx, query, &dst) if err != nil { return nil, err }