refactor: Add named returns #168

Merged
sam merged 3 commits from named-returns into master 2022-12-17 16:52:50 +00:00
7 changed files with 69 additions and 64 deletions

View file

@ -17,7 +17,7 @@ import (
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
// struct.
func ParseCLI(args []string, version string) (*util.Options, error) {
func ParseCLI(args []string, version string) (opts *util.Options, err error) {
// Parse the standard flags
opts, misc, err := parseFlags(args, version)
if err != nil {
@ -70,11 +70,11 @@ func ParseCLI(args []string, version string) (*util.Options, error) {
opts.Logger.Info("Options fully populated")
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
return opts, nil
return
}
// Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb).
func parseFlags(args []string, version string) (*util.Options, []string, error) {
func parseFlags(args []string, version string) (opts *util.Options, flags []string, err error) {
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
flagSet.Usage = func() {
@ -162,7 +162,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
flagSet.SortFlags = true
// Parse the flags
if err := flagSet.Parse(args[1:]); err != nil {
if err = flagSet.Parse(args[1:]); err != nil {
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err)
}
@ -172,7 +172,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err)
}
opts := util.Options{
opts = &util.Options{
Logger: util.InitLogger(*verbosity),
IPv4: *ipv4,
IPv6: *ipv6,
@ -243,8 +243,8 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
// TODO: DRY
if *subnet != "" {
if err = util.ParseSubnet(*subnet, &opts); err != nil {
return &opts, nil, fmt.Errorf("%w", err)
if err = util.ParseSubnet(*subnet, opts); err != nil {
return opts, nil, fmt.Errorf("%w", err)
}
}
@ -254,10 +254,12 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
if *versionFlag {
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
return &opts, nil, util.ErrNotError
return opts, nil, util.ErrNotError
}
return &opts, flagSet.Args(), nil
flags = flagSet.Args()
return
}
var errNoArg = errors.New("no argument given")

View file

@ -21,15 +21,12 @@ import (
// ToString turns the response into something that looks a lot like dig
//
// Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900
func ToString(res util.Response, opts *util.Options) (string, error) {
func ToString(res util.Response, opts *util.Options) (s string, err error) {
if res.DNS == nil {
return "<nil> MsgHdr", errNoMessage
}
var (
s string
opt *dns.OPT
)
var opt *dns.OPT
if !opts.Short {
if opts.Display.Comments {
@ -143,7 +140,7 @@ func ToString(res util.Response, opts *util.Options) (string, error) {
}
}
return s, nil
return
}
func serverExtra(opts *util.Options) string {

View file

@ -19,7 +19,7 @@ type DNSCryptResolver struct {
var _ Resolver = (*DNSCryptResolver)(nil)
// LookUp performs a DNS query.
func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
client := dnscrypt.Client{
Timeout: resolver.opts.Request.Timeout,
UDPSize: 1232,
@ -42,7 +42,7 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
resolverInf, err := client.Dial(resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("dnscrypt: dial: %w", err)
return resp, fmt.Errorf("dnscrypt: dial: %w", err)
}
now := time.Now()
@ -50,13 +50,15 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
rtt := time.Since(now)
if err != nil {
return util.Response{}, fmt.Errorf("dnscrypt: exchange: %w", err)
return resp, fmt.Errorf("dnscrypt: exchange: %w", err)
}
resp = util.Response{
DNS: res,
RTT: rtt,
}
resolver.opts.Logger.Info("Request successful")
return util.Response{
DNS: res,
RTT: rtt,
}, nil
return
}

View file

@ -23,9 +23,7 @@ type HTTPSResolver struct {
var _ Resolver = (*HTTPSResolver)(nil)
// LookUp performs a DNS query.
func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
var resp util.Response
func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
resolver.client = http.Client{
Timeout: resolver.opts.Request.Timeout,
Transport: &http.Transport{
@ -43,7 +41,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
buf, err := msg.Pack()
if err != nil {
return util.Response{}, fmt.Errorf("doh: packing: %w", err)
return resp, fmt.Errorf("doh: packing: %w", err)
}
resolver.opts.Logger.Debug("https: sending HTTPS request")
@ -57,7 +55,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
req, err := http.NewRequest(method, resolver.opts.Request.Server, bytes.NewBuffer(buf))
if err != nil {
return util.Response{}, fmt.Errorf("doh: request creation: %w", err)
return resp, fmt.Errorf("doh: request creation: %w", err)
}
req.Header.Set("Content-Type", "application/dns-message")
@ -68,23 +66,29 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
resp.RTT = time.Since(now)
if err != nil {
return util.Response{}, fmt.Errorf("doh: HTTP request: %w", err)
// overwrite RTT or else tests will fail
resp.RTT = 0
return resp, fmt.Errorf("doh: HTTP request: %w", err)
}
if res.StatusCode != http.StatusOK {
return util.Response{}, &util.ErrHTTPStatus{Code: res.StatusCode}
// overwrite RTT or else tests will fail
resp.RTT = 0
return resp, &util.ErrHTTPStatus{Code: res.StatusCode}
}
resolver.opts.Logger.Debug("https: reading response")
fullRes, err := io.ReadAll(res.Body)
if err != nil {
return util.Response{}, fmt.Errorf("doh: body read: %w", err)
return resp, fmt.Errorf("doh: body read: %w", err)
}
err = res.Body.Close()
if err != nil {
return util.Response{}, fmt.Errorf("doh: body close: %w", err)
return resp, fmt.Errorf("doh: body close: %w", err)
}
resolver.opts.Logger.Debug("https: unpacking response")
@ -93,7 +97,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
err = resp.DNS.Unpack(fullRes)
if err != nil {
return util.Response{}, fmt.Errorf("doh: dns message unpack: %w", err)
return resp, fmt.Errorf("doh: dns message unpack: %w", err)
}
return resp, nil

View file

@ -22,9 +22,7 @@ type QUICResolver struct {
var _ Resolver = (*QUICResolver)(nil)
// LookUp performs a DNS query.
func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
var resp util.Response
func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
tls := &tls.Config{
//nolint:gosec // This is intentional if the user requests it
InsecureSkipVerify: resolver.opts.TLSNoVerify,
@ -40,7 +38,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
connection, err := quic.DialAddr(resolver.opts.Request.Server, tls, conf)
if err != nil {
return util.Response{}, fmt.Errorf("doq: dial: %w", err)
return resp, fmt.Errorf("doq: dial: %w", err)
}
resolver.opts.Logger.Debug("quic: packing query")
@ -48,7 +46,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
// Compress request to over-the-wire
buf, err := msg.Pack()
if err != nil {
return util.Response{}, fmt.Errorf("doq: pack: %w", err)
return resp, fmt.Errorf("doq: pack: %w", err)
}
t := time.Now()
@ -57,21 +55,21 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
stream, err := connection.OpenStream()
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream creation: %w", err)
return resp, fmt.Errorf("doq: quic stream creation: %w", err)
}
resolver.opts.Logger.Debug("quic: writing to stream")
_, err = stream.Write(buf)
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream write: %w", err)
return resp, fmt.Errorf("doq: quic stream write: %w", err)
}
resolver.opts.Logger.Debug("quic: reading stream")
fullRes, err := io.ReadAll(stream)
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream read: %w", err)
return resp, fmt.Errorf("doq: quic stream read: %w", err)
}
resp.RTT = time.Since(t)
@ -80,14 +78,14 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
// Close with error: no error
err = connection.CloseWithError(0, "")
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic connection close: %w", err)
return resp, fmt.Errorf("doq: quic connection close: %w", err)
}
resolver.opts.Logger.Debug("quic: closing stream")
err = stream.Close()
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream close: %w", err)
return resp, fmt.Errorf("doq: quic stream close: %w", err)
}
resp.DNS = &dns.Msg{}
@ -96,8 +94,8 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
err = resp.DNS.Unpack(fullRes)
if err != nil {
return util.Response{}, fmt.Errorf("doq: unpack: %w", err)
return resp, fmt.Errorf("doq: unpack: %w", err)
}
return resp, nil
return
}

View file

@ -19,12 +19,7 @@ type StandardResolver struct {
var _ Resolver = (*StandardResolver)(nil)
// LookUp performs a DNS query.
func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
var (
resp util.Response
err error
)
func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
dnsClient := new(dns.Client)
dnsClient.Dialer = &net.Dialer{
Timeout: resolver.opts.Request.Timeout,
@ -56,7 +51,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
}
switch dns.RcodeToString[resp.DNS.MsgHdr.Rcode] {
@ -69,7 +64,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("badcookie: DNS exchange: %w", err)
return resp, fmt.Errorf("badcookie: DNS exchange: %w", err)
}
}
@ -95,8 +90,8 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
}
if err != nil {
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
}
return resp, nil
return
}

View file

@ -22,7 +22,7 @@ type Resolver interface {
}
// LoadResolver loads the respective resolver for performing a DNS query.
func LoadResolver(opts *util.Options) (Resolver, error) {
func LoadResolver(opts *util.Options) (resolver Resolver, err error) {
switch {
case opts.HTTPS:
opts.Logger.Info("loading DNS-over-HTTPS resolver")
@ -32,10 +32,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
}
opts.Request.Server += opts.HTTPSOptions.Endpoint
return &HTTPSResolver{
resolver = &HTTPSResolver{
opts: opts,
}, nil
}
return
case opts.QUIC:
opts.Logger.Info("loading DNS-over-QUIC resolver")
@ -43,9 +44,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
}
return &QUICResolver{
resolver = &QUICResolver{
opts: opts,
}, nil
}
return
case opts.DNSCrypt:
opts.Logger.Info("loading DNSCrypt resolver")
@ -53,9 +56,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = "sdns://" + opts.Request.Server
}
return &DNSCryptResolver{
resolver = &DNSCryptResolver{
opts: opts,
}, nil
}
return
default:
opts.Logger.Info("loading standard/DNS-over-TLS resolver")
@ -63,8 +68,10 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
}
return &StandardResolver{
resolver = &StandardResolver{
opts: opts,
}, nil
}
return
}
}