refactor: Add named returns #168

Merged
sam merged 3 commits from named-returns into master 2022-12-17 16:52:50 +00:00
7 changed files with 62 additions and 45 deletions
Showing only changes of commit 91969133ff - Show all commits

View file

@ -70,11 +70,11 @@ func ParseCLI(args []string, version string) (opts *util.Options, err error) {
opts.Logger.Info("Options fully populated")
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
return opts, nil
return
}
// Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb).
func parseFlags(args []string, version string) (*util.Options, []string, error) {
func parseFlags(args []string, version string) (opts *util.Options, flags []string, err error) {
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
flagSet.Usage = func() {
@ -162,7 +162,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
flagSet.SortFlags = true
// Parse the flags
if err := flagSet.Parse(args[1:]); err != nil {
if err = flagSet.Parse(args[1:]); err != nil {
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err)
}
@ -172,7 +172,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err)
}
opts := util.Options{
opts = &util.Options{
Logger: util.InitLogger(*verbosity),
IPv4: *ipv4,
IPv6: *ipv6,
@ -243,8 +243,8 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
// TODO: DRY
if *subnet != "" {
if err = util.ParseSubnet(*subnet, &opts); err != nil {
return &opts, nil, fmt.Errorf("%w", err)
if err = util.ParseSubnet(*subnet, opts); err != nil {
return opts, nil, fmt.Errorf("%w", err)
}
}
@ -254,10 +254,12 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
if *versionFlag {
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
return &opts, nil, util.ErrNotError
return opts, nil, util.ErrNotError
}
return &opts, flagSet.Args(), nil
flags = flagSet.Args()
return
}
var errNoArg = errors.New("no argument given")

View file

@ -140,7 +140,7 @@ func ToString(res util.Response, opts *util.Options) (s string, err error) {
}
}
return s, nil
return
}
func serverExtra(opts *util.Options) string {

View file

@ -42,7 +42,7 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err
resolverInf, err := client.Dial(resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("dnscrypt: dial: %w", err)
return resp, fmt.Errorf("dnscrypt: dial: %w", err)
}
now := time.Now()
@ -50,13 +50,15 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err
rtt := time.Since(now)
if err != nil {
return util.Response{}, fmt.Errorf("dnscrypt: exchange: %w", err)
return resp, fmt.Errorf("dnscrypt: exchange: %w", err)
}
resp = util.Response{
DNS: res,
RTT: rtt,
}
resolver.opts.Logger.Info("Request successful")
return util.Response{
DNS: res,
RTT: rtt,
}, nil
return
}

View file

@ -41,7 +41,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
buf, err := msg.Pack()
if err != nil {
return util.Response{}, fmt.Errorf("doh: packing: %w", err)
return resp, fmt.Errorf("doh: packing: %w", err)
}
resolver.opts.Logger.Debug("https: sending HTTPS request")
@ -55,7 +55,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
req, err := http.NewRequest(method, resolver.opts.Request.Server, bytes.NewBuffer(buf))
if err != nil {
return util.Response{}, fmt.Errorf("doh: request creation: %w", err)
return resp, fmt.Errorf("doh: request creation: %w", err)
}
req.Header.Set("Content-Type", "application/dns-message")
@ -66,23 +66,29 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
resp.RTT = time.Since(now)
if err != nil {
return util.Response{}, fmt.Errorf("doh: HTTP request: %w", err)
// overwrite RTT or else tests will fail
resp.RTT = 0
return resp, fmt.Errorf("doh: HTTP request: %w", err)
}
if res.StatusCode != http.StatusOK {
return util.Response{}, &util.ErrHTTPStatus{Code: res.StatusCode}
// overwrite RTT or else tests will fail
resp.RTT = 0
return resp, &util.ErrHTTPStatus{Code: res.StatusCode}
}
resolver.opts.Logger.Debug("https: reading response")
fullRes, err := io.ReadAll(res.Body)
if err != nil {
return util.Response{}, fmt.Errorf("doh: body read: %w", err)
return resp, fmt.Errorf("doh: body read: %w", err)
}
err = res.Body.Close()
if err != nil {
return util.Response{}, fmt.Errorf("doh: body close: %w", err)
return resp, fmt.Errorf("doh: body close: %w", err)
}
resolver.opts.Logger.Debug("https: unpacking response")
@ -91,7 +97,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
err = resp.DNS.Unpack(fullRes)
if err != nil {
return util.Response{}, fmt.Errorf("doh: dns message unpack: %w", err)
return resp, fmt.Errorf("doh: dns message unpack: %w", err)
}
return resp, nil

View file

@ -38,7 +38,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
connection, err := quic.DialAddr(resolver.opts.Request.Server, tls, conf)
if err != nil {
return util.Response{}, fmt.Errorf("doq: dial: %w", err)
return resp, fmt.Errorf("doq: dial: %w", err)
}
resolver.opts.Logger.Debug("quic: packing query")
@ -46,7 +46,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
// Compress request to over-the-wire
buf, err := msg.Pack()
if err != nil {
return util.Response{}, fmt.Errorf("doq: pack: %w", err)
return resp, fmt.Errorf("doq: pack: %w", err)
}
t := time.Now()
@ -55,21 +55,21 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
stream, err := connection.OpenStream()
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream creation: %w", err)
return resp, fmt.Errorf("doq: quic stream creation: %w", err)
}
resolver.opts.Logger.Debug("quic: writing to stream")
_, err = stream.Write(buf)
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream write: %w", err)
return resp, fmt.Errorf("doq: quic stream write: %w", err)
}
resolver.opts.Logger.Debug("quic: reading stream")
fullRes, err := io.ReadAll(stream)
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream read: %w", err)
return resp, fmt.Errorf("doq: quic stream read: %w", err)
}
resp.RTT = time.Since(t)
@ -78,14 +78,14 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
// Close with error: no error
err = connection.CloseWithError(0, "")
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic connection close: %w", err)
return resp, fmt.Errorf("doq: quic connection close: %w", err)
}
resolver.opts.Logger.Debug("quic: closing stream")
err = stream.Close()
if err != nil {
return util.Response{}, fmt.Errorf("doq: quic stream close: %w", err)
return resp, fmt.Errorf("doq: quic stream close: %w", err)
}
resp.DNS = &dns.Msg{}
@ -94,8 +94,8 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
err = resp.DNS.Unpack(fullRes)
if err != nil {
return util.Response{}, fmt.Errorf("doq: unpack: %w", err)
return resp, fmt.Errorf("doq: unpack: %w", err)
}
return resp, nil
return
}

