// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package hostinet

import (
	"fmt"

	"golang.org/x/sys/unix"
	"gvisor.dev/gvisor/pkg/abi/linux"
	"gvisor.dev/gvisor/pkg/context"
	"gvisor.dev/gvisor/pkg/errors/linuxerr"
	"gvisor.dev/gvisor/pkg/fdnotifier"
	"gvisor.dev/gvisor/pkg/hostarch"
	"gvisor.dev/gvisor/pkg/log"
	"gvisor.dev/gvisor/pkg/marshal"
	"gvisor.dev/gvisor/pkg/marshal/primitive"
	"gvisor.dev/gvisor/pkg/safemem"
	"gvisor.dev/gvisor/pkg/sentry/arch"
	"gvisor.dev/gvisor/pkg/sentry/fs"
	"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
	"gvisor.dev/gvisor/pkg/sentry/kernel"
	ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
	"gvisor.dev/gvisor/pkg/sentry/socket"
	"gvisor.dev/gvisor/pkg/sentry/socket/control"
	"gvisor.dev/gvisor/pkg/syserr"
	"gvisor.dev/gvisor/pkg/usermem"
	"gvisor.dev/gvisor/pkg/waiter"
)

const (
	sizeofInt32 = 4

	// sizeofSockaddr is the size in bytes of the largest sockaddr type
	// supported by this package.
	sizeofSockaddr = unix.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)

	// maxControlLen is the maximum size of a control message buffer used in a
	// recvmsg or sendmsg unix.
	maxControlLen = 1024
)

// LINT.IfChange

// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
	fsutil.FilePipeSeek             `state:"nosave"`
	fsutil.FileNotDirReaddir        `state:"nosave"`
	fsutil.FileNoFsync              `state:"nosave"`
	fsutil.FileNoMMap               `state:"nosave"`
	fsutil.FileNoSplice             `state:"nosave"`
	fsutil.FileNoopFlush            `state:"nosave"`
	fsutil.FileUseInodeUnstableAttr `state:"nosave"`

	socketOpsCommon
}

var _ = socket.Socket(&socketOperations{})

func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
	s := &socketOperations{
		socketOpsCommon: socketOpsCommon{
			family:   family,
			stype:    stype,
			protocol: protocol,
			fd:       fd,
		},
	}
	if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
		return nil, syserr.FromError(err)
	}
	dirent := socket.NewDirent(ctx, socketDevice)
	defer dirent.DecRef(ctx)
	return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
}

// Ioctl implements fs.FileOperations.Ioctl.
func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
	return ioctl(ctx, s.fd, io, args)
}

// Read implements fs.FileOperations.Read.
func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
	n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
		// Refuse to do anything if any part of dst.Addrs was unusable.
		if uint64(dst.NumBytes()) != dsts.NumBytes() {
			return 0, nil
		}
		if dsts.IsEmpty() {
			return 0, nil
		}
		if dsts.NumBlocks() == 1 {
			// Skip allocating []unix.Iovec.
			n, err := unix.Read(s.fd, dsts.Head().ToSlice())
			if err != nil {
				return 0, translateIOSyscallError(err)
			}
			return uint64(n), nil
		}
		return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
	}))
	return n, err
}

// Write implements fs.FileOperations.Write.
func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
	n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
		// Refuse to do anything if any part of src.Addrs was unusable.
		if uint64(src.NumBytes()) != srcs.NumBytes() {
			return 0, nil
		}
		if srcs.IsEmpty() {
			return 0, nil
		}
		if srcs.NumBlocks() == 1 {
			// Skip allocating []unix.Iovec.
			n, err := unix.Write(s.fd, srcs.Head().ToSlice())
			if err != nil {
				return 0, translateIOSyscallError(err)
			}
			return uint64(n), nil
		}
		return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
	}))
	return n, err
}

// Socket implements socket.Provider.Socket.
func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
	// Check that we are using the host network stack.
	stack := t.NetworkContext()
	if stack == nil {
		return nil, nil
	}
	if _, ok := stack.(*Stack); !ok {
		return nil, nil
	}

	// Only accept TCP and UDP.
	stype := stypeflags & linux.SOCK_TYPE_MASK
	switch stype {
	case unix.SOCK_STREAM:
		switch protocol {
		case 0, unix.IPPROTO_TCP:
			// ok
		default:
			return nil, nil
		}
	case unix.SOCK_DGRAM:
		switch protocol {
		case 0, unix.IPPROTO_UDP:
			// ok
		default:
			return nil, nil
		}
	default:
		return nil, nil
	}

	// Conservatively ignore all flags specified by the application and add
	// SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
	// to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
	fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0)
	if err != nil {
		return nil, syserr.FromError(err)
	}
	return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0)
}

