diff --git a/src/net/interface_linux.go b/src/net/interface_linux.go index 7856dae8fc878b..4561382ce72d90 100644 --- a/src/net/interface_linux.go +++ b/src/net/interface_linux.go @@ -6,6 +6,7 @@ package net import ( "os" + "sync" "syscall" "unsafe" ) @@ -121,7 +122,11 @@ func linkFlags(rawFlags uint32) Flags { // network interfaces. Otherwise it returns addresses for a specific // interface. func interfaceAddrTable(ifi *Interface) ([]Addr, error) { - tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) + var ifindex int + if ifi != nil { + ifindex = ifi.Index + } + tab, err := netlinkRIB(getAddrRequest(syscall.AF_UNSPEC, ifindex)) if err != nil { return nil, os.NewSyscallError("netlinkrib", err) } @@ -255,3 +260,87 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr { } return ifmat } + +func getAddrRequest(family, ifindex int) []byte { + reqLen := syscall.NLMSG_HDRLEN + syscall.SizeofIfAddrmsg + req := make([]byte, reqLen) + h := (*syscall.NlMsghdr)(unsafe.Pointer(&req[0])) + h.Len = uint32(reqLen) + h.Flags = uint16(syscall.NLM_F_REQUEST | syscall.NLM_F_DUMP) + h.Type = uint16(syscall.RTM_GETADDR) + h.Seq = uint32(1) + iam := (*syscall.IfAddrmsg)(unsafe.Pointer(&req[syscall.NLMSG_HDRLEN])) + iam.Family = uint8(family) + iam.Index = uint32(ifindex) + return req +} + +var pageBufPool = &sync.Pool{New: func() any { + b := make([]byte, syscall.Getpagesize()) + return &b +}} + +// These constants aren't in the syscall package, which is frozen. +// Values taken from golang.org/x/sys/unix. +const ( + _SOL_NETLINK = 0x10e + _NETLINK_GET_STRICT_CHK = 0x0c +) + +// Modified version of syscall.NetlinkRIB that sets NETLINK_GET_STRICT_CHK and +// sends request bytes directly to the socket. +func netlinkRIB(req []byte) ([]byte, error) { + s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, syscall.NETLINK_ROUTE) + if err != nil { + return nil, err + } + defer syscall.Close(s) + _ = syscall.SetsockoptInt(s, _SOL_NETLINK, _NETLINK_GET_STRICT_CHK, 1) + sa := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} + if err := syscall.Bind(s, sa); err != nil { + return nil, err + } + if err := syscall.Sendto(s, req, 0, sa); err != nil { + return nil, err + } + lsa, err := syscall.Getsockname(s) + if err != nil { + return nil, err + } + lsanl, ok := lsa.(*syscall.SockaddrNetlink) + if !ok { + return nil, syscall.EINVAL + } + var tab []byte + rbNew := pageBufPool.Get().(*[]byte) + defer pageBufPool.Put(rbNew) +done: + for { + rb := *rbNew + nr, _, err := syscall.Recvfrom(s, rb, 0) + if err != nil { + return nil, err + } + if nr < syscall.NLMSG_HDRLEN { + return nil, syscall.EINVAL + } + rb = rb[:nr] + tab = append(tab, rb...) + msgs, err := syscall.ParseNetlinkMessage(rb) + if err != nil { + return nil, err + } + for _, m := range msgs { + if m.Header.Seq != 1 || m.Header.Pid != lsanl.Pid { + return nil, syscall.EINVAL + } + if m.Header.Type == syscall.NLMSG_DONE { + break done + } + if m.Header.Type == syscall.NLMSG_ERROR { + return nil, syscall.EINVAL + } + } + } + return tab, nil +}