From fdba9a0a41ed02ed6866552f26011741f185737c Mon Sep 17 00:00:00 2001 From: Sam Therapy Date: Sat, 17 Dec 2022 16:52:50 +0000 Subject: [PATCH] refactor: Add named returns (#168) Add some named returns Co-authored-by: grumbulon Reviewed-on: https://git.froth.zone/sam/awl/pulls/168 Reviewed-by: grumbulon --- cmd/cli.go | 20 +++++++++++--------- pkg/query/print.go | 9 +++------ pkg/resolvers/DNSCrypt.go | 16 +++++++++------- pkg/resolvers/HTTPS.go | 24 ++++++++++++++---------- pkg/resolvers/QUIC.go | 22 ++++++++++------------ pkg/resolvers/general.go | 15 +++++---------- pkg/resolvers/resolver.go | 27 +++++++++++++++++---------- 7 files changed, 69 insertions(+), 64 deletions(-) diff --git a/cmd/cli.go b/cmd/cli.go index 3c06037..bb215d6 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -17,7 +17,7 @@ import ( // ParseCLI parses arguments given from the CLI and passes them into an `Options` // struct. -func ParseCLI(args []string, version string) (*util.Options, error) { +func ParseCLI(args []string, version string) (opts *util.Options, err error) { // Parse the standard flags opts, misc, err := parseFlags(args, version) if err != nil { @@ -70,11 +70,11 @@ func ParseCLI(args []string, version string) (*util.Options, error) { opts.Logger.Info("Options fully populated") opts.Logger.Debug(fmt.Sprintf("%+v", opts)) - return opts, nil + return } // Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb). -func parseFlags(args []string, version string) (*util.Options, []string, error) { +func parseFlags(args []string, version string) (opts *util.Options, flags []string, err error) { flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError) flagSet.Usage = func() { @@ -162,7 +162,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error) flagSet.SortFlags = true // Parse the flags - if err := flagSet.Parse(args[1:]); err != nil { + if err = flagSet.Parse(args[1:]); err != nil { return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err) } @@ -172,7 +172,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error) return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err) } - opts := util.Options{ + opts = &util.Options{ Logger: util.InitLogger(*verbosity), IPv4: *ipv4, IPv6: *ipv6, @@ -243,8 +243,8 @@ func parseFlags(args []string, version string) (*util.Options, []string, error) // TODO: DRY if *subnet != "" { - if err = util.ParseSubnet(*subnet, &opts); err != nil { - return &opts, nil, fmt.Errorf("%w", err) + if err = util.ParseSubnet(*subnet, opts); err != nil { + return opts, nil, fmt.Errorf("%w", err) } } @@ -254,10 +254,12 @@ func parseFlags(args []string, version string) (*util.Options, []string, error) if *versionFlag { fmt.Printf("awl version %s, built with %s\n", version, runtime.Version()) - return &opts, nil, util.ErrNotError + return opts, nil, util.ErrNotError } - return &opts, flagSet.Args(), nil + flags = flagSet.Args() + + return } var errNoArg = errors.New("no argument given") diff --git a/pkg/query/print.go b/pkg/query/print.go index edfb72a..20881c7 100644 --- a/pkg/query/print.go +++ b/pkg/query/print.go @@ -21,15 +21,12 @@ import ( // ToString turns the response into something that looks a lot like dig // // Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900 -func ToString(res util.Response, opts *util.Options) (string, error) { +func ToString(res util.Response, opts *util.Options) (s string, err error) { if res.DNS == nil { return " MsgHdr", errNoMessage } - var ( - s string - opt *dns.OPT - ) + var opt *dns.OPT if !opts.Short { if opts.Display.Comments { @@ -143,7 +140,7 @@ func ToString(res util.Response, opts *util.Options) (string, error) { } } - return s, nil + return } func serverExtra(opts *util.Options) string { diff --git a/pkg/resolvers/DNSCrypt.go b/pkg/resolvers/DNSCrypt.go index 9bff6bd..e88471a 100644 --- a/pkg/resolvers/DNSCrypt.go +++ b/pkg/resolvers/DNSCrypt.go @@ -19,7 +19,7 @@ type DNSCryptResolver struct { var _ Resolver = (*DNSCryptResolver)(nil) // LookUp performs a DNS query. -func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) { +func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) { client := dnscrypt.Client{ Timeout: resolver.opts.Request.Timeout, UDPSize: 1232, @@ -42,7 +42,7 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) { resolverInf, err := client.Dial(resolver.opts.Request.Server) if err != nil { - return util.Response{}, fmt.Errorf("dnscrypt: dial: %w", err) + return resp, fmt.Errorf("dnscrypt: dial: %w", err) } now := time.Now() @@ -50,13 +50,15 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) { rtt := time.Since(now) if err != nil { - return util.Response{}, fmt.Errorf("dnscrypt: exchange: %w", err) + return resp, fmt.Errorf("dnscrypt: exchange: %w", err) + } + + resp = util.Response{ + DNS: res, + RTT: rtt, } resolver.opts.Logger.Info("Request successful") - return util.Response{ - DNS: res, - RTT: rtt, - }, nil + return } diff --git a/pkg/resolvers/HTTPS.go b/pkg/resolvers/HTTPS.go index bfe1660..58e04f1 100644 --- a/pkg/resolvers/HTTPS.go +++ b/pkg/resolvers/HTTPS.go @@ -23,9 +23,7 @@ type HTTPSResolver struct { var _ Resolver = (*HTTPSResolver)(nil) // LookUp performs a DNS query. -func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { - var resp util.Response - +func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) { resolver.client = http.Client{ Timeout: resolver.opts.Request.Timeout, Transport: &http.Transport{ @@ -43,7 +41,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { buf, err := msg.Pack() if err != nil { - return util.Response{}, fmt.Errorf("doh: packing: %w", err) + return resp, fmt.Errorf("doh: packing: %w", err) } resolver.opts.Logger.Debug("https: sending HTTPS request") @@ -57,7 +55,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { req, err := http.NewRequest(method, resolver.opts.Request.Server, bytes.NewBuffer(buf)) if err != nil { - return util.Response{}, fmt.Errorf("doh: request creation: %w", err) + return resp, fmt.Errorf("doh: request creation: %w", err) } req.Header.Set("Content-Type", "application/dns-message") @@ -68,23 +66,29 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { resp.RTT = time.Since(now) if err != nil { - return util.Response{}, fmt.Errorf("doh: HTTP request: %w", err) + // overwrite RTT or else tests will fail + resp.RTT = 0 + + return resp, fmt.Errorf("doh: HTTP request: %w", err) } if res.StatusCode != http.StatusOK { - return util.Response{}, &util.ErrHTTPStatus{Code: res.StatusCode} + // overwrite RTT or else tests will fail + resp.RTT = 0 + + return resp, &util.ErrHTTPStatus{Code: res.StatusCode} } resolver.opts.Logger.Debug("https: reading response") fullRes, err := io.ReadAll(res.Body) if err != nil { - return util.Response{}, fmt.Errorf("doh: body read: %w", err) + return resp, fmt.Errorf("doh: body read: %w", err) } err = res.Body.Close() if err != nil { - return util.Response{}, fmt.Errorf("doh: body close: %w", err) + return resp, fmt.Errorf("doh: body close: %w", err) } resolver.opts.Logger.Debug("https: unpacking response") @@ -93,7 +97,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { err = resp.DNS.Unpack(fullRes) if err != nil { - return util.Response{}, fmt.Errorf("doh: dns message unpack: %w", err) + return resp, fmt.Errorf("doh: dns message unpack: %w", err) } return resp, nil diff --git a/pkg/resolvers/QUIC.go b/pkg/resolvers/QUIC.go index 53b96b7..648c0dc 100644 --- a/pkg/resolvers/QUIC.go +++ b/pkg/resolvers/QUIC.go @@ -22,9 +22,7 @@ type QUICResolver struct { var _ Resolver = (*QUICResolver)(nil) // LookUp performs a DNS query. -func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { - var resp util.Response - +func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) { tls := &tls.Config{ //nolint:gosec // This is intentional if the user requests it InsecureSkipVerify: resolver.opts.TLSNoVerify, @@ -40,7 +38,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { connection, err := quic.DialAddr(resolver.opts.Request.Server, tls, conf) if err != nil { - return util.Response{}, fmt.Errorf("doq: dial: %w", err) + return resp, fmt.Errorf("doq: dial: %w", err) } resolver.opts.Logger.Debug("quic: packing query") @@ -48,7 +46,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { // Compress request to over-the-wire buf, err := msg.Pack() if err != nil { - return util.Response{}, fmt.Errorf("doq: pack: %w", err) + return resp, fmt.Errorf("doq: pack: %w", err) } t := time.Now() @@ -57,21 +55,21 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { stream, err := connection.OpenStream() if err != nil { - return util.Response{}, fmt.Errorf("doq: quic stream creation: %w", err) + return resp, fmt.Errorf("doq: quic stream creation: %w", err) } resolver.opts.Logger.Debug("quic: writing to stream") _, err = stream.Write(buf) if err != nil { - return util.Response{}, fmt.Errorf("doq: quic stream write: %w", err) + return resp, fmt.Errorf("doq: quic stream write: %w", err) } resolver.opts.Logger.Debug("quic: reading stream") fullRes, err := io.ReadAll(stream) if err != nil { - return util.Response{}, fmt.Errorf("doq: quic stream read: %w", err) + return resp, fmt.Errorf("doq: quic stream read: %w", err) } resp.RTT = time.Since(t) @@ -80,14 +78,14 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { // Close with error: no error err = connection.CloseWithError(0, "") if err != nil { - return util.Response{}, fmt.Errorf("doq: quic connection close: %w", err) + return resp, fmt.Errorf("doq: quic connection close: %w", err) } resolver.opts.Logger.Debug("quic: closing stream") err = stream.Close() if err != nil { - return util.Response{}, fmt.Errorf("doq: quic stream close: %w", err) + return resp, fmt.Errorf("doq: quic stream close: %w", err) } resp.DNS = &dns.Msg{} @@ -96,8 +94,8 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) { err = resp.DNS.Unpack(fullRes) if err != nil { - return util.Response{}, fmt.Errorf("doq: unpack: %w", err) + return resp, fmt.Errorf("doq: unpack: %w", err) } - return resp, nil + return } diff --git a/pkg/resolvers/general.go b/pkg/resolvers/general.go index e4a9435..bce7ee4 100644 --- a/pkg/resolvers/general.go +++ b/pkg/resolvers/general.go @@ -19,12 +19,7 @@ type StandardResolver struct { var _ Resolver = (*StandardResolver)(nil) // LookUp performs a DNS query. -func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) { - var ( - resp util.Response - err error - ) - +func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) { dnsClient := new(dns.Client) dnsClient.Dialer = &net.Dialer{ Timeout: resolver.opts.Request.Timeout, @@ -56,7 +51,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) { resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server) if err != nil { - return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err) + return resp, fmt.Errorf("standard: DNS exchange: %w", err) } switch dns.RcodeToString[resp.DNS.MsgHdr.Rcode] { @@ -69,7 +64,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) { resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server) if err != nil { - return util.Response{}, fmt.Errorf("badcookie: DNS exchange: %w", err) + return resp, fmt.Errorf("badcookie: DNS exchange: %w", err) } } @@ -95,8 +90,8 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) { } if err != nil { - return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err) + return resp, fmt.Errorf("standard: DNS exchange: %w", err) } - return resp, nil + return } diff --git a/pkg/resolvers/resolver.go b/pkg/resolvers/resolver.go index 3840c92..63103ad 100644 --- a/pkg/resolvers/resolver.go +++ b/pkg/resolvers/resolver.go @@ -22,7 +22,7 @@ type Resolver interface { } // LoadResolver loads the respective resolver for performing a DNS query. -func LoadResolver(opts *util.Options) (Resolver, error) { +func LoadResolver(opts *util.Options) (resolver Resolver, err error) { switch { case opts.HTTPS: opts.Logger.Info("loading DNS-over-HTTPS resolver") @@ -32,10 +32,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) { } opts.Request.Server += opts.HTTPSOptions.Endpoint - - return &HTTPSResolver{ + resolver = &HTTPSResolver{ opts: opts, - }, nil + } + + return case opts.QUIC: opts.Logger.Info("loading DNS-over-QUIC resolver") @@ -43,9 +44,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) { opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port)) } - return &QUICResolver{ + resolver = &QUICResolver{ opts: opts, - }, nil + } + + return case opts.DNSCrypt: opts.Logger.Info("loading DNSCrypt resolver") @@ -53,9 +56,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) { opts.Request.Server = "sdns://" + opts.Request.Server } - return &DNSCryptResolver{ + resolver = &DNSCryptResolver{ opts: opts, - }, nil + } + + return default: opts.Logger.Info("loading standard/DNS-over-TLS resolver") @@ -63,8 +68,10 @@ func LoadResolver(opts *util.Options) (Resolver, error) { opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port)) } - return &StandardResolver{ + resolver = &StandardResolver{ opts: opts, - }, nil + } + + return } }