diff --git a/inbound/shadowtls.go b/inbound/shadowtls.go index b6c1480f..d3da651b 100644 --- a/inbound/shadowtls.go +++ b/inbound/shadowtls.go @@ -91,7 +91,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a hashConn := shadowtls.NewHashWriteConn(conn, s.password) go bufio.Copy(hashConn, handshakeConn) var request *buf.Buffer - request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn, s.fallbackAfter) + request, err = s.copyUntilHandshakeFinishedV2(ctx, handshakeConn, conn, hashConn, s.fallbackAfter) if err == nil { handshakeConn.Close() return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata) @@ -135,7 +135,7 @@ func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) err } } -func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) { +func (s *ShadowTLS) copyUntilHandshakeFinishedV2(ctx context.Context, dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) { const applicationData = 0x17 var tlsHdr [5]byte var applicationDataCount int @@ -152,9 +152,17 @@ func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, ha data.Release() return nil, err } - if length >= 8 && bytes.Equal(data.To(8), hash.Sum()) { - data.Advance(8) - return data, nil + if hash.HasContent() && length >= 8 { + checksum := hash.Sum() + if bytes.Equal(data.To(8), checksum) { + s.logger.TraceContext(ctx, "match current hashcode") + data.Advance(8) + return data, nil + } else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) { + s.logger.TraceContext(ctx, "match last hashcode") + data.Advance(8) + return data, nil + } } _, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data)) data.Release() diff --git a/transport/shadowtls/hash.go b/transport/shadowtls/hash.go index 0706f865..b6873121 100644 --- a/transport/shadowtls/hash.go +++ b/transport/shadowtls/hash.go @@ -34,19 +34,25 @@ func (c *HashReadConn) Sum() []byte { type HashWriteConn struct { net.Conn - hmac hash.Hash + hmac hash.Hash + hasContent bool + lastSum []byte } func NewHashWriteConn(conn net.Conn, password string) *HashWriteConn { return &HashWriteConn{ - conn, - hmac.New(sha1.New, []byte(password)), + Conn: conn, + hmac: hmac.New(sha1.New, []byte(password)), } } func (c *HashWriteConn) Write(p []byte) (n int, err error) { if c.hmac != nil { + if c.hasContent { + c.lastSum = c.Sum() + } c.hmac.Write(p) + c.hasContent = true } return c.Conn.Write(p) } @@ -55,6 +61,14 @@ func (c *HashWriteConn) Sum() []byte { return c.hmac.Sum(nil)[:8] } +func (c *HashWriteConn) LastSum() []byte { + return c.lastSum +} + func (c *HashWriteConn) Fallback() { c.hmac = nil } + +func (c *HashWriteConn) HasContent() bool { + return c.hasContent +}