// Pair implements socket.Provider.Pair.
func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
	// Not supported by AF_INET/AF_INET6.
	return nil, nil, nil
}

// LINT.ThenChange(./socket_vfs2.go)

// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
//
// +stateify savable
type socketOpsCommon struct {
	socket.SendReceiveTimeout

	family   int            // Read-only.
	stype    linux.SockType // Read-only.
	protocol int            // Read-only.
	queue    waiter.Queue

	// fd is the host socket fd. It must have O_NONBLOCK, so that operations
	// will return EWOULDBLOCK instead of blocking on the host. This allows us to
	// handle blocking behavior independently in the sentry.
	fd int
}

// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(context.Context) {
	fdnotifier.RemoveFD(int32(s.fd))
	_ = unix.Close(s.fd)
}

// Readiness implements waiter.Waitable.Readiness.
func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
	return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
}

// EventRegister implements waiter.Waitable.EventRegister.
func (s *socketOpsCommon) EventRegister(e *waiter.Entry) error {
	s.queue.EventRegister(e)
	if err := fdnotifier.UpdateFD(int32(s.fd)); err != nil {
		s.queue.EventUnregister(e)
		return err
	}
	return nil
}

// EventUnregister implements waiter.Waitable.EventUnregister.
func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
	s.queue.EventUnregister(e)
	if err := fdnotifier.UpdateFD(int32(s.fd)); err != nil {
		panic(err)
	}
}

// Connect implements socket.Socket.Connect.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
	if len(sockaddr) > sizeofSockaddr {
		sockaddr = sockaddr[:sizeofSockaddr]
	}

	_, _, errno := unix.Syscall(unix.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))

	if errno == 0 {
		return nil
	}
	if errno != unix.EINPROGRESS || !blocking {
		return syserr.FromError(translateIOSyscallError(errno))
	}

	// "EINPROGRESS: The socket is nonblocking and the connection cannot be
	// completed immediately. It is possible to select(2) or poll(2) for
	// completion by selecting the socket for writing. After select(2)
	// indicates writability, use getsockopt(2) to read the SO_ERROR option at
	// level SOL-SOCKET to determine whether connect() completed successfully
	// (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
	// codes listed here, explaining the reason for the failure)." - connect(2)
	writableMask := waiter.WritableEvents
	e, ch := waiter.NewChannelEntry(writableMask)
	s.EventRegister(&e)
	defer s.EventUnregister(&e)
	if s.Readiness(writableMask)&writableMask == 0 {
		if err := t.Block(ch); err != nil {
			return syserr.FromError(err)
		}
	}
	val, err := unix.GetsockoptInt(s.fd, unix.SOL_SOCKET, unix.SO_ERROR)
	if err != nil {
		return syserr.FromError(err)
	}
	if val != 0 {
		return syserr.FromError(unix.Errno(uintptr(val)))
	}
	return nil
}

// Accept implements socket.Socket.Accept.
func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
	var peerAddr linux.SockAddr
	var peerAddrBuf []byte
	var peerAddrlen uint32
	var peerAddrPtr *byte
	var peerAddrlenPtr *uint32
	if peerRequested {
		peerAddrBuf = make([]byte, sizeofSockaddr)
		peerAddrlen = uint32(len(peerAddrBuf))
		peerAddrPtr = &peerAddrBuf[0]
		peerAddrlenPtr = &peerAddrlen
	}

	// Conservatively ignore all flags specified by the application and add
	// SOCK_NONBLOCK since socketOpsCommon requires it.
	fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC)
	if blocking {
		var ch chan struct{}
		for syscallErr == linuxerr.ErrWouldBlock {
			if ch != nil {
				if syscallErr = t.Block(ch); syscallErr != nil {
					break
				}
			} else {
				var e waiter.Entry
				e, ch = waiter.NewChannelEntry(waiter.ReadableEvents)
				s.EventRegister(&e)
				defer s.EventUnregister(&e)
			}
			fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC)
		}
	}

	if peerRequested {
		peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
	}
	if syscallErr != nil {
		return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
	}

	var (
		kfd  int32
		kerr error
	)
	if kernel.VFS2Enabled {
		f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK))
		if err != nil {
			_ = unix.Close(fd)
			return 0, nil, 0, err
		}
		defer f.DecRef(t)

		kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
			CloseOnExec: flags&unix.SOCK_CLOEXEC != 0,
		})
		t.Kernel().RecordSocketVFS2(f)
	} else {
		f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0)
		if err != nil {
			_ = unix.Close(fd)
			return 0, nil, 0, err
		}
		defer f.DecRef(t)

		kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
			CloseOnExec: flags&unix.SOCK_CLOEXEC != 0,
		})
		t.Kernel().RecordSocket(f)
	}

	return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}

// Bind implements socket.Socket.Bind.
func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
	if len(sockaddr) > sizeofSockaddr {
		sockaddr = sockaddr[:sizeofSockaddr]
	}

	_, _, errno := unix.Syscall(unix.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
	if errno != 0 {
		return syserr.FromError(errno)
	}
	return nil
}

// Listen implements socket.Socket.Listen.
func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
	return syserr.FromError(unix.Listen(s.fd, backlog))
}

// Shutdown implements socket.Socket.Shutdown.
func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
	switch how {
	case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR:
		return syserr.FromError(unix.Shutdown(s.fd, how))
	default:
		return syserr.ErrInvalidArgument
	}
}

// GetSockOpt implements socket.Socket.GetSockOpt.
func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, optValAddr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
	if outLen < 0 {
		return nil, syserr.ErrInvalidArgument
	}

	// Only allow known and safe options.
	optlen, copyIn := getSockOptLen(t, level, name)
	switch level {
	case linux.SOL_IP:
		switch name {
		case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_TTL, linux.IP_RECVTTL, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
			optlen = sizeofInt32
		}
	case linux.SOL_IPV6:
		switch name {
		case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVPKTINFO, linux.IPV6_UNICAST_HOPS, linux.IPV6_MULTICAST_HOPS, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
			optlen = sizeofInt32
		}
	case linux.SOL_SOCKET:
		switch name {
		case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
			optlen = sizeofInt32
		case linux.SO_LINGER:
			optlen = unix.SizeofLinger
		case linux.SO_RCVTIMEO, linux.SO_SNDTIMEO:
			optlen = linux.SizeOfTimeval
		}
	case linux.SOL_TCP:
		switch name {
		case linux.TCP_NODELAY:
			optlen = sizeofInt32
		case linux.TCP_INFO:
			optlen = linux.SizeOfTCPInfo
		}
	}

	if optlen == 0 {
		return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
	}
	if outLen < optlen {
		return nil, syserr.ErrInvalidArgument
	}

	opt := make([]byte, optlen)
	if copyIn {
		// This is non-intuitive as normally in getsockopt one assumes that the
		// parameter is purely an out parameter. But some custom options do require
		// copying in the optVal so we do it here only for those custom options.
		if _, err := t.CopyInBytes(optValAddr, opt); err != nil {
			return nil, syserr.FromError(err)
		}
	}
	var err error
	opt, err = getsockopt(s.fd, level, name, opt)
	if err != nil {
		return nil, syserr.FromError(err)
	}
	opt = postGetSockOpt(t, level, name, opt)
	optP := primitive.ByteSlice(opt)
	return &optP, nil
}

// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
	// Only allow known and safe options.
	optlen := setSockOptLen(t, level, name)
	switch level {
	case linux.SOL_IP:
		switch name {
		case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_TTL, linux.IP_RECVTTL, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
			optlen = sizeofInt32
		}
	case linux.SOL_IPV6:
		switch name {
		case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVPKTINFO, linux.IPV6_UNICAST_HOPS, linux.IPV6_MULTICAST_HOPS, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
			optlen = sizeofInt32
		}
	case linux.SOL_SOCKET:
		switch name {
		case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
			optlen = sizeofInt32
		}
	case linux.SOL_TCP:
		switch name {
		case linux.TCP_NODELAY, linux.TCP_INQ:
			optlen = sizeofInt32
		}
	}

	if optlen == 0 {
		// Pretend to accept socket options we don't understand. This seems
		// dangerous, but it's what netstack does...
		return nil
	}
	if len(opt) < optlen {
		return syserr.ErrInvalidArgument
	}
	opt = opt[:optlen]

	_, _, errno := unix.Syscall6(unix.SYS_SETSOCKOPT, uintptr(s.fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(len(opt)), 0)
	if errno != 0 {
		return syserr.FromError(errno)
	}
	return nil
}

func (s *socketOpsCommon) recvMsgFromHost(iovs []unix.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
	// We always do a non-blocking recv*().
	sysflags := flags | unix.MSG_DONTWAIT

	msg := unix.Msghdr{}
	if len(iovs) > 0 {
		msg.Iov = &iovs[0]
		msg.Iovlen = uint64(len(iovs))
	}
	var senderAddrBuf []byte
	if senderRequested {
		senderAddrBuf = make([]byte, sizeofSockaddr)
		msg.Name = &senderAddrBuf[0]
		msg.Namelen = uint32(sizeofSockaddr)
	}
	var controlBuf []byte
	if controlLen > 0 {
		if controlLen > maxControlLen {
			controlLen = maxControlLen
		}
		controlBuf = make([]byte, controlLen)
		msg.Control = &controlBuf[0]
		msg.Controllen = controlLen
	}
	n, err := recvmsg(s.fd, &msg, sysflags)
	if err != nil {
		return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
	}
	return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
}

// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
	// Only allow known and safe flags.
	if flags&^(unix.MSG_DONTWAIT|unix.MSG_PEEK|unix.MSG_TRUNC|unix.MSG_ERRQUEUE) != 0 {
		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
	}

	var senderAddrBuf []byte
	var controlBuf []byte
	var msgFlags int
	copyToDst := func() (int64, error) {
		var n uint64
		var err error
		if dst.NumBytes() == 0 {
			// We want to make the recvmsg(2) call to the host even if dst is empty
			// to fetch control messages, sender address or errors if any occur.
			n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
			return int64(n), err
		}

		recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
			// Refuse to do anything if any part of dst.Addrs was unusable.
			if uint64(dst.NumBytes()) != dsts.NumBytes() {
				return 0, nil
			}
			if dsts.IsEmpty() {
				return 0, nil
			}

			n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
			return n, err
		})
		return dst.CopyOutFrom(t, recvmsgToBlocks)
	}

	var ch chan struct{}
	n, err := copyToDst()
	// recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
	if flags&(unix.MSG_DONTWAIT|unix.MSG_ERRQUEUE) == 0 {
		for err == linuxerr.ErrWouldBlock {
			// We only expect blocking to come from the actual syscall, in which
			// case it can't have returned any data.
			if n != 0 {
				panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
			}
			if ch != nil {
				if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
					break
				}
			} else {
				var e waiter.Entry
				e, ch = waiter.NewChannelEntry(waiter.ReadableEvents)
				s.EventRegister(&e)
				defer s.EventUnregister(&e)
			}
			n, err = copyToDst()
		}
	}
	if err != nil {
		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
	}

	var senderAddr linux.SockAddr
	if senderRequested {
		senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
	}

	unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
	if err != nil {
		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
	}
	return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
}

func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
	controlMessages := socket.ControlMessages{}
	for _, unixCmsg := range unixControlMessages {
		switch unixCmsg.Header.Level {
		case linux.SOL_SOCKET:
			switch unixCmsg.Header.Type {
			case linux.SO_TIMESTAMP:
				controlMessages.IP.HasTimestamp = true
				ts := linux.Timeval{}
				ts.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.Timestamp = ts.ToTime()
			}

		case linux.SOL_IP:
			switch unixCmsg.Header.Type {
			case linux.IP_TOS:
				controlMessages.IP.HasTOS = true
				var tos primitive.Uint8
				tos.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.TOS = uint8(tos)

			case linux.IP_TTL:
				controlMessages.IP.HasTTL = true
				var ttl primitive.Uint32
				ttl.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.TTL = uint32(ttl)

			case linux.IP_PKTINFO:
				controlMessages.IP.HasIPPacketInfo = true
				var packetInfo linux.ControlMessageIPPacketInfo
				packetInfo.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.PacketInfo = packetInfo

			case linux.IP_RECVORIGDSTADDR:
				var addr linux.SockAddrInet
				addr.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.OriginalDstAddress = &addr

			case unix.IP_RECVERR:
				var errCmsg linux.SockErrCMsgIPv4
				errCmsg.UnmarshalBytes(unixCmsg.Data)
				controlMessages.IP.SockErr = &errCmsg
			}

		case linux.SOL_IPV6:
			switch unixCmsg.Header.Type {
			case linux.IPV6_TCLASS:
				controlMessages.IP.HasTClass = true
				var tclass primitive.Uint32
				tclass.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.TClass = uint32(tclass)

			case linux.IPV6_PKTINFO:
				controlMessages.IP.HasIPv6PacketInfo = true
				var packetInfo linux.ControlMessageIPv6PacketInfo
				packetInfo.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.IPv6PacketInfo = packetInfo

			case linux.IPV6_HOPLIMIT:
				controlMessages.IP.HasHopLimit = true
				var hoplimit primitive.Uint32
				hoplimit.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.HopLimit = uint32(hoplimit)

			case linux.IPV6_RECVORIGDSTADDR:
				var addr linux.SockAddrInet6
				addr.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.OriginalDstAddress = &addr

			case unix.IPV6_RECVERR:
				var errCmsg linux.SockErrCMsgIPv6
				errCmsg.UnmarshalBytes(unixCmsg.Data)
				controlMessages.IP.SockErr = &errCmsg
			}

		case linux.SOL_TCP:
			switch unixCmsg.Header.Type {
			case linux.TCP_INQ:
				controlMessages.IP.HasInq = true
				var inq primitive.Int32
				inq.UnmarshalUnsafe(unixCmsg.Data)
				controlMessages.IP.Inq = int32(inq)
			}
		}
	}
	return controlMessages
}

