diff --git a/net/net_linux.go b/net/net_linux.go index 0bddfe3..bf0bf31 100644 --- a/net/net_linux.go +++ b/net/net_linux.go @@ -292,6 +292,12 @@ func Connections(kind string) ([]ConnectionStat, error) { return ConnectionsPid(kind, 0) } +// Return a list of network connections opened returning at most `max` +// connections for each running process. +func ConnectionsMax(kind string, max int) ([]ConnectionStat, error) { + return ConnectionsPidMax(kind, 0, max) +} + // Return a list of network connections opened by a process. func ConnectionsPid(kind string, pid int32) ([]ConnectionStat, error) { tmap, ok := netConnectionKindMap[kind] @@ -302,9 +308,33 @@ func ConnectionsPid(kind string, pid int32) ([]ConnectionStat, error) { var err error var inodes map[string][]inodeMap if pid == 0 { - inodes, err = getProcInodesAll(root) + inodes, err = getProcInodesAll(root, 0) + } else { + inodes, err = getProcInodes(root, pid, 0) + if len(inodes) == 0 { + // no connection for the pid + return []ConnectionStat{}, nil + } + } + if err != nil { + return nil, fmt.Errorf("cound not get pid(s), %d", pid) + } + return statsFromInodes(root, pid, tmap, inodes) +} + +// Return up to `max` network connections opened by a process. +func ConnectionsPidMax(kind string, pid int32, max int) ([]ConnectionStat, error) { + tmap, ok := netConnectionKindMap[kind] + if !ok { + return nil, fmt.Errorf("invalid kind, %s", kind) + } + root := common.HostProc() + var err error + var inodes map[string][]inodeMap + if pid == 0 { + inodes, err = getProcInodesAll(root, max) } else { - inodes, err = getProcInodes(root, pid) + inodes, err = getProcInodes(root, pid, max) if len(inodes) == 0 { // no connection for the pid return []ConnectionStat{}, nil @@ -313,10 +343,14 @@ func ConnectionsPid(kind string, pid int32) ([]ConnectionStat, error) { if err != nil { return nil, fmt.Errorf("cound not get pid(s), %d", pid) } + return statsFromInodes(root, pid, tmap, inodes) +} +func statsFromInodes(root string, pid int32, tmap []netConnectionKindType, inodes map[string][]inodeMap) ([]ConnectionStat, error) { dupCheckMap := make(map[string]bool) var ret []ConnectionStat + var err error for _, t := range tmap { var path string var ls []connTmp @@ -368,11 +402,15 @@ func ConnectionsPid(kind string, pid int32) ([]ConnectionStat, error) { } // getProcInodes returnes fd of the pid. -func getProcInodes(root string, pid int32) (map[string][]inodeMap, error) { +func getProcInodes(root string, pid int32, max int) (map[string][]inodeMap, error) { ret := make(map[string][]inodeMap) dir := fmt.Sprintf("%s/%d/fd", root, pid) - files, err := ioutil.ReadDir(dir) + f, err := os.Open(dir) + if err != nil { + return ret, nil + } + files, err := f.Readdir(max) if err != nil { return ret, nil } @@ -484,7 +522,7 @@ func (p *process) fillFromStatus() error { return nil } -func getProcInodesAll(root string) (map[string][]inodeMap, error) { +func getProcInodesAll(root string, max int) (map[string][]inodeMap, error) { pids, err := Pids() if err != nil { return nil, err @@ -492,7 +530,7 @@ func getProcInodesAll(root string) (map[string][]inodeMap, error) { ret := make(map[string][]inodeMap) for _, pid := range pids { - t, err := getProcInodes(root, pid) + t, err := getProcInodes(root, pid, max) if err != nil { return ret, err } diff --git a/net/net_linux_test.go b/net/net_linux_test.go index 47faeb9..a881b7a 100644 --- a/net/net_linux_test.go +++ b/net/net_linux_test.go @@ -15,11 +15,32 @@ func TestGetProcInodesAll(t *testing.T) { } root := common.HostProc("") - v, err := getProcInodesAll(root) + v, err := getProcInodesAll(root, 0) assert.Nil(t, err) assert.NotEmpty(t, v) } +func TestConnectionsMax(t *testing.T) { + if os.Getenv("CIRCLECI") == "true" { + t.Skip("Skip CI") + } + + max := 10 + v, err := ConnectionsMax("tcp", max) + assert.Nil(t, err) + assert.NotEmpty(t, v) + + cxByPid := map[int32]int{} + for _, cx := range v { + if cx.Pid > 0 { + cxByPid[cx.Pid]++ + } + } + for _, c := range cxByPid { + assert.True(t, c <= max) + } +} + type AddrTest struct { IP string Port int