diff --git a/adapters/outbound/urltest.go b/adapters/outbound/urltest.go index b9efd8d2..60b4beda 100644 --- a/adapters/outbound/urltest.go +++ b/adapters/outbound/urltest.go @@ -110,8 +110,9 @@ func (u *URLTest) speedTest() { } defer atomic.StoreInt32(&u.once, 0) - picker, ctx, cancel := picker.WithTimeout(context.Background(), defaultURLTestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout) defer cancel() + picker := picker.WithoutAutoCancel(ctx) for _, p := range u.proxies { proxy := p picker.Go(func() (interface{}, error) { @@ -123,10 +124,12 @@ func (u *URLTest) speedTest() { }) } - fast := picker.Wait() + fast := picker.WaitWithoutCancel() if fast != nil { u.fast = fast.(C.Proxy) } + + picker.Wait() } func NewURLTest(option URLTestOption, proxies []C.Proxy) (*URLTest, error) { diff --git a/common/picker/picker.go b/common/picker/picker.go index a66b9bc6..a30202df 100644 --- a/common/picker/picker.go +++ b/common/picker/picker.go @@ -10,26 +10,42 @@ import ( // for groups of goroutines working on subtasks of a common task. // Inspired by errGroup type Picker struct { + ctx context.Context cancel func() wg sync.WaitGroup once sync.Once result interface{} + + firstDone chan struct{} +} + +func newPicker(ctx context.Context, cancel func()) *Picker { + return &Picker{ + ctx: ctx, + cancel: cancel, + firstDone: make(chan struct{}, 1), + } } // WithContext returns a new Picker and an associated Context derived from ctx. // and cancel when first element return. func WithContext(ctx context.Context) (*Picker, context.Context) { ctx, cancel := context.WithCancel(ctx) - return &Picker{cancel: cancel}, ctx + return newPicker(ctx, cancel), ctx } -// WithTimeout returns a new Picker and an associated Context derived from ctx with timeout, -// but it doesn't cancel when first element return. -func WithTimeout(ctx context.Context, timeout time.Duration) (*Picker, context.Context, context.CancelFunc) { +// WithTimeout returns a new Picker and an associated Context derived from ctx with timeout. +func WithTimeout(ctx context.Context, timeout time.Duration) (*Picker, context.Context) { ctx, cancel := context.WithTimeout(ctx, timeout) - return &Picker{}, ctx, cancel + return newPicker(ctx, cancel), ctx +} + +// WithoutAutoCancel returns a new Picker and an associated Context derived from ctx, +// but it wouldn't cancel context when the first element return. +func WithoutAutoCancel(ctx context.Context) *Picker { + return newPicker(ctx, nil) } // Wait blocks until all function calls from the Go method have returned, @@ -42,6 +58,16 @@ func (p *Picker) Wait() interface{} { return p.result } +// WaitWithoutCancel blocks until the first result return, if timeout will return nil. +func (p *Picker) WaitWithoutCancel() interface{} { + select { + case <-p.firstDone: + return p.result + case <-p.ctx.Done(): + return p.result + } +} + // Go calls the given function in a new goroutine. // The first call to return a nil error cancels the group; its result will be returned by Wait. func (p *Picker) Go(f func() (interface{}, error)) { @@ -53,6 +79,7 @@ func (p *Picker) Go(f func() (interface{}, error)) { if ret, err := f(); err == nil { p.once.Do(func() { p.result = ret + p.firstDone <- struct{}{} if p.cancel != nil { p.cancel() } diff --git a/common/picker/picker_test.go b/common/picker/picker_test.go index 7ff3712f..9e165009 100644 --- a/common/picker/picker_test.go +++ b/common/picker/picker_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" ) func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) { @@ -24,19 +26,41 @@ func TestPicker_Basic(t *testing.T) { picker.Go(sleepAndSend(ctx, 20, 1)) number := picker.Wait() - if number != nil && number.(int) != 1 { - t.Error("should recv 1", number) - } + assert.NotNil(t, number) + assert.Equal(t, number.(int), 1) } func TestPicker_Timeout(t *testing.T) { - picker, ctx, cancel := WithTimeout(context.Background(), time.Millisecond*5) - defer cancel() - + picker, ctx := WithTimeout(context.Background(), time.Millisecond*5) picker.Go(sleepAndSend(ctx, 20, 1)) number := picker.Wait() - if number != nil { - t.Error("should recv nil") - } + assert.Nil(t, number) +} + +func TestPicker_WaitWithoutAutoCancel(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*60) + defer cancel() + picker := WithoutAutoCancel(ctx) + + trigger := false + picker.Go(sleepAndSend(ctx, 10, 1)) + picker.Go(func() (interface{}, error) { + timer := time.NewTimer(time.Millisecond * time.Duration(30)) + select { + case <-timer.C: + trigger = true + return 2, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }) + elm := picker.WaitWithoutCancel() + + assert.NotNil(t, elm) + assert.Equal(t, elm.(int), 1) + + elm = picker.Wait() + assert.True(t, trigger) + assert.Equal(t, elm.(int), 1) } diff --git a/dns/resolver.go b/dns/resolver.go index 2276f789..1b68d7d6 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -171,10 +171,7 @@ func (r *Resolver) IsFakeIP() bool { } func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - fast, ctx := picker.WithContext(ctx) - + fast, ctx := picker.WithTimeout(context.Background(), time.Second) for _, client := range clients { r := client fast.Go(func() (interface{}, error) { diff --git a/hub/route/proxies.go b/hub/route/proxies.go index d485495e..201fd03f 100644 --- a/hub/route/proxies.go +++ b/hub/route/proxies.go @@ -9,7 +9,6 @@ import ( "time" A "github.com/Dreamacro/clash/adapters/outbound" - "github.com/Dreamacro/clash/common/picker" C "github.com/Dreamacro/clash/constant" T "github.com/Dreamacro/clash/tunnel" @@ -111,21 +110,17 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) { proxy := r.Context().Value(CtxKeyProxy).(C.Proxy) - picker, ctx, cancel := picker.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout)) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout)) defer cancel() - picker.Go(func() (interface{}, error) { - return proxy.URLTest(ctx, url) - }) - elm := picker.Wait() - if elm == nil { + delay, err := proxy.URLTest(ctx, url) + if ctx.Err() != nil { render.Status(r, http.StatusGatewayTimeout) render.JSON(w, r, ErrRequestTimeout) return } - delay := elm.(uint16) - if delay == 0 { + if err != nil || delay == 0 { render.Status(r, http.StatusServiceUnavailable) render.JSON(w, r, newError("An error occurred in the delay test")) return