refactor: Add named returns #168
7 changed files with 69 additions and 64 deletions
20
cmd/cli.go
20
cmd/cli.go
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
|
||||
// struct.
|
||||
func ParseCLI(args []string, version string) (*util.Options, error) {
|
||||
func ParseCLI(args []string, version string) (opts *util.Options, err error) {
|
||||
// Parse the standard flags
|
||||
opts, misc, err := parseFlags(args, version)
|
||||
if err != nil {
|
||||
|
@ -70,11 +70,11 @@ func ParseCLI(args []string, version string) (*util.Options, error) {
|
|||
opts.Logger.Info("Options fully populated")
|
||||
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
|
||||
|
||||
return opts, nil
|
||||
return
|
||||
}
|
||||
|
||||
// Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb).
|
||||
func parseFlags(args []string, version string) (*util.Options, []string, error) {
|
||||
func parseFlags(args []string, version string) (opts *util.Options, flags []string, err error) {
|
||||
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
|
||||
|
||||
flagSet.Usage = func() {
|
||||
|
@ -162,7 +162,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
flagSet.SortFlags = true
|
||||
|
||||
// Parse the flags
|
||||
if err := flagSet.Parse(args[1:]); err != nil {
|
||||
if err = flagSet.Parse(args[1:]); err != nil {
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err)
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
}
|
||||
|
||||
opts := util.Options{
|
||||
opts = &util.Options{
|
||||
Logger: util.InitLogger(*verbosity),
|
||||
IPv4: *ipv4,
|
||||
IPv6: *ipv6,
|
||||
|
@ -243,8 +243,8 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
|
||||
// TODO: DRY
|
||||
if *subnet != "" {
|
||||
if err = util.ParseSubnet(*subnet, &opts); err != nil {
|
||||
return &opts, nil, fmt.Errorf("%w", err)
|
||||
if err = util.ParseSubnet(*subnet, opts); err != nil {
|
||||
return opts, nil, fmt.Errorf("%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,10 +254,12 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
if *versionFlag {
|
||||
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
|
||||
|
||||
return &opts, nil, util.ErrNotError
|
||||
return opts, nil, util.ErrNotError
|
||||
}
|
||||
|
||||
return &opts, flagSet.Args(), nil
|
||||
flags = flagSet.Args()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var errNoArg = errors.New("no argument given")
|
||||
|
|
|
@ -21,15 +21,12 @@ import (
|
|||
// ToString turns the response into something that looks a lot like dig
|
||||
//
|
||||
// Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900
|
||||
func ToString(res util.Response, opts *util.Options) (string, error) {
|
||||
func ToString(res util.Response, opts *util.Options) (s string, err error) {
|
||||
if res.DNS == nil {
|
||||
return "<nil> MsgHdr", errNoMessage
|
||||
}
|
||||
|
||||
var (
|
||||
s string
|
||||
opt *dns.OPT
|
||||
)
|
||||
var opt *dns.OPT
|
||||
|
||||
if !opts.Short {
|
||||
if opts.Display.Comments {
|
||||
|
@ -143,7 +140,7 @@ func ToString(res util.Response, opts *util.Options) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return
|
||||
}
|
||||
|
||||
func serverExtra(opts *util.Options) string {
|
||||
|
|
|
@ -19,7 +19,7 @@ type DNSCryptResolver struct {
|
|||
var _ Resolver = (*DNSCryptResolver)(nil)
|
||||
|
||||
// LookUp performs a DNS query.
|
||||
func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
||||
func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
|
||||
client := dnscrypt.Client{
|
||||
Timeout: resolver.opts.Request.Timeout,
|
||||
UDPSize: 1232,
|
||||
|
@ -42,7 +42,7 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
resolverInf, err := client.Dial(resolver.opts.Request.Server)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("dnscrypt: dial: %w", err)
|
||||
return resp, fmt.Errorf("dnscrypt: dial: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
@ -50,13 +50,15 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
rtt := time.Since(now)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("dnscrypt: exchange: %w", err)
|
||||
return resp, fmt.Errorf("dnscrypt: exchange: %w", err)
|
||||
}
|
||||
|
||||
resp = util.Response{
|
||||
DNS: res,
|
||||
RTT: rtt,
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Info("Request successful")
|
||||
|
||||
return util.Response{
|
||||
DNS: res,
|
||||
RTT: rtt,
|
||||
}, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -23,9 +23,7 @@ type HTTPSResolver struct {
|
|||
var _ Resolver = (*HTTPSResolver)(nil)
|
||||
|
||||
// LookUp performs a DNS query.
|
||||
func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
||||
var resp util.Response
|
||||
|
||||
func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
|
||||
resolver.client = http.Client{
|
||||
Timeout: resolver.opts.Request.Timeout,
|
||||
Transport: &http.Transport{
|
||||
|
@ -43,7 +41,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: packing: %w", err)
|
||||
return resp, fmt.Errorf("doh: packing: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: sending HTTPS request")
|
||||
|
@ -57,7 +55,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
req, err := http.NewRequest(method, resolver.opts.Request.Server, bytes.NewBuffer(buf))
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: request creation: %w", err)
|
||||
return resp, fmt.Errorf("doh: request creation: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
|
@ -68,23 +66,29 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
resp.RTT = time.Since(now)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: HTTP request: %w", err)
|
||||
// overwrite RTT or else tests will fail
|
||||
resp.RTT = 0
|
||||
|
||||
return resp, fmt.Errorf("doh: HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return util.Response{}, &util.ErrHTTPStatus{Code: res.StatusCode}
|
||||
// overwrite RTT or else tests will fail
|
||||
resp.RTT = 0
|
||||
|
||||
return resp, &util.ErrHTTPStatus{Code: res.StatusCode}
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: reading response")
|
||||
|
||||
fullRes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: body read: %w", err)
|
||||
return resp, fmt.Errorf("doh: body read: %w", err)
|
||||
}
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: body close: %w", err)
|
||||
return resp, fmt.Errorf("doh: body close: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: unpacking response")
|
||||
|
@ -93,7 +97,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: dns message unpack: %w", err)
|
||||
return resp, fmt.Errorf("doh: dns message unpack: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
|
|
@ -22,9 +22,7 @@ type QUICResolver struct {
|
|||
var _ Resolver = (*QUICResolver)(nil)
|
||||
|
||||
// LookUp performs a DNS query.
|
||||
func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
||||
var resp util.Response
|
||||
|
||||
func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
|
||||
tls := &tls.Config{
|
||||
//nolint:gosec // This is intentional if the user requests it
|
||||
InsecureSkipVerify: resolver.opts.TLSNoVerify,
|
||||
|
@ -40,7 +38,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
connection, err := quic.DialAddr(resolver.opts.Request.Server, tls, conf)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: dial: %w", err)
|
||||
return resp, fmt.Errorf("doq: dial: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: packing query")
|
||||
|
@ -48,7 +46,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
// Compress request to over-the-wire
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: pack: %w", err)
|
||||
return resp, fmt.Errorf("doq: pack: %w", err)
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
|
@ -57,21 +55,21 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
stream, err := connection.OpenStream()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream creation: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream creation: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: writing to stream")
|
||||
|
||||
_, err = stream.Write(buf)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream write: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream write: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: reading stream")
|
||||
|
||||
fullRes, err := io.ReadAll(stream)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream read: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream read: %w", err)
|
||||
}
|
||||
|
||||
resp.RTT = time.Since(t)
|
||||
|
@ -80,14 +78,14 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
// Close with error: no error
|
||||
err = connection.CloseWithError(0, "")
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic connection close: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic connection close: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: closing stream")
|
||||
|
||||
err = stream.Close()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream close: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream close: %w", err)
|
||||
}
|
||||
|
||||
resp.DNS = &dns.Msg{}
|
||||
|
@ -96,8 +94,8 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: unpack: %w", err)
|
||||
return resp, fmt.Errorf("doq: unpack: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,12 +19,7 @@ type StandardResolver struct {
|
|||
var _ Resolver = (*StandardResolver)(nil)
|
||||
|
||||
// LookUp performs a DNS query.
|
||||
func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
||||
var (
|
||||
resp util.Response
|
||||
err error
|
||||
)
|
||||
|
||||
func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err error) {
|
||||
dnsClient := new(dns.Client)
|
||||
dnsClient.Dialer = &net.Dialer{
|
||||
Timeout: resolver.opts.Request.Timeout,
|
||||
|
@ -56,7 +51,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
|
||||
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
}
|
||||
|
||||
switch dns.RcodeToString[resp.DNS.MsgHdr.Rcode] {
|
||||
|
@ -69,7 +64,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("badcookie: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("badcookie: DNS exchange: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,8 +90,8 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (util.Response, error) {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ type Resolver interface {
|
|||
}
|
||||
|
||||
// LoadResolver loads the respective resolver for performing a DNS query.
|
||||
func LoadResolver(opts *util.Options) (Resolver, error) {
|
||||
func LoadResolver(opts *util.Options) (resolver Resolver, err error) {
|
||||
switch {
|
||||
case opts.HTTPS:
|
||||
opts.Logger.Info("loading DNS-over-HTTPS resolver")
|
||||
|
@ -32,10 +32,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
}
|
||||
|
||||
opts.Request.Server += opts.HTTPSOptions.Endpoint
|
||||
|
||||
return &HTTPSResolver{
|
||||
resolver = &HTTPSResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
case opts.QUIC:
|
||||
opts.Logger.Info("loading DNS-over-QUIC resolver")
|
||||
|
||||
|
@ -43,9 +44,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
|
||||
}
|
||||
|
||||
return &QUICResolver{
|
||||
resolver = &QUICResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
case opts.DNSCrypt:
|
||||
opts.Logger.Info("loading DNSCrypt resolver")
|
||||
|
||||
|
@ -53,9 +56,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = "sdns://" + opts.Request.Server
|
||||
}
|
||||
|
||||
return &DNSCryptResolver{
|
||||
resolver = &DNSCryptResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
default:
|
||||
opts.Logger.Info("loading standard/DNS-over-TLS resolver")
|
||||
|
||||
|
@ -63,8 +68,10 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
|
||||
}
|
||||
|
||||
return &StandardResolver{
|
||||
resolver = &StandardResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue