diff --git a/inbound/default.go b/inbound/default.go index 631a3adf..83d0c45d 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -325,7 +325,7 @@ func (a *myInboundAdapter) NewError(ctx context.Context, err error) { func NewError(logger log.ContextLogger, ctx context.Context, err error) { common.Close(err) if E.IsClosedOrCanceled(err) { - logger.DebugContext(ctx, "connection closed") + logger.TraceContext(ctx, "connection closed: ", err) return } logger.ErrorContext(ctx, err) diff --git a/outbound/dns.go b/outbound/dns.go index 8f93e75f..03539131 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -6,6 +6,8 @@ import ( "io" "net" "os" + "sync" + "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -14,6 +16,7 @@ import ( "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/task" "golang.org/x/net/dns/dnsmessage" ) @@ -45,6 +48,7 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa } func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) _buffer := buf.StackNewSize(1024) defer common.KeepAlive(_buffer) @@ -97,45 +101,69 @@ func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter } func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) _buffer := buf.StackNewSize(1024) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - for { - buffer.FullReset() - destination, err := conn.ReadPacket(buffer) - if err != nil { - return err - } - var message dnsmessage.Message - err = message.Unpack(buffer.Bytes()) - if err != nil { - return err - } - if len(message.Questions) > 0 { - question := message.Questions[0] - metadata.Domain = string(question.Name.Data[:question.Name.Length-1]) - d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source) - } - go func() error { - response, err := d.router.Exchange(ctx, &message) + var wg sync.WaitGroup + fastClose, cancel := context.WithCancel(ctx) + err := task.Run(fastClose, func() error { + var count int + for { + buffer.FullReset() + destination, err := conn.ReadPacket(buffer) if err != nil { return err } - _responseBuffer := buf.StackNewSize(1024) - defer common.KeepAlive(_responseBuffer) - responseBuffer := common.Dup(_responseBuffer) - defer responseBuffer.Release() - n, err := response.AppendPack(responseBuffer.Index(0)) + var message dnsmessage.Message + err = message.Unpack(buffer.Bytes()) if err != nil { return err } - responseBuffer.Truncate(len(n)) - err = conn.WritePacket(responseBuffer, destination) - return err - }() - } + if len(message.Questions) > 0 { + question := message.Questions[0] + metadata.Domain = string(question.Name.Data[:question.Name.Length-1]) + d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source) + } + wg.Add(1) + go func() error { + defer wg.Done() + response, err := d.router.Exchange(ctx, &message) + if err != nil { + return err + } + _responseBuffer := buf.StackNewSize(1024) + defer common.KeepAlive(_responseBuffer) + responseBuffer := common.Dup(_responseBuffer) + defer responseBuffer.Release() + n, err := response.AppendPack(responseBuffer.Index(0)) + if err != nil { + return err + } + responseBuffer.Truncate(len(n)) + err = conn.WritePacket(responseBuffer, destination) + return err + }() + count++ + if count == 2 { + break + } + } + cancel() + return nil + }, func() error { + timer := time.NewTimer(5 * time.Second) + select { + case <-timer.C: + cancel() + case <-fastClose.Done(): + } + return nil + }) + wg.Wait() + return err } func formatDNSQuestion(question dnsmessage.Question) string {