View file

@ -51,7 +51,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
}
switch dns.RcodeToString[resp.DNS.MsgHdr.Rcode] {
@ -64,7 +64,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
if err != nil {
return util.Response{}, fmt.Errorf("badcookie: DNS exchange: %w", err)
return resp, fmt.Errorf("badcookie: DNS exchange: %w", err)
}
}
@ -90,8 +90,8 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
}
if err != nil {
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
}
return resp, nil
return
}

View file

@ -22,7 +22,7 @@ type Resolver interface {
}
// LoadResolver loads the respective resolver for performing a DNS query.
func LoadResolver(opts *util.Options) (Resolver, error) {
func LoadResolver(opts *util.Options) (resolver Resolver, err error) {
switch {
case opts.HTTPS:
opts.Logger.Info("loading DNS-over-HTTPS resolver")
@ -32,10 +32,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
}
opts.Request.Server += opts.HTTPSOptions.Endpoint
return &HTTPSResolver{
resolver = &HTTPSResolver{
opts: opts,
}, nil
}
return
case opts.QUIC:
opts.Logger.Info("loading DNS-over-QUIC resolver")
@ -43,9 +44,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
}
return &QUICResolver{
resolver = &QUICResolver{
opts: opts,
}, nil
}
return
case opts.DNSCrypt:
opts.Logger.Info("loading DNSCrypt resolver")
@ -53,9 +56,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = "sdns://" + opts.Request.Server
}
return &DNSCryptResolver{
resolver = &DNSCryptResolver{
opts: opts,
}, nil
}
return
default:
opts.Logger.Info("loading standard/DNS-over-TLS resolver")
@ -63,8 +68,10 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
}
return &StandardResolver{
resolver = &StandardResolver{
opts: opts,
}, nil
}
return
}
}