Set TCP keepalive for WireGuard gVisor TCP connections

This commit is contained in:
世界 2023-04-26 10:47:32 +08:00
parent f53007cbf3
commit f949ddc0ab
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
2 changed files with 79 additions and 1 deletions

View File

@ -112,7 +112,7 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
} }
switch N.NetworkName(network) { switch N.NetworkName(network) {
case N.NetworkTCP: case N.NetworkTCP:
tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,78 @@
//go:build with_gvisor
package wireguard
import (
"context"
"errors"
"fmt"
"net"
"time"
M "github.com/sagernet/sing/common/metadata"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
// Create wait queue entry that notifies a channel.
//
// We do this unconditionally as Connect will always return an error.
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
wq.EventRegister(&waitEntry)
defer wq.EventUnregister(&waitEntry)
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Bind before connect if requested.
if localAddr != (tcpip.FullAddress{}) {
if err = ep.Bind(localAddr); err != nil {
return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
}
}
err = ep.Connect(remoteAddr)
if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
ep.Close()
return nil, ctx.Err()
case <-notifyCh:
}
err = ep.LastError()
}
if err != nil {
ep.Close()
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
Addr: M.SocksaddrFrom(M.AddrFromIP(net.IP(remoteAddr.Addr)), remoteAddr.Port).TCPAddr(),
Err: errors.New(err.String()),
}
}
// sing-box added: set keepalive
ep.SocketOptions().SetKeepAlive(true)
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
ep.SetSockOpt(&keepAliveIdle)
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
ep.SetSockOpt(&keepAliveInterval)
return gonet.NewTCPConn(&wq, ep), nil
}