📜 ⬆️ ⬇️

Writing DNS proxy on Go


I have long wanted to solve the problem with advertising. The easiest way to do this on all devices turned out to be raising your DNS server with blocking requests for receiving IP addresses of advertising domains.

First of all, I started using dnsmasq , but I wanted to download lists from the Internet and get some usage statistics. That's why I decided to write my server.

Of course, it is not written entirely from scratch, all work with DNS is taken from this library .

Configuration


The program starts, of course, by loading the configuration file. Immediately I thought about the need to automatically upload the config when it was changed in order to avoid restarting the server. The fsnotify package came in handy for this.
')
Config structure:

type Config struct { Nameservers []string `yaml:"nameservers"` Blocklist []string `yaml:"blocklist"` BlockAddress4 string `yaml:"blockAddress4"` BlockAddress6 string `yaml:"blockAddress6"` ConfigUpdate bool `yaml:"configUpdate"` UpdateInterval time.Duration `yaml:"updateInterval"` } 

Here the most interesting point is tracking configuration file updates. With the help of the library, this is done quite simply: we create a Watcher, hook to it a file and listen to events from the channel. True go!

Code
 func configWatcher() { watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } defer watcher.Close() err = watcher.Add(*configFile) if err != nil { log.Fatal(err) } for { select { case event := <-watcher.Events: if event.Op&fsnotify.Write == fsnotify.Write { log.Println("Config file updated, reload config") c, err := loadConfig() if err != nil { log.Println("Bad config: ", err) } else { log.Println("Config successfuly updated") config = c if !c.ConfigUpdate { return } } } case err := <-watcher.Errors: log.Println("error:", err) } } } 


Blacklist


Of course, since the goal is to block unwanted sites, they need to be stored somewhere. For this, with a small load, a simple hash table of empty structures will be suitable, where a blocked domain is used as a key. I want to note that you must have a point at the end.
But since we do not have a simultaneous read / write, we can do without mutexes.

Code
 type BlackList struct { data map[string]struct{} } func (b *BlackList) Add(server string) bool { server = strings.Trim(server, " ") if len(server) == 0 { return false } if !strings.HasSuffix(server, ".") { server += "." } b.data[server] = struct{}{} return true } func (b *BlackList) Contains(server string) bool { _, ok := b.data[server] return ok } 


Caching


Initially, I thought to do without it, after all, all my devices do not create a significant number of requests. But one evening, my server was somehow discovered and flooded with the same request with a frequency of ~ 100 rps. Yes, this is not much, but the requests are proxied to real namespace servers (in my case, Google) and it would be very unpleasant to get a lock.

The main problem of caching is a large number of different requests and they need to be stored separately, so a two-level hash table has turned out.

Code
 type Cache interface { Get(reqType uint16, domain string) dns.RR Set(reqType uint16, domain string, ip dns.RR) } type CacheItem struct { Ip dns.RR Die time.Time } type MemoryCache struct { cache map[uint16]map[string]*CacheItem locker sync.RWMutex } func (c *MemoryCache) Get(reqType uint16, domain string) dns.RR { c.locker.RLock() defer c.locker.RUnlock() if m, ok := c.cache[reqType]; ok { if ip, ok := m[domain]; ok { if ip.Die.After(time.Now()) { return ip.Ip } } } return nil } func (c *MemoryCache) Set(reqType uint16, domain string, ip dns.RR) { c.locker.Lock() defer c.locker.Unlock() var m map[string]*CacheItem m, ok := c.cache[reqType] if !ok { m = make(map[string]*CacheItem) c.cache[reqType] = m } m[domain] = &CacheItem{ Ip: ip, Die: time.Now().Add(time.Duration(ip.Header().Ttl) * time.Second), } } 


Handler


Of course, the main part of the program is an incoming requests handler, so I left it for dessert. The basic logic is something like this: we receive a request, check its presence in the blacklist, check the availability in the cache, proxy the request to the real server.

The main interest is the function lukapa. In it, we simultaneously send a request to all servers at once (if we have time before the answer arrives) and wait for a successful response from at least one of them.

Code
 func Lookup(req *dns.Msg) (*dns.Msg, error) { c := &dns.Client{ Net: "tcp", ReadTimeout: time.Second * 5, WriteTimeout: time.Second * 5, } qName := req.Question[0].Name res := make(chan *dns.Msg, 1) var wg sync.WaitGroup L := func(nameserver string) { defer wg.Done() r, _, err := c.Exchange(req, nameserver) totalRequestsToGoogle.Inc() if err != nil { log.Printf("%s socket error on %s", qName, nameserver) log.Printf("error:%s", err.Error()) return } if r != nil && r.Rcode != dns.RcodeSuccess { if r.Rcode == dns.RcodeServerFailure { return } } select { case res <- r: default: } } ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() // Start lookup on each nameserver top-down, in every second for _, nameserver := range config.Nameservers { wg.Add(1) go L(nameserver) // but exit early, if we have an answer select { case r := <-res: return r, nil case <-ticker.C: continue } } // wait for all the namservers to finish wg.Wait() select { case r := <-res: return r, nil default: return nil, errors.New("can't resolve ip for" + qName) } } 


Metrics


For metrics we will use the client from prometheus. It is used very simply, you first need to declare a counter, then register it and call the Inc () method in the right place. The main thing to remember is to launch the web server with the prometheus handler so that it can read the metrics.

Code
 var ( totalRequestsTcp = prometheus.NewCounter(prometheus.CounterOpts(prometheus.Opts{ Namespace: "dns", Subsystem: "requests", Name: "total", Help: "total requests", ConstLabels: map[string]string{ "type": "tcp", }, })) ) func runPrometheus() { prometheus.MustRegister(totalRequestsTcp) http.Handle("/metrics", promhttp.Handler()) log.Fatal(http.ListenAndServe(":9970", nil)) } 


I think the main needs no introduction and description. In this article, the code is presented in an abbreviated format.

The full code can be viewed in the repository (fixes and additions are of course welcome). Also in the repository there is a file for Docker and an example CI configuration for Gitlab.

Thanks for attention.

Source: https://habr.com/ru/post/348280/


All Articles