refactor of query.go #27

Merged
sam merged 4 commits from refactor into master 2022-07-03 20:45:12 +00:00
9 changed files with 270 additions and 162 deletions
Showing only changes of commit cd63082d8f - Show all commits

16
cli.go
View file

@ -9,7 +9,7 @@ import (
"strings"
"git.froth.zone/sam/awl/conf"
"git.froth.zone/sam/awl/util"
"git.froth.zone/sam/awl/query"
"github.com/miekg/dns"
"github.com/urfave/cli/v2"
@ -137,6 +137,10 @@ func prepareCLI() *cli.App {
Aliases: []string{"x"},
Usage: "do a reverse lookup",
},
&cli.BoolFlag{
Name: "debug",
Usage: "enable verbose logging",
},
},
Action: doQuery,
}
@ -144,9 +148,9 @@ func prepareCLI() *cli.App {
}
// Parse the wildcard arguments, drill style
func parseArgs(args []string) (util.Answers, error) {
func parseArgs(args []string) (query.Answers, error) {
var (
resp util.Response
resp query.Response
err error
)
for _, arg := range args {
@ -158,7 +162,7 @@ func parseArgs(args []string) (util.Answers, error) {
case strings.Contains(arg, "."):
resp.Answers.Name, err = idna.ToUnicode(arg)
if err != nil {
return util.Answers{}, err
return query.Answers{}, err
}
case ok:
// If it's a DNS request, it's a DNS request (obviously)
@ -167,7 +171,7 @@ func parseArgs(args []string) (util.Answers, error) {
//else, assume it's a name
resp.Answers.Name, err = idna.ToUnicode(arg)
if err != nil {
return util.Answers{}, err
return query.Answers{}, err
}
}
@ -193,5 +197,5 @@ func parseArgs(args []string) (util.Answers, error) {
}
}
return util.Answers{Server: resp.Answers.Server, Request: resp.Answers.Request, Name: resp.Answers.Name}, nil
return query.Answers{Server: resp.Answers.Server, Request: resp.Answers.Request, Name: resp.Answers.Name}, nil
}

View file

@ -6,7 +6,7 @@ import (
"os"
"testing"
"git.froth.zone/sam/awl/util"
"git.froth.zone/sam/awl/query"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -21,19 +21,19 @@ func TestApp(t *testing.T) {
func TestArgParse(t *testing.T) {
tests := []struct {
in []string
want util.Answers
want query.Answers
}{
{
[]string{"@::1", "localhost", "AAAA"},
util.Answers{Server: "::1", Request: dns.TypeAAAA, Name: "localhost"},
query.Answers{Server: "::1", Request: dns.TypeAAAA, Name: "localhost"},
},
{
[]string{"@1.0.0.1", "google.com"},
util.Answers{Server: "1.0.0.1", Request: dns.TypeA, Name: "google.com"},
query.Answers{Server: "1.0.0.1", Request: dns.TypeA, Name: "google.com"},
},
{
[]string{"@8.8.4.4"},
util.Answers{Server: "8.8.4.4", Request: dns.TypeNS, Name: "."},
query.Answers{Server: "8.8.4.4", Request: dns.TypeNS, Name: "."},
},
}
for _, test := range tests {
@ -50,6 +50,25 @@ func TestQuery(t *testing.T) {
err := app.Run(args)
assert.NotNil(t, err)
}
func TestNoArgs(t *testing.T) {
app := prepareCLI()
args := os.Args[0:1]
args = append(args, "--no-truncate")
err := app.Run(args)
assert.Nil(t, err)
}
func TestFlags(t *testing.T) {
app := prepareCLI()
args := os.Args[0:1]
args = append(args, "--debug")
args = append(args, "--short")
args = append(args, "-4")
err := app.Run(args)
assert.Nil(t, err)
}
func TestHTTPS(t *testing.T) {
app := prepareCLI()
args := os.Args[0:1]
@ -60,6 +79,25 @@ func TestHTTPS(t *testing.T) {
assert.Nil(t, err)
}
func TestJSON(t *testing.T) {
app := prepareCLI()
args := os.Args[0:1]
args = append(args, "-j")
args = append(args, "git.froth.zone")
err := app.Run(args)
assert.Nil(t, err)
}
func TestQUIC(t *testing.T) {
app := prepareCLI()
args := os.Args[0:1]
args = append(args, "-Q")
args = append(args, "@dns.adguard.com")
args = append(args, "git.froth.zone")
err := app.Run(args)
assert.Nil(t, err)
}
func FuzzCli(f *testing.F) {
testcases := []string{"git.froth.zone", "", "!12345", "google.com.edu.org.fr"}
for _, tc := range testcases {

160
query.go
View file

@ -10,7 +10,6 @@ import (
"strings"
"time"
"git.froth.zone/sam/awl/logawl"
"git.froth.zone/sam/awl/query"
"git.froth.zone/sam/awl/util"
"github.com/miekg/dns"
@ -19,106 +18,73 @@ import (
func doQuery(c *cli.Context) error {
var (
err error
resp util.Response
isHTTPS bool
Logger = logawl.New() //init logger
err error
)
resp.Answers, err = parseArgs(c.Args().Slice())
Options := query.Options{
Logger: util.InitLogger(c.Bool("debug")),
Port: c.Int("port"),
IPv4: c.Bool("4"),
IPv6: c.Bool("6"),
DNSSEC: c.Bool("dnssec"),
Short: c.Bool("short"),
TCP: c.Bool("tcp"),
TLS: c.Bool("tls"),
HTTPS: c.Bool("https"),
QUIC: c.Bool("quic"),
Truncate: c.Bool("no-truncate"),
AA: c.Bool("aa"),
TC: c.Bool("tc"),
Z: c.Bool("z"),
CD: c.Bool("cd"),
NoRD: c.Bool("no-rd"),
NoRA: c.Bool("no-ra"),
Reverse: c.Bool("reverse"),
Debug: c.Bool("debug"),
}
Options.Answers, err = parseArgs(c.Args().Slice())
if err != nil {
Logger.Error("Unable to parse args")
Options.Logger.Error("Unable to parse args")
return err
}
port := c.Int("port")
if c.Bool("debug") {
Logger.SetLevel(3)
}
Logger.Debug("Starting awl")
// If port is not set, set it
if port == 0 {
if c.Bool("tls") || c.Bool("quic") {
port = 853
} else {
port = 53
}
}
if c.Bool("https") || strings.HasPrefix(resp.Answers.Server, "https://") {
// add https:// if it doesn't already exist
if !strings.HasPrefix(resp.Answers.Server, "https://") {
resp.Answers.Server = "https://" + resp.Answers.Server
}
isHTTPS = true
} else {
resp.Answers.Server = net.JoinHostPort(resp.Answers.Server, strconv.Itoa(port))
}
// Process the IP/Phone number so a PTR/NAPTR can be done
if c.Bool("reverse") {
if dns.TypeToString[resp.Answers.Request] == "A" {
resp.Answers.Request = dns.StringToType["PTR"]
}
resp.Answers.Name, err = util.ReverseDNS(resp.Answers.Name, dns.TypeToString[resp.Answers.Request])
if err != nil {
return err
}
}
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(resp.Answers.Name, ".") {
resp.Answers.Name = fmt.Sprintf("%s.", resp.Answers.Name)
}
msg := new(dns.Msg)
msg.SetQuestion(resp.Answers.Name, resp.Answers.Request)
// Make this authoritative (does this do anything?)
if c.Bool("aa") {
msg.Authoritative = true
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(Options.Answers.Name, ".") {
Options.Answers.Name = fmt.Sprintf("%s.", Options.Answers.Name)
}
// Set truncated flag (why)
if c.Bool("tc") {
msg.Truncated = true
}
// Set the zero flag if requested (does nothing)
if c.Bool("z") {
Logger.Debug("Setting message to zero")
msg.Zero = true
}
// Disable DNSSEC validation
if c.Bool("cd") {
msg.CheckingDisabled = true
}
// Disable wanting recursion
if c.Bool("no-rd") {
msg.RecursionDesired = false
}
// Disable recursion being available (I don't think this does anything)
if c.Bool("no-ra") {
msg.RecursionAvailable = false
}
// Set DNSSEC if requested
if c.Bool("dnssec") {
Logger.Debug("Using DNSSEC")
msg.SetEdns0(1232, true)
msg.SetQuestion(Options.Answers.Name, Options.Answers.Request)
// If port is not set, set it
if Options.Port == 0 {
if Options.TLS || Options.QUIC {
Options.Port = 853
} else {
Options.Port = 53
}
}
var in *dns.Msg
resolver, err := query.LoadResolver(Options.Answers.Server, Options)
if err != nil {
return err
}
if Options.Debug {
Options.Logger.SetLevel(3)
}
Options.Logger.Debug("Starting awl")
var in = Options.Answers.DNS
// Make the DNS request
if isHTTPS {
in, resp.Answers.RTT, err = query.ResolveHTTPS(msg, resp.Answers.Server)
} else if c.Bool("quic") {
in, resp.Answers.RTT, err = query.ResolveQUIC(msg, resp.Answers.Server)
if Options.HTTPS {
in, Options.Answers.RTT, err = resolver.LookUp(msg)
} else if Options.QUIC {
in, Options.Answers.RTT, err = resolver.LookUp(msg)
} else {
Options.Answers.Server = net.JoinHostPort(Options.Answers.Server, strconv.Itoa(Options.Port))
d := new(dns.Client)
// Set TCP/UDP, depending on flags
if c.Bool("tcp") || c.Bool("tls") {
if Options.TCP || Options.TLS {
d.Net = "tcp"
} else {
d.Net = "udp"
@ -126,32 +92,32 @@ func doQuery(c *cli.Context) error {
// Set IPv4 or IPv6, depending on flags
switch {
case c.Bool("4"):
case Options.IPv4:
d.Net += "4"
case c.Bool("6"):
case Options.IPv6:
d.Net += "6"
}
// Add TLS, if requested
if c.Bool("tls") {
if Options.TLS {
d.Net += "-tls"
}
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
in, Options.Answers.RTT, err = d.Exchange(msg, Options.Answers.Server)
if err != nil {
return err
}
// If UDP truncates, use TCP instead (unless truncation is to be ignored)
if in.MsgHdr.Truncated && !c.Bool("no-truncate") {
if in.MsgHdr.Truncated && !Options.Truncate {
fmt.Printf(";; Truncated, retrying with TCP\n\n")
d.Net = "tcp"
switch {
case c.Bool("4"):
case Options.IPv4:
d.Net += "4"
case c.Bool("6"):
case Options.IPv4:
d.Net += "6"
}
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
in, Options.Answers.RTT, err = d.Exchange(msg, Options.Answers.Server)
}
}
@ -169,8 +135,8 @@ func doQuery(c *cli.Context) error {
if !c.Bool("short") {
// Print everything
fmt.Println(in)
fmt.Println(";; Query time:", resp.Answers.RTT)
fmt.Println(";; SERVER:", resp.Answers.Server)
fmt.Println(";; Query time:", Options.Answers.RTT)
fmt.Println(";; SERVER:", Options.Answers.Server)
fmt.Println(";; WHEN:", time.Now().Format(time.RFC1123Z))
fmt.Println(";; MSG SIZE rcvd:", in.Len())
} else {

View file

@ -12,17 +12,21 @@ import (
"github.com/miekg/dns"
)
// Resolve a DNS-over-HTTPS query
//
// Currently only supports POST requests
func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
type HTTPSResolver struct {
server string
opts Options
}
func (r *HTTPSResolver) LookUp(msg *dns.Msg) (*dns.Msg, time.Duration, error) {
var resp Response
httpR := &http.Client{}
buf, err := msg.Pack()
if err != nil {
return nil, 0, err
}
r.opts.Logger.Debug("making DoH request")
// query := server + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest("POST", server, bytes.NewBuffer(buf))
req, err := http.NewRequest("POST", r.server, bytes.NewBuffer(buf))
if err != nil {
return nil, 0, fmt.Errorf("DoH: %s", err.Error())
}
@ -31,7 +35,7 @@ func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error)
now := time.Now()
res, err := httpR.Do(req)
rtt := time.Since(now)
resp.Answers.RTT = time.Since(now)
if err != nil {
return nil, 0, fmt.Errorf("DoH HTTP request error: %s", err.Error())
@ -46,11 +50,12 @@ func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error)
if err != nil {
return nil, 0, fmt.Errorf("DoH body read error: %s", err.Error())
}
response := dns.Msg{}
err = response.Unpack(fullRes)
resp.DNS = dns.Msg{}
r.opts.Logger.Debug("unpacking response")
err = resp.DNS.Unpack(fullRes)
if err != nil {
return nil, 0, fmt.Errorf("DoH dns message unpack error: %s", err.Error())
}
return &response, rtt, nil
return &resp.DNS, resp.Answers.RTT, nil
}

View file

@ -11,12 +11,18 @@ import (
"github.com/miekg/dns"
)
// Resolve DNS over QUIC, the hip new standard (for privacy I think, IDK)
func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
type QUICResolver struct {
server string
opts Options
}
func (r *QUICResolver) LookUp(msg *dns.Msg) (*dns.Msg, time.Duration, error) {
var resp Response
tls := &tls.Config{
NextProtos: []string{"doq"},
}
connection, err := quic.DialAddr(server, tls, nil)
connection, err := quic.DialAddr(r.server, tls, nil)
if err != nil {
return nil, 0, err
}
@ -40,7 +46,7 @@ func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
if err != nil {
return nil, 0, err
}
rtt := time.Since(t)
resp.Answers.RTT = time.Since(t)
// Close with error: no error
err = connection.CloseWithError(0, "")
@ -53,11 +59,10 @@ func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
return nil, 0, err
}
response := dns.Msg{}
err = response.Unpack(fullRes)
resp.DNS = dns.Msg{}
err = resp.DNS.Unpack(fullRes)
if err != nil {
return nil, 0, err
}
return &response, rtt, nil
return &resp.DNS, resp.Answers.RTT, nil
}

View file

@ -14,7 +14,13 @@ import (
func TestResolveHTTPS(t *testing.T) {
var err error
testCase := util.Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
opts := Options{
HTTPS: true,
Logger: util.InitLogger(false),
}
testCase := Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
resolver, err := LoadResolver(testCase.Server, opts)
if !strings.HasPrefix(testCase.Server, "https://") {
testCase.Server = "https://" + testCase.Server
}
@ -27,28 +33,36 @@ func TestResolveHTTPS(t *testing.T) {
msg.SetQuestion(testCase.Name, testCase.Request)
msg = msg.SetQuestion(testCase.Name, testCase.Request)
var in *dns.Msg
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
in, testCase.RTT, err = resolver.LookUp(msg)
assert.Nil(t, err)
assert.NotNil(t, in)
}
func Test2ResolveHTTPS(t *testing.T) {
opts := Options{
HTTPS: true,
Logger: util.InitLogger(false),
}
var err error
testCase := util.Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
testCase := Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
resolver, err := LoadResolver(testCase.Server, opts)
msg := new(dns.Msg)
msg.SetQuestion(testCase.Name, testCase.Request)
msg = msg.SetQuestion(testCase.Name, testCase.Request)
var in *dns.Msg
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
in, testCase.RTT, err = resolver.LookUp(msg)
assert.NotNil(t, err)
assert.Nil(t, in)
}
func Test3ResolveHTTPS(t *testing.T) {
opts := Options{
HTTPS: true,
Logger: util.InitLogger(false),
}
var err error
testCase := util.Answers{Server: "dns9..quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone."}
testCase := Answers{Server: "dns9..quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone."}
if !strings.HasPrefix(testCase.Server, "https://") {
testCase.Server = "https://" + testCase.Server
}
@ -56,30 +70,33 @@ func Test3ResolveHTTPS(t *testing.T) {
if !strings.HasSuffix(testCase.Name, ".") {
testCase.Name = fmt.Sprintf("%s.", testCase.Name)
}
resolver, err := LoadResolver(testCase.Server, opts)
msg := new(dns.Msg)
msg.SetQuestion(testCase.Name, testCase.Request)
msg = msg.SetQuestion(testCase.Name, testCase.Request)
var in *dns.Msg
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
in, testCase.RTT, err = resolver.LookUp(msg)
assert.NotNil(t, err)
assert.Nil(t, in)
}
func TestQuic(t *testing.T) {
var err error
testCase := util.Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
testCase2 := util.Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
var testCases []util.Answers
opts := Options{
QUIC: true,
Logger: util.InitLogger(false),
Port: 853,
Answers: Answers{Server: "dns.adguard.com"},
}
testCase := Answers{Server: "dns.//./,,adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
testCase2 := Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
var testCases []Answers
testCases = append(testCases, testCase)
testCases = append(testCases, testCase2)
for i := range testCases {
switch i {
case 0:
port := 853
testCases[i].Server = net.JoinHostPort(testCases[i].Server, strconv.Itoa(port))
fmt.Println(testCases[i].Server)
resolver, err := LoadResolver(testCases[i].Server, opts)
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(testCase.Name, ".") {
testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name)
@ -88,21 +105,21 @@ func TestQuic(t *testing.T) {
msg.SetQuestion(testCase.Name, testCase.Request)
msg = msg.SetQuestion(testCase.Name, testCase.Request)
var in *dns.Msg
in, testCase.RTT, err = ResolveQUIC(msg, testCase.Server)
in, testCase.RTT, err = resolver.LookUp(msg)
assert.NotNil(t, err)
assert.Nil(t, in)
case 1:
port := 853
testCases[i].Server = net.JoinHostPort(testCases[i].Server, strconv.Itoa(port))
resolver, err := LoadResolver(testCase2.Server, opts)
testCase2.Server = net.JoinHostPort(testCase2.Server, strconv.Itoa(opts.Port))
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(testCase.Name, ".") {
testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name)
if !strings.HasSuffix(testCase2.Name, ".") {
testCase2.Name = fmt.Sprintf("%s.", testCase2.Name)
}
msg := new(dns.Msg)
msg.SetQuestion(testCases[i].Name, testCases[i].Request)
msg = msg.SetQuestion(testCases[i].Name, testCases[i].Request)
msg.SetQuestion(testCase2.Name, testCase2.Request)
msg = msg.SetQuestion(testCase2.Name, testCase2.Request)
var in *dns.Msg
in, testCase.RTT, err = ResolveQUIC(msg, testCases[i].Server)
in, testCase.RTT, err = resolver.LookUp(msg)
assert.Nil(t, err)
assert.NotNil(t, in)
}

74
query/resolver.go Normal file
View file

@ -0,0 +1,74 @@
package query
import (
"net"
"strconv"
"strings"
"time"
"git.froth.zone/sam/awl/logawl"
"github.com/miekg/dns"
)
// represent all CLI flags
type Options struct {
Logger *logawl.Logger
Port int
IPv4 bool
IPv6 bool
DNSSEC bool
Short bool
TCP bool
TLS bool
HTTPS bool
QUIC bool
Truncate bool
AA bool
TC bool
Z bool
CD bool
NoRD bool
NoRA bool
Reverse bool
Debug bool
Answers Answers
}
type Response struct {
Answers Answers `json:"Response"` // These be DNS query answers
DNS dns.Msg
}
// The Answers struct is the basic structure of a DNS request
// to be returned to the user upon making a request
type Answers struct {
Server string `json:"Server"` // The server to make the DNS request from
DNS *dns.Msg
Request uint16 `json:"Request"` // The type of request
Name string `json:"Name"` // The domain name to make a DNS request for
RTT time.Duration `json:"RTT"` // The time it took to make the DNS query
}
type Resolver interface {
LookUp(*dns.Msg) (*dns.Msg, time.Duration, error)
}
func LoadResolver(server string, opts Options) (Resolver, error) {
if opts.HTTPS {
if !strings.HasPrefix(server, "https://") {
server = "https://" + server
}
return &HTTPSResolver{
server: server,
opts: opts,
}, nil
} else if opts.QUIC {
server = net.JoinHostPort(opts.Answers.Server, strconv.Itoa(opts.Port))
return &QUICResolver{
server: server,
opts: opts,
}, nil
}
return nil, nil
}

View file

@ -6,24 +6,10 @@ import (
"errors"
"fmt"
"strings"
"time"
"github.com/miekg/dns"
)
type Response struct {
Answers Answers `json:"Response"` // These be DNS query answers
}
// The Answers struct is the basic structure of a DNS request
// to be returned to the user upon making a request
type Answers struct {
Server string `json:"Server"` // The server to make the DNS request from
Request uint16 `json:"Request"` // The type of request
Name string `json:"Name"` // The domain name to make a DNS request for
RTT time.Duration `json:"RTT"` // The time it took to make the DNS query
}
// Given an IP or phone number, return a canonical string to be queried
func ReverseDNS(address string, query string) (string, error) {
if query == "PTR" {

13
util/logger.go Normal file
View file

@ -0,0 +1,13 @@
package util
import "git.froth.zone/sam/awl/logawl"
func InitLogger(debug bool) (Logger *logawl.Logger) {
Logger = logawl.New()
if debug {
Logger.SetLevel(3)
}
return
}