From 0f4a0a7275313f164c7b61ab8f081a191197a772 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 25 Nov 2022 08:08:14 +0800 Subject: [PATCH] chore: add tuic outbound close #133 --- README.md | 19 ++ adapter/outbound/tuic.go | 216 +++++++++++++++++ adapter/parser.go | 13 +- component/dialer/dialer.go | 28 ++- constant/adapters.go | 3 + docs/config.yaml | 14 ++ go.mod | 2 +- transport/tuic/client.go | 377 ++++++++++++++++++++++++++++++ transport/tuic/protocol.go | 468 +++++++++++++++++++++++++++++++++++++ 9 files changed, 1125 insertions(+), 15 deletions(-) create mode 100644 adapter/outbound/tuic.go create mode 100644 transport/tuic/client.go create mode 100644 transport/tuic/protocol.go diff --git a/README.md b/README.md index ce16182d..4a3fc3b5 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,25 @@ proxies: udp: true ``` +Support outbound transport protocol `Tuic` +```yaml +proxies: + - name: "tuic" + server: www.example.com + port: 10443 + type: tuic + token: TOKEN + # ip: 127.0.0.1 + # heartbeat_interval: 10000 + # alpn: [h3] + # disable_sni: true + reduce_rtt: true +# request_timeout: 8000 + udp_relay_mode: native + # skip-cert-verify: true + +``` + ### IPTABLES configuration Work on Linux OS who's supported `iptables` diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go new file mode 100644 index 00000000..43843fdb --- /dev/null +++ b/adapter/outbound/tuic.go @@ -0,0 +1,216 @@ +package outbound + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "encoding/pem" + "fmt" + "net" + "os" + "strconv" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + + "github.com/Dreamacro/clash/component/dialer" + tlsC "github.com/Dreamacro/clash/component/tls" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/tuic" +) + +type Tuic struct { + *Base + option *TuicOption + getClient func(opts ...dialer.Option) *tuic.Client +} + +type TuicOption struct { + BasicOption + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + Token string `proxy:"token"` + Ip string `proxy:"ip,omitempty"` + HeartbeatInterval int `proxy:"heartbeat_interval,omitempty"` + ALPN []string `proxy:"alpn,omitempty"` + ReduceRtt bool `proxy:"reduce_rtt,omitempty"` + RequestTimeout int `proxy:"request_timeout,omitempty"` + UdpRelayMode string `proxy:"udp_relay_mode,omitempty"` + DisableSni bool `proxy:"disable_sni,omitempty"` + + SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` + Fingerprint string `proxy:"fingerprint,omitempty"` + CustomCA string `proxy:"ca,omitempty"` + CustomCAString string `proxy:"ca_str,omitempty"` + ReceiveWindowConn int `proxy:"recv_window_conn,omitempty"` + ReceiveWindow int `proxy:"recv_window,omitempty"` + DisableMTUDiscovery bool `proxy:"disable_mtu_discovery,omitempty"` +} + +// DialContext implements C.ProxyAdapter +func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { + opts = t.Base.DialOptions(opts...) + conn, err := t.getClient(opts...).DialContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) + if err != nil { + return nil, nil, err + } + addr, err := resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer) + if err != nil { + return nil, nil, err + } + return pc, addr, err + }) + if err != nil { + return nil, err + } + return NewConn(conn, t), err +} + +// ListenPacketContext implements C.ProxyAdapter +func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { + opts = t.Base.DialOptions(opts...) + pc, err := t.getClient(opts...).ListenPacketContext(ctx, metadata, func(ctx context.Context) (net.PacketConn, net.Addr, error) { + pc, err := dialer.ListenPacket(ctx, "udp", "", opts...) + if err != nil { + return nil, nil, err + } + addr, err := resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer) + if err != nil { + return nil, nil, err + } + return pc, addr, err + }) + if err != nil { + return nil, err + } + return newPacketConn(pc, t), nil +} + +func NewTuic(option TuicOption) (*Tuic, error) { + addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) + serverName := option.Server + + tlsConfig := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: option.SkipCertVerify, + MinVersion: tls.VersionTLS13, + } + + var bs []byte + var err error + if len(option.CustomCA) > 0 { + bs, err = os.ReadFile(option.CustomCA) + if err != nil { + return nil, fmt.Errorf("hysteria %s load ca error: %w", addr, err) + } + } else if option.CustomCAString != "" { + bs = []byte(option.CustomCAString) + } + + if len(bs) > 0 { + block, _ := pem.Decode(bs) + if block == nil { + return nil, fmt.Errorf("CA cert is not PEM") + } + + fpBytes := sha256.Sum256(block.Bytes) + if len(option.Fingerprint) == 0 { + option.Fingerprint = hex.EncodeToString(fpBytes[:]) + } + } + + if len(option.Fingerprint) != 0 { + var err error + tlsConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint) + if err != nil { + return nil, err + } + } else { + tlsConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig) + } + + if len(option.ALPN) > 0 { + tlsConfig.NextProtos = option.ALPN + } else { + tlsConfig.NextProtos = []string{"h3"} + } + + if option.RequestTimeout == 0 { + option.RequestTimeout = 8000 + } + + if option.HeartbeatInterval <= 0 { + option.HeartbeatInterval = 10000 + } + + if option.UdpRelayMode != "quic" { + option.UdpRelayMode = "native" + } + + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn), + MaxStreamReceiveWindow: uint64(option.ReceiveWindowConn), + InitialConnectionReceiveWindow: uint64(option.ReceiveWindow), + MaxConnectionReceiveWindow: uint64(option.ReceiveWindow), + KeepAlivePeriod: time.Duration(option.HeartbeatInterval) * time.Millisecond, + DisablePathMTUDiscovery: option.DisableMTUDiscovery, + EnableDatagrams: true, + } + if option.ReceiveWindowConn == 0 { + quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow / 10 + quicConfig.MaxStreamReceiveWindow = DefaultStreamReceiveWindow + } + if option.ReceiveWindow == 0 { + quicConfig.InitialConnectionReceiveWindow = DefaultConnectionReceiveWindow / 10 + quicConfig.MaxConnectionReceiveWindow = DefaultConnectionReceiveWindow + } + + if len(option.Ip) > 0 { + addr = net.JoinHostPort(option.Ip, strconv.Itoa(option.Port)) + } + host := option.Server + if option.DisableSni { + host = "" + tlsConfig.ServerName = "" + } + tkn := tuic.GenTKN(option.Token) + clientMap := make(map[any]*tuic.Client) + clientMapMutex := sync.Mutex{} + getClient := func(opts ...dialer.Option) *tuic.Client { + o := *dialer.ApplyOptions(opts...) + + clientMapMutex.Lock() + defer clientMapMutex.Unlock() + if client, ok := clientMap[o]; ok && client != nil { + return client + } + client := &tuic.Client{ + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Host: host, + Token: tkn, + UdpRelayMode: option.UdpRelayMode, + ReduceRtt: option.ReduceRtt, + RequestTimeout: option.RequestTimeout, + } + clientMap[o] = client + return client + } + + return &Tuic{ + Base: &Base{ + name: option.Name, + addr: addr, + tp: C.Tuic, + udp: true, + iface: option.Interface, + prefer: C.NewDNSPrefer(option.IPVersion), + }, + option: &option, + getClient: getClient, + }, nil +} diff --git a/adapter/parser.go b/adapter/parser.go index 5d145998..0ce054f8 100644 --- a/adapter/parser.go +++ b/adapter/parser.go @@ -89,12 +89,19 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) { } proxy, err = outbound.NewHysteria(*hyOption) case "wireguard": - hyOption := &outbound.WireGuardOption{} - err = decoder.Decode(mapping, hyOption) + wgOption := &outbound.WireGuardOption{} + err = decoder.Decode(mapping, wgOption) if err != nil { break } - proxy, err = outbound.NewWireGuard(*hyOption) + proxy, err = outbound.NewWireGuard(*wgOption) + case "tuic": + tuicOption := &outbound.TuicOption{} + err = decoder.Decode(mapping, tuicOption) + if err != nil { + break + } + proxy, err = outbound.NewTuic(*tuicOption) default: return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) } diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index d3dc36ba..cb87061c 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -22,7 +22,7 @@ var ( ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel") ) -func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { +func ApplyOptions(options ...Option) *option { opt := &option{ interfaceName: DefaultInterface.Load(), routingMark: int(DefaultRoutingMark.Load()), @@ -36,6 +36,12 @@ func DialContext(ctx context.Context, network, address string, options ...Option o(opt) } + return opt +} + +func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { + opt := ApplyOptions(options...) + if opt.network == 4 || opt.network == 6 { if strings.Contains(network, "tcp") { network = "tcp" @@ -204,15 +210,15 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt } } case <-ctx.Done(): - err=ctx.Err() + err = ctx.Err() break } } - if err==nil { - err=fmt.Errorf("dual stack dial failed") - }else{ - err=fmt.Errorf("dual stack dial failed:%w",err) + if err == nil { + err = fmt.Errorf("dual stack dial failed") + } else { + err = fmt.Errorf("dual stack dial failed:%w", err) } return nil, err } @@ -322,7 +328,7 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr if fallback.done && fallback.error == nil { return fallback.Conn, nil } - finalError=ctx.Err() + finalError = ctx.Err() break } } @@ -339,10 +345,10 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr return nil, fallback.error } - if finalError==nil { - finalError=fmt.Errorf("all ips %v tcp shake hands failed", ips) - }else{ - finalError=fmt.Errorf("concurrent dial failed:%w",finalError) + if finalError == nil { + finalError = fmt.Errorf("all ips %v tcp shake hands failed", ips) + } else { + finalError = fmt.Errorf("concurrent dial failed:%w", finalError) } return nil, finalError diff --git a/constant/adapters.go b/constant/adapters.go index 47826a74..53d03fb0 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -32,6 +32,7 @@ const ( Trojan Hysteria WireGuard + Tuic ) const ( @@ -168,6 +169,8 @@ func (at AdapterType) String() string { return "Hysteria" case WireGuard: return "WireGuard" + case Tuic: + return "Tuic" case Relay: return "Relay" diff --git a/docs/config.yaml b/docs/config.yaml index 36050e28..39ed60be 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -458,6 +458,20 @@ proxies: public-key: Cr8hWlKvtDt7nrvf+f0brNQQzabAqrjfBvas9pmowjo= udp: true + - name: tuic + server: www.example.com + port: 10443 + type: tuic + token: TOKEN + # ip: 127.0.0.1 + # heartbeat_interval: 10000 + # alpn: [h3] + # disable_sni: true + reduce_rtt: true +# request_timeout: 8000 + udp_relay_mode: native + # skip-cert-verify: true + # ShadowsocksR # The supported ciphers (encryption methods): all stream ciphers in ss # The supported obfses: diff --git a/go.mod b/go.mod index feef8810..6a06a917 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( golang.org/x/sys v0.2.0 google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v3 v3.0.1 + lukechampine.com/blake3 v1.1.7 ) @@ -75,5 +76,4 @@ require ( golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.1.12 // indirect gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c // indirect - lukechampine.com/blake3 v1.1.7 // indirect ) diff --git a/transport/tuic/client.go b/transport/tuic/client.go new file mode 100644 index 00000000..3d9506c1 --- /dev/null +++ b/transport/tuic/client.go @@ -0,0 +1,377 @@ +package tuic + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "math/rand" + "net" + "net/netip" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + + N "github.com/Dreamacro/clash/common/net" + C "github.com/Dreamacro/clash/constant" +) + +type Client struct { + TlsConfig *tls.Config + QuicConfig *quic.Config + Host string + Token [32]byte + UdpRelayMode string + ReduceRtt bool + RequestTimeout int + + quicConn quic.Connection + connMutex sync.Mutex + + udpInputMap sync.Map +} + +func (t *Client) getQuicConn(ctx context.Context, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (quic.Connection, error) { + t.connMutex.Lock() + defer t.connMutex.Unlock() + if t.quicConn != nil { + return t.quicConn, nil + } + pc, addr, err := dialFn(ctx) + if err != nil { + return nil, err + } + var quicConn quic.Connection + if t.ReduceRtt { + quicConn, err = quic.DialEarlyContext(ctx, pc, addr, t.Host, t.TlsConfig, t.QuicConfig) + } else { + quicConn, err = quic.DialContext(ctx, pc, addr, t.Host, t.TlsConfig, t.QuicConfig) + } + if err != nil { + return nil, err + } + + sendAuthentication := func(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + stream, err := quicConn.OpenUniStream() + if err != nil { + return err + } + buf := &bytes.Buffer{} + err = NewAuthenticate(t.Token).WriteTo(buf) + if err != nil { + return err + } + _, err = buf.WriteTo(stream) + if err != nil { + return err + } + err = stream.Close() + if err != nil { + return + } + return nil + } + + go sendAuthentication(quicConn) + + go func(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + switch t.UdpRelayMode { + case "quic": + for { + var stream quic.ReceiveStream + stream, err = quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + var assocId uint32 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + }() + reader := bufio.NewReader(stream) + packet, err := ReadPacket(reader) + if err != nil { + return + } + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + writer := bufio.NewWriterSize(conn, packet.BytesLen()) + _ = packet.WriteTo(writer) + _ = writer.Flush() + } + } + return + }() + } + default: // native + for { + var message []byte + message, err = quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + var assocId uint32 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + }() + buffer := bytes.NewBuffer(message) + packet, err := ReadPacket(buffer) + if err != nil { + return + } + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _, _ = conn.Write(message) + } + } + return + }() + } + } + }(quicConn) + + t.quicConn = quicConn + return quicConn, nil +} + +func (t *Client) deferQuicConn(quicConn quic.Connection, err error) { + var netError net.Error + if err != nil && errors.As(err, &netError) { + t.connMutex.Lock() + defer t.connMutex.Unlock() + if t.quicConn == quicConn { + t.udpInputMap.Range(func(key, value any) bool { + if conn, ok := value.(net.Conn); ok { + _ = conn.Close() + } + return true + }) + t.udpInputMap = sync.Map{} // new one + t.quicConn = nil + } + } +} + +func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) { + quicConn, err := t.getQuicConn(ctx, dialFn) + if err != nil { + return nil, err + } + defer func() { + t.deferQuicConn(quicConn, err) + }() + buf := &bytes.Buffer{} + err = NewConnect(NewAddress(metadata)).WriteTo(buf) + if err != nil { + return nil, err + } + stream, err := quicConn.OpenStream() + if err != nil { + return nil, err + } + _, err = buf.WriteTo(stream) + if err != nil { + return nil, err + } + if t.RequestTimeout > 0 { + _ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond)) + } + conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr()}) + response, err := ReadResponse(conn) + if err != nil { + return nil, err + } + if response.IsFailed() { + _ = stream.Close() + return nil, errors.New("connect failed") + } + _ = stream.SetReadDeadline(time.Time{}) + return conn, err +} + +type quicStreamConn struct { + quic.Stream + lAddr net.Addr + rAddr net.Addr +} + +func (q *quicStreamConn) LocalAddr() net.Addr { + return q.lAddr +} + +func (q *quicStreamConn) RemoteAddr() net.Addr { + return q.rAddr +} + +var _ net.Conn = &quicStreamConn{} + +func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.PacketConn, error) { + quicConn, err := t.getQuicConn(ctx, dialFn) + if err != nil { + return nil, err + } + + pipe1, pipe2 := net.Pipe() + inputCh := make(chan udpData) + var connId uint32 + for { + connId = rand.Uint32() + _, loaded := t.udpInputMap.LoadOrStore(connId, pipe1) + if !loaded { + break + } + } + pc := &quicStreamPacketConn{ + connId: connId, + quicConn: quicConn, + lAddr: quicConn.LocalAddr(), + client: t, + inputConn: N.NewBufferedConn(pipe2), + inputCh: inputCh, + } + return pc, nil +} + +type udpData struct { + data []byte + addr net.Addr + err error +} + +type quicStreamPacketConn struct { + connId uint32 + quicConn quic.Connection + lAddr net.Addr + client *Client + inputConn *N.BufferedConn + inputCh chan udpData + + closeOnce sync.Once + closeErr error +} + +func (q *quicStreamPacketConn) Close() error { + q.closeOnce.Do(func() { + q.closeErr = q.close() + }) + return q.closeErr +} + +func (q *quicStreamPacketConn) close() (err error) { + defer func() { + q.client.deferQuicConn(q.quicConn, err) + }() + buf := &bytes.Buffer{} + err = NewDissociate(q.connId).WriteTo(buf) + if err != nil { + return + } + stream, err := q.quicConn.OpenUniStream() + if err != nil { + return + } + _, err = buf.WriteTo(stream) + if err != nil { + return + } + err = stream.Close() + if err != nil { + return + } + return +} + +func (q *quicStreamPacketConn) SetDeadline(t time.Time) error { + //TODO implement me + return nil +} + +func (q *quicStreamPacketConn) SetReadDeadline(t time.Time) error { + return q.inputConn.SetReadDeadline(t) +} + +func (q *quicStreamPacketConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + return nil +} + +func (q *quicStreamPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, err := ReadPacket(q.inputConn) + if err != nil { + return + } + n = copy(p, packet.DATA) + addr = packet.ADDR.UDPAddr() + return +} + +func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + defer func() { + q.client.deferQuicConn(q.quicConn, err) + }() + addr.String() + buf := &bytes.Buffer{} + addrPort, err := netip.ParseAddrPort(addr.String()) + if err != nil { + return + } + err = NewPacket(q.connId, uint16(len(p)), NewAddressAddrPort(addrPort), p).WriteTo(buf) + if err != nil { + return + } + switch q.client.UdpRelayMode { + case "quic": + var stream quic.SendStream + stream, err = q.quicConn.OpenUniStream() + if err != nil { + return + } + _, err = buf.WriteTo(stream) + if err != nil { + return + } + err = stream.Close() + if err != nil { + return + } + default: // native + err = q.quicConn.SendMessage(buf.Bytes()) + if err != nil { + return + } + } + n = len(p) + + return +} + +func (q *quicStreamPacketConn) LocalAddr() net.Addr { + return q.lAddr +} + +var _ net.PacketConn = &quicStreamPacketConn{} diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go new file mode 100644 index 00000000..b3b000b4 --- /dev/null +++ b/transport/tuic/protocol.go @@ -0,0 +1,468 @@ +package tuic + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + "strconv" + + "github.com/lucas-clemente/quic-go" + "lukechampine.com/blake3" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" +) + +type BufferedReader interface { + io.Reader + io.ByteReader +} + +type BufferedWriter interface { + io.Writer + io.ByteWriter +} + +type CommandType byte + +const ( + AuthenticateType = CommandType(0x00) + ConnectType = CommandType(0x01) + PacketType = CommandType(0x02) + DissociateType = CommandType(0x03) + HeartbeatType = CommandType(0x04) + ResponseType = CommandType(0x05) +) + +func (c CommandType) String() string { + switch c { + case AuthenticateType: + return "Authenticate" + case ConnectType: + return "Connect" + case PacketType: + return "Packet" + case DissociateType: + return "Dissociate" + case HeartbeatType: + return "Heartbeat" + case ResponseType: + return "Response" + default: + return fmt.Sprintf("UnknowCommand: %#x", byte(c)) + } +} + +func (c CommandType) BytesLen() int { + return 1 +} + +type CommandHead struct { + VER byte + TYPE CommandType +} + +func NewCommandHead(TYPE CommandType) CommandHead { + return CommandHead{ + VER: 0x04, + TYPE: TYPE, + } +} + +func ReadCommandHead(reader BufferedReader) (c CommandHead, err error) { + c.VER, err = reader.ReadByte() + if err != nil { + return + } + TYPE, err := reader.ReadByte() + if err != nil { + return + } + c.TYPE = CommandType(TYPE) + return +} + +func (c CommandHead) WriteTo(writer BufferedWriter) (err error) { + err = writer.WriteByte(c.VER) + if err != nil { + return + } + err = writer.WriteByte(byte(c.TYPE)) + if err != nil { + return + } + return +} + +func (c CommandHead) BytesLen() int { + return 1 + c.TYPE.BytesLen() +} + +type Authenticate struct { + CommandHead + TKN [32]byte +} + +func NewAuthenticate(TKN [32]byte) Authenticate { + return Authenticate{ + CommandHead: NewCommandHead(AuthenticateType), + TKN: TKN, + } +} + +func GenTKN(token string) [32]byte { + return blake3.Sum256([]byte(token)) +} + +func (c Authenticate) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + _, err = writer.Write(c.TKN[:]) + if err != nil { + return + } + return +} + +func (c Authenticate) BytesLen() int { + return c.CommandHead.BytesLen() + 32 +} + +type Connect struct { + CommandHead + ADDR Address +} + +func NewConnect(ADDR Address) Connect { + return Connect{ + CommandHead: NewCommandHead(ConnectType), + ADDR: ADDR, + } +} + +func (c Connect) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = c.ADDR.WriteTo(writer) + if err != nil { + return + } + return +} + +func (c Connect) BytesLen() int { + return c.CommandHead.BytesLen() + c.ADDR.BytesLen() +} + +type Packet struct { + CommandHead + ASSOC_ID uint32 + LEN uint16 + ADDR Address + DATA []byte +} + +func NewPacket(ASSOC_ID uint32, LEN uint16, ADDR Address, DATA []byte) Packet { + return Packet{ + CommandHead: NewCommandHead(PacketType), + ASSOC_ID: ASSOC_ID, + LEN: LEN, + ADDR: ADDR, + DATA: DATA, + } +} + +func ReadPacket(reader BufferedReader) (c Packet, err error) { + c.CommandHead, err = ReadCommandHead(reader) + if err != nil { + return + } + if c.CommandHead.TYPE != PacketType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + } + err = binary.Read(reader, binary.BigEndian, &c.ASSOC_ID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &c.LEN) + if err != nil { + return + } + c.ADDR, err = ReadAddress(reader) + if err != nil { + return + } + c.DATA = make([]byte, c.LEN) + _, err = io.ReadFull(reader, c.DATA) + if err != nil { + return + } + return +} + +func (c Packet) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.ASSOC_ID) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.LEN) + if err != nil { + return + } + err = c.ADDR.WriteTo(writer) + if err != nil { + return + } + _, err = writer.Write(c.DATA) + if err != nil { + return + } + return +} + +func (c Packet) BytesLen() int { + return c.CommandHead.BytesLen() + 4 + 2 + c.ADDR.BytesLen() + len(c.DATA) +} + +type Dissociate struct { + CommandHead + ASSOC_ID uint32 +} + +func NewDissociate(ASSOC_ID uint32) Dissociate { + return Dissociate{ + CommandHead: NewCommandHead(DissociateType), + ASSOC_ID: ASSOC_ID, + } +} + +func (c Dissociate) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.ASSOC_ID) + if err != nil { + return + } + return +} + +func (c Dissociate) BytesLen() int { + return c.CommandHead.BytesLen() + 4 +} + +type Heartbeat struct { + CommandHead +} + +func NewHeartbeat() Heartbeat { + return Heartbeat{ + CommandHead: NewCommandHead(HeartbeatType), + } +} + +func ReadHeartbeat(reader BufferedReader) (c Response, err error) { + c.CommandHead, err = ReadCommandHead(reader) + if err != nil { + return + } + if c.CommandHead.TYPE != HeartbeatType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + } + return +} + +type Response struct { + CommandHead + REP byte +} + +func NewResponse(REP byte) Response { + return Response{ + CommandHead: NewCommandHead(ResponseType), + REP: REP, + } +} + +func ReadResponse(reader BufferedReader) (c Response, err error) { + c.CommandHead, err = ReadCommandHead(reader) + if err != nil { + return + } + if c.CommandHead.TYPE != ResponseType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + } + c.REP, err = reader.ReadByte() + if err != nil { + return + } + return +} + +func (c Response) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = writer.WriteByte(c.REP) + if err != nil { + return + } + return +} + +func (c Response) IsSucceed() bool { + return c.REP == 0x00 +} + +func (c Response) IsFailed() bool { + return c.REP == 0xff +} + +func (c Response) BytesLen() int { + return c.CommandHead.BytesLen() + 1 +} + +// Addr types +const ( + AtypDomainName byte = 0 + AtypIPv4 byte = 1 + AtypIPv6 byte = 2 +) + +type Address struct { + TYPE byte + ADDR []byte + PORT uint16 +} + +func NewAddress(metadata *C.Metadata) Address { + var addrType byte + var addr []byte + switch metadata.AddrType() { + case socks5.AtypIPv4: + addrType = AtypIPv4 + addr = make([]byte, net.IPv4len) + copy(addr[:], metadata.DstIP.AsSlice()) + case socks5.AtypIPv6: + addrType = AtypIPv6 + addr = make([]byte, net.IPv6len) + copy(addr[:], metadata.DstIP.AsSlice()) + case socks5.AtypDomainName: + addrType = AtypDomainName + addr = make([]byte, len(metadata.Host)+1) + addr[0] = byte(len(metadata.Host)) + copy(addr[1:], metadata.Host) + } + + port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) + + return Address{ + TYPE: addrType, + ADDR: addr, + PORT: uint16(port), + } +} + +func NewAddressAddrPort(addrPort netip.AddrPort) Address { + var addrType byte + var addr []byte + if addrPort.Addr().Is4() { + addrType = AtypIPv4 + addr = make([]byte, net.IPv4len) + } else { + addrType = AtypIPv6 + addr = make([]byte, net.IPv6len) + } + copy(addr[:], addrPort.Addr().AsSlice()) + return Address{ + TYPE: addrType, + ADDR: addr, + PORT: addrPort.Port(), + } +} + +func ReadAddress(reader BufferedReader) (c Address, err error) { + c.TYPE, err = reader.ReadByte() + if err != nil { + return + } + switch c.TYPE { + case AtypIPv4: + c.ADDR = make([]byte, net.IPv4len) + _, err = io.ReadFull(reader, c.ADDR) + if err != nil { + return + } + case AtypIPv6: + c.ADDR = make([]byte, net.IPv6len) + _, err = io.ReadFull(reader, c.ADDR) + if err != nil { + return + } + case AtypDomainName: + var addrLen byte + addrLen, err = reader.ReadByte() + if err != nil { + return + } + c.ADDR = make([]byte, addrLen+1) + c.ADDR[0] = addrLen + _, err = io.ReadFull(reader, c.ADDR[1:]) + if err != nil { + return + } + } + + err = binary.Read(reader, binary.BigEndian, &c.PORT) + if err != nil { + return + } + return +} + +func (c Address) WriteTo(writer BufferedWriter) (err error) { + err = writer.WriteByte(c.TYPE) + if err != nil { + return + } + _, err = writer.Write(c.ADDR[:]) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.PORT) + if err != nil { + return + } + return +} + +func (c Address) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: c.ADDR, + Port: int(c.PORT), + Zone: "", + } +} + +func (c Address) BytesLen() int { + return 1 + len(c.ADDR) + 2 +} + +const ( + ProtocolError = quic.ApplicationErrorCode(0xfffffff0) + AuthenticationFailed = quic.ApplicationErrorCode(0xfffffff1) + AuthenticationTimeout = quic.ApplicationErrorCode(0xfffffff2) + BadCommand = quic.ApplicationErrorCode(0xfffffff3) +)