// SendMsg implements socket.Socket.SendMsg.
func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
	// Only allow known and safe flags.
	if flags&^(unix.MSG_DONTWAIT|unix.MSG_EOR|unix.MSG_FASTOPEN|unix.MSG_MORE|unix.MSG_NOSIGNAL) != 0 {
		return 0, syserr.ErrInvalidArgument
	}

	// If the src is zero-length, call SENDTO directly with a null buffer in
	// order to generate poll/epoll notifications.
	if src.NumBytes() == 0 {
		sysflags := flags | unix.MSG_DONTWAIT
		n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), 0, 0, uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
		if errno != 0 {
			return 0, syserr.FromError(errno)
		}
		return int(n), nil
	}

	space := uint64(control.CmsgsSpace(t, controlMessages))
	if space > maxControlLen {
		space = maxControlLen
	}
	controlBuf := make([]byte, 0, space)
	// PackControlMessages will append up to space bytes to controlBuf.
	controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)

	sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
		// Refuse to do anything if any part of src.Addrs was unusable.
		if uint64(src.NumBytes()) != srcs.NumBytes() {
			return 0, nil
		}
		if srcs.IsEmpty() && len(controlBuf) == 0 {
			return 0, nil
		}

		// We always do a non-blocking send*().
		sysflags := flags | unix.MSG_DONTWAIT

		if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
			// Skip allocating []unix.Iovec.
			src := srcs.Head()
			n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
			if errno != 0 {
				return 0, translateIOSyscallError(errno)
			}
			return uint64(n), nil
		}

		iovs := safemem.IovecsFromBlockSeq(srcs)
		msg := unix.Msghdr{
			Iov:    &iovs[0],
			Iovlen: uint64(len(iovs)),
		}
		if len(to) != 0 {
			msg.Name = &to[0]
			msg.Namelen = uint32(len(to))
		}
		if len(controlBuf) != 0 {
			msg.Control = &controlBuf[0]
			msg.Controllen = uint64(len(controlBuf))
		}
		return sendmsg(s.fd, &msg, sysflags)
	})

	var ch chan struct{}
	n, err := src.CopyInTo(t, sendmsgFromBlocks)
	if flags&unix.MSG_DONTWAIT == 0 {
		for err == linuxerr.ErrWouldBlock {
			// We only expect blocking to come from the actual syscall, in which
			// case it can't have returned any data.
			if n != 0 {
				panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
			}
			if ch != nil {
				if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
					if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
						err = linuxerr.ErrWouldBlock
					}
					break
				}
			} else {
				var e waiter.Entry
				e, ch = waiter.NewChannelEntry(waiter.WritableEvents)
				s.EventRegister(&e)
				defer s.EventUnregister(&e)
			}
			n, err = src.CopyInTo(t, sendmsgFromBlocks)
		}
	}

	return int(n), syserr.FromError(err)
}

func translateIOSyscallError(err error) error {
	if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
		return linuxerr.ErrWouldBlock
	}
	return err
}

// State implements socket.Socket.State.
func (s *socketOpsCommon) State() uint32 {
	info := linux.TCPInfo{}
	buf := make([]byte, linux.SizeOfTCPInfo)
	var err error
	buf, err = getsockopt(s.fd, unix.SOL_TCP, unix.TCP_INFO, buf)
	if err != nil {
		if err != unix.ENOPROTOOPT {
			log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
		}
		// For non-TCP sockets, silently ignore the failure.
		return 0
	}
	if len(buf) != linux.SizeOfTCPInfo {
		// Unmarshal below will panic if getsockopt returns a buffer of
		// unexpected size.
		log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
		return 0
	}

	info.UnmarshalUnsafe(buf[:info.SizeBytes()])
	return uint32(info.State)
}

// Type implements socket.Socket.Type.
func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
	return s.family, s.stype, s.protocol
}

type socketProvider struct {
	family int
}

func init() {
	for _, family := range []int{unix.AF_INET, unix.AF_INET6} {
		socket.RegisterProvider(family, &socketProvider{family})
		socket.RegisterProviderVFS2(family, &socketProviderVFS2{family})
	}
}
