From 0a768767646147cef51b9a472040716f85cd2c27 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Wed, 6 Jul 2022 21:25:25 +0800 Subject: [PATCH] fix: h3 of doh fall back logic --- dns/doh.go | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/dns/doh.go b/dns/doh.go index c79e7ba2..2e3f685a 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -70,9 +70,7 @@ func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { client := &http.Client{Transport: dc.transport} resp, err := client.Do(req) if err != nil { - if err != nil { - return nil, err - } + return nil, err } defer resp.Body.Close() @@ -168,6 +166,14 @@ func (doh *dohTransport) RoundTrip(req *http.Request) (*http.Response, error) { var resp *http.Response var err error var bodyBytes []byte + var h3Err bool + var fallbackErr bool + defer func() { + if doh.preferH3 && h3Err { + doh.canUseH3.Store(doh.preferH3 && fallbackErr) + } + }() + if req.Body != nil { bodyBytes, err = ioutil.ReadAll(req.Body) } @@ -175,20 +181,17 @@ func (doh *dohTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes)) if doh.preferH3 && doh.canUseH3.Load() { resp, err = doh.h3.RoundTrip(req) - if err == nil { + h3Err = err != nil + if !h3Err { return resp, err } else { - doh.canUseH3.Store(false) req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes)) } } resp, err = doh.Transport.RoundTrip(req) - if err != nil { - if doh.preferH3 { - doh.canUseH3.Store(true) - } - + fallbackErr = err != nil + if fallbackErr { return resp, err }