diff --git a/disk/disk_windows.go b/disk/disk_windows.go index 8a1a28d..e17db3e 100644 --- a/disk/disk_windows.go +++ b/disk/disk_windows.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "fmt" - "sync" "syscall" "unsafe" @@ -51,6 +50,7 @@ func init() { key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Services\PartMgr`, registry.SET_VALUE) if err == nil { key.SetDWordValue("EnableCounterForIoctl", 1) + key.Close() } } @@ -86,28 +86,22 @@ func PartitionsWithContext(ctx context.Context, all bool) ([]PartitionStat, erro warnings := Warnings{ Verbose: true, } - var ret []PartitionStat - retChan := make(chan []PartitionStat) - errChan := make(chan error) - lpBuffer := make([]byte, 254) - - var waitgrp sync.WaitGroup - waitgrp.Add(1) - defer waitgrp.Done() - - f := func() { - defer func() { - waitgrp.Wait() - // fires when this func and the outside func finishes. - close(errChan) - close(retChan) - }() + + var errLogicalDrives error + retChan := make(chan PartitionStat) + quitChan := make(chan struct{}) + defer close(quitChan) + + getPartitions := func() { + defer close(retChan) + + lpBuffer := make([]byte, 254) diskret, _, err := procGetLogicalDriveStringsW.Call( uintptr(len(lpBuffer)), uintptr(unsafe.Pointer(&lpBuffer[0]))) if diskret == 0 { - errChan <- err + errLogicalDrives = err return } for _, v := range lpBuffer { @@ -153,27 +147,37 @@ func PartitionsWithContext(ctx context.Context, all bool) ([]PartitionStat, erro opts = append(opts, "compress") } - d := PartitionStat{ + select { + case retChan <- PartitionStat{ Mountpoint: path, Device: path, - Fstype: string(bytes.Replace(lpFileSystemNameBuffer, []byte("\x00"), []byte(""), -1)), + Fstype: string(bytes.ReplaceAll(lpFileSystemNameBuffer, []byte("\x00"), []byte(""))), Opts: opts, + }: + case <-quitChan: + return } - ret = append(ret, d) } } } - retChan <- ret } - go f() - select { - case err := <-errChan: - return ret, err - case ret := <-retChan: - return ret, warnings.Reference() - case <-ctx.Done(): - return ret, ctx.Err() + go getPartitions() + + var ret []PartitionStat + for { + select { + case p, ok := <-retChan: + if !ok { + if errLogicalDrives != nil { + return ret, errLogicalDrives + } + return ret, warnings.Reference() + } + ret = append(ret, p) + case <-ctx.Done(): + return ret, ctx.Err() + } } }