From 3a4eeec8f2bf2022d10c984c4e280b19eba5351d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 31 Aug 2023 20:07:32 +0800 Subject: [PATCH] Add hysteria2 protocol --- constant/proxy.go | 3 + inbound/builder.go | 2 + inbound/hysteria2.go | 144 ++++++ inbound/hysteria_stub.go | 4 + option/hysteria2.go | 33 ++ option/inbound.go | 5 + option/outbound.go | 5 + outbound/builder.go | 2 + outbound/hysteria2.go | 122 +++++ outbound/hysteria_stub.go | 4 + test/ech_test.go | 78 ++++ test/hysteria2_test.go | 97 ++++ transport/hysteria2/client.go | 306 ++++++++++++ transport/hysteria2/client_paclet.go | 47 ++ transport/hysteria2/congestion/brutal.go | 151 ++++++ transport/hysteria2/congestion/pacer.go | 86 ++++ transport/hysteria2/internal/protocol/http.go | 68 +++ .../hysteria2/internal/protocol/padding.go | 31 ++ .../hysteria2/internal/protocol/proxy.go | 266 +++++++++++ transport/hysteria2/packet.go | 438 ++++++++++++++++++ transport/hysteria2/salamander.go | 106 +++++ transport/hysteria2/server.go | 336 ++++++++++++++ transport/hysteria2/server_packet.go | 55 +++ 23 files changed, 2389 insertions(+) create mode 100644 inbound/hysteria2.go create mode 100644 option/hysteria2.go create mode 100644 outbound/hysteria2.go create mode 100644 test/hysteria2_test.go create mode 100644 transport/hysteria2/client.go create mode 100644 transport/hysteria2/client_paclet.go create mode 100644 transport/hysteria2/congestion/brutal.go create mode 100644 transport/hysteria2/congestion/pacer.go create mode 100644 transport/hysteria2/internal/protocol/http.go create mode 100644 transport/hysteria2/internal/protocol/padding.go create mode 100644 transport/hysteria2/internal/protocol/proxy.go create mode 100644 transport/hysteria2/packet.go create mode 100644 transport/hysteria2/salamander.go create mode 100644 transport/hysteria2/server.go create mode 100644 transport/hysteria2/server_packet.go diff --git a/constant/proxy.go b/constant/proxy.go index 2b9d8945..1e9baee2 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -22,6 +22,7 @@ const ( TypeShadowsocksR = "shadowsocksr" TypeVLESS = "vless" TypeTUIC = "tuic" + TypeHysteria2 = "hysteria2" ) const ( @@ -65,6 +66,8 @@ func ProxyDisplayName(proxyType string) string { return "VLESS" case TypeTUIC: return "TUIC" + case TypeHysteria2: + return "Hysteria2" case TypeSelector: return "Selector" case TypeURLTest: diff --git a/inbound/builder.go b/inbound/builder.go index 4cd466af..513b016f 100644 --- a/inbound/builder.go +++ b/inbound/builder.go @@ -46,6 +46,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions) case C.TypeTUIC: return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions) + case C.TypeHysteria2: + return NewHysteria2(ctx, router, logger, options.Tag, options.Hysteria2Options) default: return nil, E.New("unknown inbound type: ", options.Type) } diff --git a/inbound/hysteria2.go b/inbound/hysteria2.go new file mode 100644 index 00000000..4b7bb9a1 --- /dev/null +++ b/inbound/hysteria2.go @@ -0,0 +1,144 @@ +//go:build with_quic + +package inbound + +import ( + "context" + "net" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/hysteria2" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +var _ adapter.Inbound = (*Hysteria2)(nil) + +type Hysteria2 struct { + myInboundAdapter + tlsConfig tls.ServerConfig + server *hysteria2.Server +} + +func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) { + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + var salamanderPassword string + if options.Obfs != nil { + if options.Obfs.Password == "" { + return nil, E.New("missing obfs password") + } + switch options.Obfs.Type { + case hysteria2.ObfsTypeSalamander: + salamanderPassword = options.Obfs.Password + default: + return nil, E.New("unknown obfs type: ", options.Obfs.Type) + } + } + var masqueradeHandler http.Handler + if options.Masquerade != "" { + masqueradeURL, err := url.Parse(options.Masquerade) + if err != nil { + return nil, E.Cause(err, "parse masquerade URL") + } + switch masqueradeURL.Scheme { + case "file": + masqueradeHandler = http.FileServer(http.Dir(masqueradeURL.Path)) + case "http", "https": + masqueradeHandler = &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(masqueradeURL) + r.Out.Host = r.In.Host + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + }, + } + default: + return nil, E.New("unknown masquerade URL scheme: ", masqueradeURL.Scheme) + } + } + inbound := &Hysteria2{ + myInboundAdapter: myInboundAdapter{ + protocol: C.TypeHysteria2, + network: []string{N.NetworkUDP}, + ctx: ctx, + router: router, + logger: logger, + tag: tag, + listenOptions: options.ListenOptions, + }, + tlsConfig: tlsConfig, + } + server, err := hysteria2.NewServer(hysteria2.ServerOptions{ + Context: ctx, + Logger: logger, + SendBPS: uint64(options.UpMbps * 1024 * 1024), + ReceiveBPS: uint64(options.DownMbps * 1024 * 1024), + SalamanderPassword: salamanderPassword, + TLSConfig: tlsConfig, + Users: common.Map(options.Users, func(it option.Hysteria2User) hysteria2.User { + return hysteria2.User(it) + }), + IgnoreClientBandwidth: options.IgnoreClientBandwidth, + Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), + MasqueradeHandler: masqueradeHandler, + }) + if err != nil { + return nil, err + } + inbound.server = server + return inbound, nil +} + +func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + metadata = h.createMetadata(conn, metadata) + metadata.User, _ = auth.UserFromContext[string](ctx) + return h.router.RouteConnection(ctx, conn, metadata) +} + +func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + metadata = h.createPacketMetadata(conn, metadata) + metadata.User, _ = auth.UserFromContext[string](ctx) + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + return h.router.RoutePacketConnection(ctx, conn, metadata) +} + +func (h *Hysteria2) Start() error { + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } + } + packetConn, err := h.myInboundAdapter.ListenUDP() + if err != nil { + return err + } + return h.server.Start(packetConn) +} + +func (h *Hysteria2) Close() error { + return common.Close( + &h.myInboundAdapter, + h.tlsConfig, + common.PtrOrNil(h.server), + ) +} diff --git a/inbound/hysteria_stub.go b/inbound/hysteria_stub.go index 1a56a5b6..fab86bb5 100644 --- a/inbound/hysteria_stub.go +++ b/inbound/hysteria_stub.go @@ -14,3 +14,7 @@ import ( func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (adapter.Inbound, error) { return nil, C.ErrQUICNotIncluded } + +func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (adapter.Inbound, error) { + return nil, C.ErrQUICNotIncluded +} diff --git a/option/hysteria2.go b/option/hysteria2.go new file mode 100644 index 00000000..48396ca1 --- /dev/null +++ b/option/hysteria2.go @@ -0,0 +1,33 @@ +package option + +type Hysteria2InboundOptions struct { + ListenOptions + UpMbps int `json:"up_mbps,omitempty"` + DownMbps int `json:"down_mbps,omitempty"` + Obfs *Hysteria2Obfs `json:"obfs,omitempty"` + Users []Hysteria2User `json:"users,omitempty"` + IgnoreClientBandwidth bool `json:"ignore_client_bandwidth,omitempty"` + TLS *InboundTLSOptions `json:"tls,omitempty"` + Masquerade string `json:"masquerade,omitempty"` +} + +type Hysteria2Obfs struct { + Type string `json:"type,omitempty"` + Password string `json:"password,omitempty"` +} + +type Hysteria2User struct { + Name string `json:"name,omitempty"` + Password string `json:"password,omitempty"` +} + +type Hysteria2OutboundOptions struct { + DialerOptions + ServerOptions + UpMbps int `json:"up_mbps,omitempty"` + DownMbps int `json:"down_mbps,omitempty"` + Obfs *Hysteria2Obfs `json:"obfs,omitempty"` + Password string `json:"password,omitempty"` + Network NetworkList `json:"network,omitempty"` + TLS *OutboundTLSOptions `json:"tls,omitempty"` +} diff --git a/option/inbound.go b/option/inbound.go index 64b45e6c..06408f58 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -24,6 +24,7 @@ type _Inbound struct { ShadowTLSOptions ShadowTLSInboundOptions `json:"-"` VLESSOptions VLESSInboundOptions `json:"-"` TUICOptions TUICInboundOptions `json:"-"` + Hysteria2Options Hysteria2InboundOptions `json:"-"` } type Inbound _Inbound @@ -61,6 +62,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) { v = h.VLESSOptions case C.TypeTUIC: v = h.TUICOptions + case C.TypeHysteria2: + v = h.Hysteria2Options default: return nil, E.New("unknown inbound type: ", h.Type) } @@ -104,6 +107,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error { v = &h.VLESSOptions case C.TypeTUIC: v = &h.TUICOptions + case C.TypeHysteria2: + v = &h.Hysteria2Options default: return E.New("unknown inbound type: ", h.Type) } diff --git a/option/outbound.go b/option/outbound.go index 5e837741..3d780ef4 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -24,6 +24,7 @@ type _Outbound struct { ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"` VLESSOptions VLESSOutboundOptions `json:"-"` TUICOptions TUICOutboundOptions `json:"-"` + Hysteria2Options Hysteria2OutboundOptions `json:"-"` SelectorOptions SelectorOutboundOptions `json:"-"` URLTestOptions URLTestOutboundOptions `json:"-"` } @@ -63,6 +64,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) { v = h.VLESSOptions case C.TypeTUIC: v = h.TUICOptions + case C.TypeHysteria2: + v = h.Hysteria2Options case C.TypeSelector: v = h.SelectorOptions case C.TypeURLTest: @@ -110,6 +113,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { v = &h.VLESSOptions case C.TypeTUIC: v = &h.TUICOptions + case C.TypeHysteria2: + v = &h.Hysteria2Options case C.TypeSelector: v = &h.SelectorOptions case C.TypeURLTest: diff --git a/outbound/builder.go b/outbound/builder.go index 92bdef27..141758d8 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -53,6 +53,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t return NewVLESS(ctx, router, logger, tag, options.VLESSOptions) case C.TypeTUIC: return NewTUIC(ctx, router, logger, tag, options.TUICOptions) + case C.TypeHysteria2: + return NewHysteria2(ctx, router, logger, tag, options.Hysteria2Options) case C.TypeSelector: return NewSelector(router, logger, tag, options.SelectorOptions) case C.TypeURLTest: diff --git a/outbound/hysteria2.go b/outbound/hysteria2.go new file mode 100644 index 00000000..f974e9a8 --- /dev/null +++ b/outbound/hysteria2.go @@ -0,0 +1,122 @@ +//go:build with_quic + +package outbound + +import ( + "context" + "net" + "os" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/hysteria2" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var ( + _ adapter.Outbound = (*TUIC)(nil) + _ adapter.InterfaceUpdateListener = (*TUIC)(nil) +) + +type Hysteria2 struct { + myOutboundAdapter + client *hysteria2.Client +} + +func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (*Hysteria2, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + var salamanderPassword string + if options.Obfs != nil { + if options.Obfs.Password == "" { + return nil, E.New("missing obfs password") + } + switch options.Obfs.Type { + case hysteria2.ObfsTypeSalamander: + salamanderPassword = options.Obfs.Password + default: + return nil, E.New("unknown obfs type: ", options.Obfs.Type) + } + } + outboundDialer, err := dialer.New(router, options.DialerOptions) + if err != nil { + return nil, err + } + networkList := options.Network.Build() + client, err := hysteria2.NewClient(hysteria2.ClientOptions{ + Context: ctx, + Dialer: outboundDialer, + ServerAddress: options.ServerOptions.Build(), + SendBPS: uint64(options.UpMbps * 1024 * 1024), + ReceiveBPS: uint64(options.DownMbps * 1024 * 1024), + SalamanderPassword: salamanderPassword, + Password: options.Password, + TLSConfig: tlsConfig, + UDPDisabled: !common.Contains(networkList, N.NetworkUDP), + }) + if err != nil { + return nil, err + } + return &Hysteria2{ + myOutboundAdapter: myOutboundAdapter{ + protocol: C.TypeHysteria2, + network: networkList, + router: router, + logger: logger, + tag: tag, + dependencies: withDialerDependency(options.DialerOptions), + }, + client: client, + }, nil +} + +func (h *Hysteria2) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.DialConn(ctx, destination) + case N.NetworkUDP: + conn, err := h.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(conn, destination), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *Hysteria2) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + return h.client.ListenPacket(ctx) +} + +func (h *Hysteria2) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) +} + +func (h *Hysteria2) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) +} + +func (h *Hysteria2) InterfaceUpdated() error { + return h.client.CloseWithError(E.New("network changed")) +} + +func (h *Hysteria2) Close() error { + return h.client.CloseWithError(os.ErrClosed) +} diff --git a/outbound/hysteria_stub.go b/outbound/hysteria_stub.go index 62fae20c..84db5305 100644 --- a/outbound/hysteria_stub.go +++ b/outbound/hysteria_stub.go @@ -14,3 +14,7 @@ import ( func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (adapter.Outbound, error) { return nil, C.ErrQUICNotIncluded } + +func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (adapter.Outbound, error) { + return nil, C.ErrQUICNotIncluded +} diff --git a/test/ech_test.go b/test/ech_test.go index a792d8c1..c533c5b8 100644 --- a/test/ech_test.go +++ b/test/ech_test.go @@ -168,3 +168,81 @@ func TestECHQUIC(t *testing.T) { }) testSuitLargeUDP(t, clientPort, testPort) } + +func TestECHHysteria2(t *testing.T) { + _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + echConfig, echKey := common.Must2(tls.ECHKeygenDefault("not.example.org", false)) + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeHysteria2, + Hysteria2Options: option.Hysteria2InboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + Users: []option.Hysteria2User{{ + Password: "password", + }}, + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + ECH: &option.InboundECHOptions{ + Enabled: true, + Key: []string{echKey}, + }, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeHysteria2, + Tag: "hy2-out", + Hysteria2Options: option.Hysteria2OutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + Password: "password", + TLS: &option.OutboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + ECH: &option.OutboundECHOptions{ + Enabled: true, + Config: []string{echConfig}, + }, + }, + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "hy2-out", + }, + }, + }, + }, + }) + testSuit(t, clientPort, testPort) +} diff --git a/test/hysteria2_test.go b/test/hysteria2_test.go new file mode 100644 index 00000000..22e9c164 --- /dev/null +++ b/test/hysteria2_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "net/netip" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/hysteria2" +) + +func TestHysteria2Self(t *testing.T) { + t.Run("self", func(t *testing.T) { + testHysteria2Self(t, "") + }) + t.Run("self-salamander", func(t *testing.T) { + testHysteria2Self(t, "password") + }) +} + +func testHysteria2Self(t *testing.T, salamanderPassword string) { + _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + var obfs *option.Hysteria2Obfs + if salamanderPassword != "" { + obfs = &option.Hysteria2Obfs{ + Type: hysteria2.ObfsTypeSalamander, + Password: salamanderPassword, + } + } + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeHysteria2, + Hysteria2Options: option.Hysteria2InboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + Obfs: obfs, + Users: []option.Hysteria2User{{ + Password: "password", + }}, + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeHysteria2, + Tag: "hy2-out", + Hysteria2Options: option.Hysteria2OutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + Obfs: obfs, + Password: "password", + TLS: &option.OutboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + }, + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "hy2-out", + }, + }, + }, + }, + }) + testSuit(t, clientPort, testPort) +} diff --git a/transport/hysteria2/client.go b/transport/hysteria2/client.go new file mode 100644 index 00000000..633cbe8e --- /dev/null +++ b/transport/hysteria2/client.go @@ -0,0 +1,306 @@ +package hysteria2 + +import ( + "context" + "io" + "net" + "net/http" + "net/url" + "os" + "runtime" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/common/qtls" + "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/transport/hysteria2/congestion" + "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" + tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const ( + defaultStreamReceiveWindow = 8388608 // 8MB + defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB + defaultMaxIdleTimeout = 30 * time.Second + defaultKeepAlivePeriod = 10 * time.Second +) + +type ClientOptions struct { + Context context.Context + Dialer N.Dialer + ServerAddress M.Socksaddr + SendBPS uint64 + ReceiveBPS uint64 + SalamanderPassword string + Password string + TLSConfig tls.Config + UDPDisabled bool +} + +type Client struct { + ctx context.Context + dialer N.Dialer + serverAddr M.Socksaddr + sendBPS uint64 + receiveBPS uint64 + salamanderPassword string + password string + tlsConfig tls.Config + quicConfig *quic.Config + udpDisabled bool + + connAccess sync.RWMutex + conn *clientQUICConnection +} + +func NewClient(options ClientOptions) (*Client, error) { + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + MaxDatagramFrameSize: 1400, + EnableDatagrams: true, + InitialStreamReceiveWindow: defaultStreamReceiveWindow, + MaxStreamReceiveWindow: defaultStreamReceiveWindow, + InitialConnectionReceiveWindow: defaultConnReceiveWindow, + MaxConnectionReceiveWindow: defaultConnReceiveWindow, + MaxIdleTimeout: defaultMaxIdleTimeout, + KeepAlivePeriod: defaultKeepAlivePeriod, + } + return &Client{ + ctx: options.Context, + dialer: options.Dialer, + serverAddr: options.ServerAddress, + sendBPS: options.SendBPS, + receiveBPS: options.ReceiveBPS, + salamanderPassword: options.SalamanderPassword, + password: options.Password, + tlsConfig: options.TLSConfig, + quicConfig: quicConfig, + udpDisabled: options.UDPDisabled, + }, nil +} + +func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { + conn := c.conn + if conn != nil && conn.active() { + return conn, nil + } + c.connAccess.Lock() + defer c.connAccess.Unlock() + conn = c.conn + if conn != nil && conn.active() { + return conn, nil + } + conn, err := c.offerNew(ctx) + if err != nil { + return nil, err + } + return conn, nil +} + +func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { + udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr) + if err != nil { + return nil, err + } + var packetConn net.PacketConn + packetConn = bufio.NewUnbindPacketConn(udpConn) + if c.salamanderPassword != "" { + packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword)) + } + var quicConn quic.EarlyConnection + http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true) + if err != nil { + udpConn.Close() + return nil, err + } + request := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "https", + Host: protocol.URLHost, + Path: protocol.URLPath, + }, + Header: make(http.Header), + } + protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS}) + response, err := http3Transport.RoundTrip(request) + if err != nil { + if quicConn != nil { + quicConn.CloseWithError(0, "") + } + udpConn.Close() + return nil, err + } + if response.StatusCode != protocol.StatusAuthOK { + if quicConn != nil { + quicConn.CloseWithError(0, "") + } + udpConn.Close() + return nil, E.New("authentication failed, status code: ", response.StatusCode) + } + response.Body.Close() + authResponse := protocol.AuthResponseFromHeader(response.Header) + actualTx := authResponse.Rx + if actualTx == 0 || actualTx > c.sendBPS { + actualTx = c.sendBPS + } + if !authResponse.RxAuto && actualTx > 0 { + quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx)) + } else { + quicConn.SetCongestionControl(tuicCongestion.NewBBRSender( + tuicCongestion.DefaultClock{}, + tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()), + tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize, + tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize, + )) + } + conn := &clientQUICConnection{ + quicConn: quicConn, + rawConn: udpConn, + connDone: make(chan struct{}), + udpDisabled: c.udpDisabled || !authResponse.UDPEnabled, + udpConnMap: make(map[uint32]*udpPacketConn), + } + if !c.udpDisabled { + go c.loopMessages(conn) + } + c.conn = conn + return conn, nil +} + +func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + stream, err := conn.quicConn.OpenStream() + if err != nil { + return nil, err + } + return &clientConn{ + Stream: stream, + destination: destination, + }, nil +} + +func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { + if c.udpDisabled { + return nil, os.ErrInvalid + } + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + if conn.udpDisabled { + return nil, E.New("UDP disabled by server") + } + var sessionID uint32 + clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() { + conn.udpAccess.Lock() + delete(conn.udpConnMap, sessionID) + conn.udpAccess.Unlock() + }) + conn.udpAccess.Lock() + sessionID = conn.udpSessionID + conn.udpSessionID++ + conn.udpConnMap[sessionID] = clientPacketConn + conn.udpAccess.Unlock() + clientPacketConn.sessionID = sessionID + return clientPacketConn, nil +} + +func (c *Client) CloseWithError(err error) error { + conn := c.conn + if conn != nil { + conn.closeWithError(err) + } + return nil +} + +type clientQUICConnection struct { + quicConn quic.Connection + rawConn io.Closer + closeOnce sync.Once + connDone chan struct{} + connErr error + udpDisabled bool + udpAccess sync.RWMutex + udpConnMap map[uint32]*udpPacketConn + udpSessionID uint32 +} + +func (c *clientQUICConnection) active() bool { + select { + case <-c.quicConn.Context().Done(): + return false + default: + } + select { + case <-c.connDone: + return false + default: + } + return true +} + +func (c *clientQUICConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.connErr = err + close(c.connDone) + c.quicConn.CloseWithError(0, "") + }) +} + +type clientConn struct { + quic.Stream + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func (c *clientConn) NeedHandshake() bool { + return !c.requestWritten +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if c.responseRead { + return c.Stream.Read(p) + } + status, errorMessage, err := protocol.ReadTCPResponse(c.Stream) + if err != nil { + return + } + if !status { + err = E.New("remote error: ", errorMessage) + return + } + c.responseRead = true + return c.Stream.Read(p) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if !c.requestWritten { + buffer := protocol.WriteTCPRequest(c.destination.String(), p) + defer buffer.Release() + _, err = c.Stream.Write(buffer.Bytes()) + if err != nil { + return + } + c.requestWritten = true + return len(p), nil + } + return c.Stream.Write(p) +} + +func (c *clientConn) LocalAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *clientConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} diff --git a/transport/hysteria2/client_paclet.go b/transport/hysteria2/client_paclet.go new file mode 100644 index 00000000..59e946c1 --- /dev/null +++ b/transport/hysteria2/client_paclet.go @@ -0,0 +1,47 @@ +package hysteria2 + +import E "github.com/sagernet/sing/common/exceptions" + +func (c *Client) loopMessages(conn *clientQUICConnection) { + for { + message, err := conn.quicConn.ReceiveMessage(c.ctx) + if err != nil { + conn.closeWithError(E.Cause(err, "receive message")) + return + } + go func() { + hErr := c.handleMessage(conn, message) + if hErr != nil { + conn.closeWithError(E.Cause(hErr, "handle message")) + } + }() + } +} + +func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { + message := udpMessagePool.Get().(*udpMessage) + err := decodeUDPMessage(message, data) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + conn.handleUDPMessage(message) + return nil +} + +func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { + c.udpAccess.RLock() + udpConn, loaded := c.udpConnMap[message.sessionID] + c.udpAccess.RUnlock() + if !loaded { + message.releaseMessage() + return + } + select { + case <-udpConn.ctx.Done(): + message.releaseMessage() + return + default: + } + udpConn.inputPacket(message) +} diff --git a/transport/hysteria2/congestion/brutal.go b/transport/hysteria2/congestion/brutal.go new file mode 100644 index 00000000..c52350c8 --- /dev/null +++ b/transport/hysteria2/congestion/brutal.go @@ -0,0 +1,151 @@ +package congestion + +import ( + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + initMaxDatagramSize = 1252 + + pktInfoSlotCount = 4 + minSampleCount = 50 + minAckRate = 0.8 +) + +var _ congestion.CongestionControl = &BrutalSender{} + +type BrutalSender struct { + rttStats congestion.RTTStatsProvider + bps congestion.ByteCount + maxDatagramSize congestion.ByteCount + pacer *pacer + + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 +} + +type pktInfo struct { + Timestamp int64 + AckCount uint64 + LossCount uint64 +} + +func NewBrutalSender(bps uint64) *BrutalSender { + bs := &BrutalSender{ + bps: congestion.ByteCount(bps), + maxDatagramSize: initMaxDatagramSize, + ackRate: 1, + } + bs.pacer = newPacer(func() congestion.ByteCount { + return congestion.ByteCount(float64(bs.bps) / bs.ackRate) + }) + return bs +} + +func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { + b.rttStats = rttStats +} + +func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +func (b *BrutalSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { + rtt := b.rttStats.SmoothedRTT() + if rtt <= 0 { + return 10240 + } + return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate) +} + +func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) +} + +func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, eventTime time.Time, +) { + currentTimestamp := eventTime.Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].AckCount++ + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = 1 + b.pktInfoSlots[slot].LossCount = 0 + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, +) { + currentTimestamp := time.Now().Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].LossCount++ + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = 0 + b.pktInfoSlots[slot].LossCount = 1 + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { + b.maxDatagramSize = size + b.pacer.SetMaxDatagramSize(size) +} + +func (b *BrutalSender) updateAckRate(currentTimestamp int64) { + minTimestamp := currentTimestamp - pktInfoSlotCount + var ackCount, lossCount uint64 + for _, info := range b.pktInfoSlots { + if info.Timestamp < minTimestamp { + continue + } + ackCount += info.AckCount + lossCount += info.LossCount + } + if ackCount+lossCount < minSampleCount { + b.ackRate = 1 + } + rate := float64(ackCount) / float64(ackCount+lossCount) + if rate < minAckRate { + b.ackRate = minAckRate + } + b.ackRate = rate +} + +func (b *BrutalSender) InSlowStart() bool { + return false +} + +func (b *BrutalSender) InRecovery() bool { + return false +} + +func (b *BrutalSender) MaybeExitSlowStart() {} + +func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/transport/hysteria2/congestion/pacer.go b/transport/hysteria2/congestion/pacer.go new file mode 100644 index 00000000..878985e5 --- /dev/null +++ b/transport/hysteria2/congestion/pacer.go @@ -0,0 +1,86 @@ +package congestion + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 + minPacingDelay = time.Millisecond +) + +// The pacer implements a token bucket pacing algorithm. +type pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func newPacer(getBandwidth func() congestion.ByteCount) *pacer { + p := &pacer{ + budgetAtLastSent: maxBurstPackets * initMaxDatagramSize, + maxDatagramSize: initMaxDatagramSize, + getBandwidth: getBandwidth, + } + return p +} + +func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return minByteCount(p.maxBurstSize(), budget) +} + +func (p *pacer) maxBurstSize() congestion.ByteCount { + return maxByteCount( + congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(maxDuration( + minPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ + float64(p.getBandwidth())))*time.Nanosecond, + )) +} + +func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} + +func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return b + } + return a +} + +func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return a + } + return b +} diff --git a/transport/hysteria2/internal/protocol/http.go b/transport/hysteria2/internal/protocol/http.go new file mode 100644 index 00000000..abcc1a4f --- /dev/null +++ b/transport/hysteria2/internal/protocol/http.go @@ -0,0 +1,68 @@ +package protocol + +import ( + "net/http" + "strconv" +) + +const ( + URLHost = "hysteria" + URLPath = "/auth" + + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" + + StatusAuthOK = 233 +) + +// AuthRequest is what client sends to server for authentication. +type AuthRequest struct { + Auth string + Rx uint64 // 0 = unknown, client asks server to use bandwidth detection +} + +// AuthResponse is what server sends to client when authentication is passed. +type AuthResponse struct { + UDPEnabled bool + Rx uint64 // 0 = unlimited + RxAuto bool // true = server asks client to use bandwidth detection +} + +func AuthRequestFromHeader(h http.Header) AuthRequest { + rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) + return AuthRequest{ + Auth: h.Get(RequestHeaderAuth), + Rx: rx, + } +} + +func AuthRequestToHeader(h http.Header, req AuthRequest) { + h.Set(RequestHeaderAuth, req.Auth) + h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) + h.Set(CommonHeaderPadding, authRequestPadding.String()) +} + +func AuthResponseFromHeader(h http.Header) AuthResponse { + resp := AuthResponse{} + resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) + rxStr := h.Get(CommonHeaderCCRX) + if rxStr == "auto" { + // Special case for server requesting client to use bandwidth detection + resp.RxAuto = true + } else { + resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) + } + return resp +} + +func AuthResponseToHeader(h http.Header, resp AuthResponse) { + h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) + if resp.RxAuto { + h.Set(CommonHeaderCCRX, "auto") + } else { + h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) + } + h.Set(CommonHeaderPadding, authResponsePadding.String()) +} diff --git a/transport/hysteria2/internal/protocol/padding.go b/transport/hysteria2/internal/protocol/padding.go new file mode 100644 index 00000000..9895cdcc --- /dev/null +++ b/transport/hysteria2/internal/protocol/padding.go @@ -0,0 +1,31 @@ +package protocol + +import ( + "math/rand" +) + +const ( + paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +// padding specifies a half-open range [Min, Max). +type padding struct { + Min int + Max int +} + +func (p padding) String() string { + n := p.Min + rand.Intn(p.Max-p.Min) + bs := make([]byte, n) + for i := range bs { + bs[i] = paddingChars[rand.Intn(len(paddingChars))] + } + return string(bs) +} + +var ( + authRequestPadding = padding{Min: 256, Max: 2048} + authResponsePadding = padding{Min: 256, Max: 2048} + tcpRequestPadding = padding{Min: 64, Max: 512} + tcpResponsePadding = padding{Min: 128, Max: 1024} +) diff --git a/transport/hysteria2/internal/protocol/proxy.go b/transport/hysteria2/internal/protocol/proxy.go new file mode 100644 index 00000000..795b3cbf --- /dev/null +++ b/transport/hysteria2/internal/protocol/proxy.go @@ -0,0 +1,266 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/sagernet/quic-go/quicvarint" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" +) + +const ( + FrameTypeTCPRequest = 0x401 + + // Max length values are for preventing DoS attacks + + MaxAddressLength = 2048 + MaxMessageLength = 2048 + MaxPaddingLength = 4096 + + MaxUDPSize = 4096 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// TCPRequest format: +// 0x401 (QUIC varint) +// Address length (QUIC varint) +// Address (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPRequest(r io.Reader) (string, error) { + bReader := quicvarint.NewReader(r) + addrLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if addrLen == 0 || addrLen > MaxAddressLength { + return "", E.New("invalid address length") + } + addrBuf := make([]byte, addrLen) + _, err = io.ReadFull(r, addrBuf) + if err != nil { + return "", err + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if paddingLen > MaxPaddingLength { + return "", E.New("invalid padding length") + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return "", err + } + } + return string(addrBuf), nil +} + +func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { + padding := tcpRequestPadding.String() + paddingLen := len(padding) + addrLen := len(addr) + sz := int(quicvarint.Len(FrameTypeTCPRequest)) + + int(quicvarint.Len(uint64(addrLen))) + addrLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buffer := buf.NewSize(sz + len(payload)) + bufferContent := buffer.Extend(sz) + i := varintPut(bufferContent, FrameTypeTCPRequest) + i += varintPut(bufferContent[i:], uint64(addrLen)) + i += copy(bufferContent[i:], addr) + i += varintPut(bufferContent[i:], uint64(paddingLen)) + copy(bufferContent[i:], padding) + buffer.Write(payload) + return buffer +} + +// TCPResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (QUIC varint) +// Message (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + bReader := quicvarint.NewReader(r) + msg, err := ReadVString(bReader) + if err != nil { + return false, "", err + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if paddingLen > MaxPaddingLength { + return false, "", E.New("invalid padding length") + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return false, "", err + } + } + return status[0] == 0, msg, nil +} + +func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { + padding := tcpResponsePadding.String() + paddingLen := len(padding) + msgLen := len(msg) + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buffer := buf.NewSize(sz + len(payload)) + if ok { + buffer.WriteByte(0) + } else { + buffer.WriteByte(1) + } + WriteVString(buffer, msg) + WriteUVariant(buffer, uint64(paddingLen)) + buffer.Extend(paddingLen) + buffer.Write(payload) + return buffer +} + +// UDPMessage format: +// Session ID (uint32 BE) +// Packet ID (uint16 BE) +// Fragment ID (uint8) +// Fragment count (uint8) +// Address length (QUIC varint) +// Address (bytes) +// Data... + +type UDPMessage struct { + SessionID uint32 // 4 + PacketID uint16 // 2 + FragID uint8 // 1 + FragCount uint8 // 1 + Addr string // varint + bytes + Data []byte +} + +func (m *UDPMessage) HeaderSize() int { + lAddr := len(m.Addr) + return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr +} + +func (m *UDPMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} + +func (m *UDPMessage) Serialize(buf []byte) int { + // Make sure the buffer is big enough + if len(buf) < m.Size() { + return -1 + } + binary.BigEndian.PutUint32(buf, m.SessionID) + binary.BigEndian.PutUint16(buf[4:], m.PacketID) + buf[6] = m.FragID + buf[7] = m.FragCount + i := varintPut(buf[8:], uint64(len(m.Addr))) + i += copy(buf[8+i:], m.Addr) + i += copy(buf[8+i:], m.Data) + return 8 + i +} + +func ParseUDPMessage(msg []byte) (*UDPMessage, error) { + m := &UDPMessage{} + buf := bytes.NewBuffer(msg) + if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { + return nil, err + } + lAddr, err := quicvarint.Read(buf) + if err != nil { + return nil, err + } + if lAddr == 0 || lAddr > MaxMessageLength { + return nil, E.New("invalid address length") + } + bs := buf.Bytes() + m.Addr = string(bs[:lAddr]) + m.Data = bs[lAddr:] + return m, nil +} + +func ReadVString(reader io.Reader) (string, error) { + length, err := quicvarint.Read(quicvarint.NewReader(reader)) + if err != nil { + return "", err + } + value, err := rw.ReadBytes(reader, int(length)) + if err != nil { + return "", err + } + return string(value), nil +} + +func WriteVString(writer io.Writer, value string) error { + err := WriteUVariant(writer, uint64(len(value))) + if err != nil { + return err + } + return rw.WriteString(writer, value) +} + +func WriteUVariant(writer io.Writer, value uint64) error { + var b [8]byte + return common.Error(writer.Write(b[:varintPut(b[:], value)])) +} + +// varintPut is like quicvarint.Append, but instead of appending to a slice, +// it writes to a fixed-size buffer. Returns the number of bytes written. +func varintPut(b []byte, i uint64) int { + if i <= maxVarInt1 { + b[0] = uint8(i) + return 1 + } + if i <= maxVarInt2 { + b[0] = uint8(i>>8) | 0x40 + b[1] = uint8(i) + return 2 + } + if i <= maxVarInt4 { + b[0] = uint8(i>>24) | 0x80 + b[1] = uint8(i >> 16) + b[2] = uint8(i >> 8) + b[3] = uint8(i) + return 4 + } + if i <= maxVarInt8 { + b[0] = uint8(i>>56) | 0xc0 + b[1] = uint8(i >> 48) + b[2] = uint8(i >> 40) + b[3] = uint8(i >> 32) + b[4] = uint8(i >> 24) + b[5] = uint8(i >> 16) + b[6] = uint8(i >> 8) + b[7] = uint8(i) + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/transport/hysteria2/packet.go b/transport/hysteria2/packet.go new file mode 100644 index 00000000..ed544cdd --- /dev/null +++ b/transport/hysteria2/packet.go @@ -0,0 +1,438 @@ +package hysteria2 + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "math" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/quicvarint" + "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + M "github.com/sagernet/sing/common/metadata" +) + +var udpMessagePool = sync.Pool{ + New: func() interface{} { + return new(udpMessage) + }, +} + +func releaseMessages(messages []*udpMessage) { + for _, message := range messages { + if message != nil { + *message = udpMessage{} + udpMessagePool.Put(message) + } + } +} + +type udpMessage struct { + sessionID uint32 + packetID uint16 + fragmentID uint8 + fragmentTotal uint8 + destination string + data *buf.Buffer +} + +func (m *udpMessage) release() { + *m = udpMessage{} + udpMessagePool.Put(m) +} + +func (m *udpMessage) releaseMessage() { + m.data.Release() + m.release() +} + +func (m *udpMessage) pack() *buf.Buffer { + buffer := buf.NewSize(m.headerSize() + m.data.Len()) + common.Must( + binary.Write(buffer, binary.BigEndian, m.sessionID), + binary.Write(buffer, binary.BigEndian, m.packetID), + binary.Write(buffer, binary.BigEndian, m.fragmentID), + binary.Write(buffer, binary.BigEndian, m.fragmentTotal), + protocol.WriteVString(buffer, m.destination), + common.Error(buffer.Write(m.data.Bytes())), + ) + return buffer +} + +func (m *udpMessage) headerSize() int { + return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination) +} + +func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { + if message.data.Len() <= maxPacketSize { + return []*udpMessage{message} + } + var fragments []*udpMessage + originPacket := message.data.Bytes() + udpMTU := maxPacketSize - message.headerSize() + for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { + fragment := udpMessagePool.Get().(*udpMessage) + *fragment = *message + if remaining > udpMTU { + fragment.data = buf.As(originPacket[:udpMTU]) + originPacket = originPacket[udpMTU:] + } else { + fragment.data = buf.As(originPacket) + originPacket = nil + } + fragments = append(fragments, fragment) + } + fragmentTotal := uint16(len(fragments)) + for index, fragment := range fragments { + fragment.fragmentID = uint8(index) + fragment.fragmentTotal = uint8(fragmentTotal) + /*if index > 0 { + fragment.destination = "" + // not work in hysteria + }*/ + } + return fragments +} + +type udpPacketConn struct { + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint32 + quicConn quic.Connection + data chan *udpMessage + udpMTU int + udpMTUTime time.Time + packetId atomic.Uint32 + closeOnce sync.Once + defragger *udpDefragger + onDestroy func() +} + +func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { + ctx, cancel := common.ContextWithCancelCause(ctx) + return &udpPacketConn{ + ctx: ctx, + cancel: cancel, + quicConn: quicConn, + data: make(chan *udpMessage, 64), + defragger: newUDPDefragger(), + onDestroy: onDestroy, + } +} + +func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + buffer = p.data + destination = M.ParseSocksaddr(p.destination) + p.release() + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = buffer.ReadOnceFrom(p.data) + destination = M.ParseSocksaddr(p.destination) + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = newBuffer().ReadOnceFrom(p.data) + destination = M.ParseSocksaddr(p.destination) + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + select { + case pkt := <-c.data: + n = copy(p, pkt.data.Bytes()) + destination := M.ParseSocksaddr(pkt.destination) + if destination.IsFqdn() { + addr = destination + } else { + addr = destination.UDPAddr() + } + pkt.releaseMessage() + return n, addr, nil + case <-c.ctx.Done(): + return 0, nil, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) needFragment() bool { + nowTime := time.Now() + if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { + c.udpMTUTime = nowTime + return true + } + return false +} + +func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + select { + case <-c.ctx.Done(): + return net.ErrClosed + default: + } + if buffer.Len() > 0xffff { + return quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := udpMessagePool.Get().(*udpMessage) + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + destination: destination.String(), + data: buffer, + } + defer message.releaseMessage() + var err error + if c.needFragment() && buffer.Len() > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + } else { + err = c.writePacket(message) + } + if err == nil { + return nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return err + } + c.udpMTU = int(tooLargeErr) + c.udpMTUTime = time.Now() + return c.writePackets(fragUDPMessage(message, c.udpMTU)) +} + +func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + select { + case <-c.ctx.Done(): + return 0, net.ErrClosed + default: + } + if len(p) > 0xffff { + return 0, quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := udpMessagePool.Get().(*udpMessage) + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + destination: addr.String(), + data: buf.As(p), + } + if c.needFragment() && len(p) > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + } else { + err = c.writePacket(message) + } + if err == nil { + return len(p), nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return + } + c.udpMTU = int(tooLargeErr) + c.udpMTUTime = time.Now() + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + return +} + +func (c *udpPacketConn) inputPacket(message *udpMessage) { + if message.fragmentTotal <= 1 { + select { + case c.data <- message: + default: + } + } else { + newMessage := c.defragger.feed(message) + if newMessage != nil { + select { + case c.data <- newMessage: + default: + } + } + } +} + +func (c *udpPacketConn) writePackets(messages []*udpMessage) error { + defer releaseMessages(messages) + for _, message := range messages { + err := c.writePacket(message) + if err != nil { + return err + } + } + return nil +} + +func (c *udpPacketConn) writePacket(message *udpMessage) error { + buffer := message.pack() + defer buffer.Release() + return c.quicConn.SendMessage(buffer.Bytes()) +} + +func (c *udpPacketConn) Close() error { + c.closeOnce.Do(func() { + c.closeWithError(os.ErrClosed) + c.onDestroy() + }) + return nil +} + +func (c *udpPacketConn) closeWithError(err error) { + c.cancel(err) +} + +func (c *udpPacketConn) LocalAddr() net.Addr { + return c.quicConn.LocalAddr() +} + +func (c *udpPacketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +type udpDefragger struct { + packetMap *cache.LruCache[uint16, *packetItem] +} + +func newUDPDefragger() *udpDefragger { + return &udpDefragger{ + packetMap: cache.New( + cache.WithAge[uint16, *packetItem](10), + cache.WithUpdateAgeOnGet[uint16, *packetItem](), + cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { + releaseMessages(value.messages) + }), + ), + } +} + +type packetItem struct { + access sync.Mutex + messages []*udpMessage + count uint8 +} + +func (d *udpDefragger) feed(m *udpMessage) *udpMessage { + if m.fragmentTotal <= 1 { + return m + } + if m.fragmentID >= m.fragmentTotal { + return nil + } + item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) + item.access.Lock() + defer item.access.Unlock() + if int(m.fragmentTotal) != len(item.messages) { + releaseMessages(item.messages) + item.messages = make([]*udpMessage, m.fragmentTotal) + item.count = 1 + item.messages[m.fragmentID] = m + return nil + } + if item.messages[m.fragmentID] != nil { + return nil + } + item.messages[m.fragmentID] = m + item.count++ + if int(item.count) != len(item.messages) { + return nil + } + newMessage := udpMessagePool.Get().(*udpMessage) + *newMessage = *item.messages[0] + var finalLength int + for _, message := range item.messages { + finalLength += message.data.Len() + } + if finalLength > 0 { + newMessage.data = buf.NewSize(finalLength) + for _, message := range item.messages { + newMessage.data.Write(message.data.Bytes()) + message.releaseMessage() + } + item.messages = nil + return newMessage + } + return nil +} + +func newPacketItem() *packetItem { + return new(packetItem) +} + +func decodeUDPMessage(message *udpMessage, data []byte) error { + reader := bytes.NewReader(data) + err := binary.Read(reader, binary.BigEndian, &message.sessionID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.packetID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) + if err != nil { + return err + } + message.destination, err = protocol.ReadVString(reader) + if err != nil { + return err + } + message.data = buf.As(data[len(data)-reader.Len():]) + return nil +} diff --git a/transport/hysteria2/salamander.go b/transport/hysteria2/salamander.go new file mode 100644 index 00000000..9b734d52 --- /dev/null +++ b/transport/hysteria2/salamander.go @@ -0,0 +1,106 @@ +package hysteria2 + +import ( + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/crypto/blake2b" +) + +const salamanderSaltLen = 8 + +const ObfsTypeSalamander = "salamander" + +type Salamander struct { + net.PacketConn + password []byte +} + +func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn { + writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn) + if isVectorised { + return &VectorisedSalamander{ + Salamander: Salamander{ + PacketConn: conn, + password: password, + }, + writer: writer, + } + } else { + return &Salamander{ + PacketConn: conn, + password: password, + } + } +} + +func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = s.PacketConn.ReadFrom(p) + if err != nil { + return + } + if n <= salamanderSaltLen { + return 0, nil, E.New("salamander: packet too short") + } + key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...)) + for index, c := range p[salamanderSaltLen:n] { + p[index] = c ^ key[index%blake2b.Size256] + } + return n - salamanderSaltLen, addr, nil +} + +func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { + buffer := buf.NewSize(len(p) + salamanderSaltLen) + defer buffer.Release() + buffer.WriteRandom(salamanderSaltLen) + key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) + for index, c := range p { + common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256])) + } + _, err = s.PacketConn.WriteTo(buffer.Bytes(), addr) + if err != nil { + return + } + return len(p), nil +} + +type VectorisedSalamander struct { + Salamander + writer N.VectorisedPacketWriter +} + +func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { + buffer := buf.NewSize(salamanderSaltLen) + buffer.WriteRandom(salamanderSaltLen) + key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) + for i := range p { + p[i] ^= key[i%blake2b.Size256] + } + err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + return len(p), nil +} + +func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { + header := buf.NewSize(salamanderSaltLen) + defer header.Release() + header.WriteRandom(salamanderSaltLen) + key := blake2b.Sum256(append(s.password, header.Bytes()...)) + var bufferIndex int + for _, buffer := range buffers { + content := buffer.Bytes() + for index, c := range content { + content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256] + } + bufferIndex += len(content) + } + return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination) +} diff --git a/transport/hysteria2/server.go b/transport/hysteria2/server.go new file mode 100644 index 00000000..53895720 --- /dev/null +++ b/transport/hysteria2/server.go @@ -0,0 +1,336 @@ +package hysteria2 + +import ( + "context" + "io" + "net" + "net/http" + "os" + "runtime" + "strings" + "sync" + + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/http3" + "github.com/sagernet/sing-box/common/baderror" + "github.com/sagernet/sing-box/common/qtls" + "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/transport/hysteria2/congestion" + "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" + tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ServerOptions struct { + Context context.Context + Logger logger.Logger + SendBPS uint64 + ReceiveBPS uint64 + IgnoreClientBandwidth bool + SalamanderPassword string + TLSConfig tls.ServerConfig + Users []User + UDPDisabled bool + Handler ServerHandler + MasqueradeHandler http.Handler +} + +type User struct { + Name string + Password string +} + +type ServerHandler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler +} + +type Server struct { + ctx context.Context + logger logger.Logger + sendBPS uint64 + receiveBPS uint64 + ignoreClientBandwidth bool + salamanderPassword string + tlsConfig tls.ServerConfig + quicConfig *quic.Config + userMap map[string]User + udpDisabled bool + handler ServerHandler + masqueradeHandler http.Handler + quicListener io.Closer +} + +func NewServer(options ServerOptions) (*Server, error) { + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + MaxDatagramFrameSize: 1400, + EnableDatagrams: !options.UDPDisabled, + MaxIncomingStreams: 1 << 60, + InitialStreamReceiveWindow: defaultStreamReceiveWindow, + MaxStreamReceiveWindow: defaultStreamReceiveWindow, + InitialConnectionReceiveWindow: defaultConnReceiveWindow, + MaxConnectionReceiveWindow: defaultConnReceiveWindow, + MaxIdleTimeout: defaultMaxIdleTimeout, + KeepAlivePeriod: defaultKeepAlivePeriod, + } + if len(options.Users) == 0 { + return nil, E.New("missing users") + } + userMap := make(map[string]User) + for _, user := range options.Users { + userMap[user.Password] = user + } + if options.MasqueradeHandler == nil { + options.MasqueradeHandler = http.NotFoundHandler() + } + return &Server{ + ctx: options.Context, + logger: options.Logger, + sendBPS: options.SendBPS, + receiveBPS: options.ReceiveBPS, + ignoreClientBandwidth: options.IgnoreClientBandwidth, + salamanderPassword: options.SalamanderPassword, + tlsConfig: options.TLSConfig, + quicConfig: quicConfig, + userMap: userMap, + udpDisabled: options.UDPDisabled, + handler: options.Handler, + masqueradeHandler: options.MasqueradeHandler, + }, nil +} + +func (s *Server) Start(conn net.PacketConn) error { + if s.salamanderPassword != "" { + conn = NewSalamanderConn(conn, []byte(s.salamanderPassword)) + } + err := qtls.ConfigureHTTP3(s.tlsConfig) + if err != nil { + return err + } + listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) + if err != nil { + return err + } + s.quicListener = listener + go s.loopConnections(listener) + return nil +} + +func (s *Server) Close() error { + return common.Close( + s.quicListener, + ) +} + +func (s *Server) loopConnections(listener qtls.QUICListener) { + for { + connection, err := listener.Accept(s.ctx) + if err != nil { + if strings.Contains(err.Error(), "server closed") { + s.logger.Debug(E.Cause(err, "listener closed")) + } else { + s.logger.Error(E.Cause(err, "listener closed")) + } + return + } + go s.handleConnection(connection) + } +} + +func (s *Server) handleConnection(connection quic.Connection) { + session := &serverSession{ + Server: s, + ctx: s.ctx, + quicConn: connection, + source: M.SocksaddrFromNet(connection.RemoteAddr()), + connDone: make(chan struct{}), + udpConnMap: make(map[uint32]*udpPacketConn), + } + httpServer := http3.Server{ + Handler: session, + StreamHijacker: session.handleStream0, + } + _ = httpServer.ServeQUICConn(connection) + _ = connection.CloseWithError(0, "") +} + +type serverSession struct { + *Server + ctx context.Context + quicConn quic.Connection + source M.Socksaddr + connAccess sync.Mutex + connDone chan struct{} + connErr error + authenticated bool + authUser *User + udpAccess sync.RWMutex + udpConnMap map[uint32]*udpPacketConn +} + +func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { + if s.authenticated { + protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ + UDPEnabled: !s.udpDisabled, + Rx: s.receiveBPS, + RxAuto: s.ignoreClientBandwidth, + }) + w.WriteHeader(protocol.StatusAuthOK) + return + } + request := protocol.AuthRequestFromHeader(r.Header) + user, loaded := s.userMap[request.Auth] + if !loaded { + s.masqueradeHandler.ServeHTTP(w, r) + return + } + s.authUser = &user + s.authenticated = true + if !s.ignoreClientBandwidth && request.Rx > 0 { + var sendBps uint64 + if s.sendBPS > 0 && s.sendBPS < request.Rx { + sendBps = s.sendBPS + } else { + sendBps = request.Rx + } + s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps)) + } else { + s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender( + tuicCongestion.DefaultClock{}, + tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()), + tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize, + tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize, + )) + } + protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ + UDPEnabled: !s.udpDisabled, + Rx: s.receiveBPS, + RxAuto: s.ignoreClientBandwidth, + }) + w.WriteHeader(protocol.StatusAuthOK) + if s.ctx.Done() != nil { + go func() { + select { + case <-s.ctx.Done(): + s.closeWithError(s.ctx.Err()) + case <-s.connDone: + } + }() + } + if !s.udpDisabled { + go s.loopMessages() + } + } else { + s.masqueradeHandler.ServeHTTP(w, r) + } +} + +func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) { + if !s.authenticated || err != nil { + return false, nil + } + if frameType != protocol.FrameTypeTCPRequest { + return false, nil + } + go func() { + hErr := s.handleStream(stream) + if hErr != nil { + stream.CancelRead(0) + stream.Close() + s.logger.Error(E.Cause(hErr, "handle stream request")) + } + }() + return true, nil +} + +func (s *serverSession) handleStream(stream quic.Stream) error { + destinationString, err := protocol.ReadTCPRequest(stream) + if err != nil { + return E.New("read TCP request") + } + var conn net.Conn = &serverConn{ + Stream: stream, + } + ctx := s.ctx + if s.authUser.Name != "" { + ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) + } + _ = s.handler.NewConnection(ctx, conn, M.Metadata{ + Source: s.source, + Destination: M.ParseSocksaddr(destinationString), + }) + return nil +} + +func (s *serverSession) closeWithError(err error) { + s.connAccess.Lock() + defer s.connAccess.Unlock() + select { + case <-s.connDone: + return + default: + s.connErr = err + close(s.connDone) + } + if E.IsClosedOrCanceled(err) { + s.logger.Debug(E.Cause(err, "connection failed")) + } else { + s.logger.Error(E.Cause(err, "connection failed")) + } + _ = s.quicConn.CloseWithError(0, "") +} + +type serverConn struct { + quic.Stream + responseWritten bool +} + +func (c *serverConn) HandshakeFailure(err error) error { + if c.responseWritten { + return os.ErrClosed + } + c.responseWritten = true + buffer := protocol.WriteTCPResponse(false, err.Error(), nil) + defer buffer.Release() + return common.Error(c.Stream.Write(buffer.Bytes())) +} + +func (c *serverConn) Read(p []byte) (n int, err error) { + n, err = c.Stream.Read(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + c.responseWritten = true + buffer := protocol.WriteTCPResponse(true, "", p) + defer buffer.Release() + _, err = c.Stream.Write(buffer.Bytes()) + if err != nil { + return 0, baderror.WrapQUIC(err) + } + return len(p), nil + } + n, err = c.Stream.Write(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) LocalAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *serverConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *serverConn) Close() error { + c.Stream.CancelRead(0) + return c.Stream.Close() +} diff --git a/transport/hysteria2/server_packet.go b/transport/hysteria2/server_packet.go new file mode 100644 index 00000000..c09ba6eb --- /dev/null +++ b/transport/hysteria2/server_packet.go @@ -0,0 +1,55 @@ +package hysteria2 + +import ( + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +func (s *serverSession) loopMessages() { + for { + message, err := s.quicConn.ReceiveMessage(s.ctx) + if err != nil { + s.closeWithError(E.Cause(err, "receive message")) + return + } + hErr := s.handleMessage(message) + if hErr != nil { + s.closeWithError(E.Cause(hErr, "handle message")) + return + } + } +} + +func (s *serverSession) handleMessage(data []byte) error { + message := udpMessagePool.Get().(*udpMessage) + err := decodeUDPMessage(message, data) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + s.handleUDPMessage(message) + return nil +} + +func (s *serverSession) handleUDPMessage(message *udpMessage) { + s.udpAccess.RLock() + udpConn, loaded := s.udpConnMap[message.sessionID] + s.udpAccess.RUnlock() + if !loaded || common.Done(udpConn.ctx) { + udpConn = newUDPPacketConn(s.ctx, s.quicConn, func() { + s.udpAccess.Lock() + delete(s.udpConnMap, message.sessionID) + s.udpAccess.Unlock() + }) + udpConn.sessionID = message.sessionID + s.udpAccess.Lock() + s.udpConnMap[message.sessionID] = udpConn + s.udpAccess.Unlock() + go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{ + Source: s.source, + Destination: M.ParseSocksaddr(message.destination), + }) + } + udpConn.inputPacket(message) +}