Source code for

import datetime
import logging
import threading
import time

import socketserver
from django.conf import settings
from django.contrib.auth.models import User
from import BaseCommand, CommandError
from dnslib import PTR, QTYPE, RR, A, DNSHeader, DNSRecord

from oioioi.ipdnsauth.models import IpToUser
from oioioi.ipdnsauth.utils import username_to_hostname

[docs]logger = logging.getLogger(__name__)
# Most of the code taken from: #
[docs]class BaseRequestHandler(socketserver.BaseRequestHandler):
[docs] def get_data(self): raise NotImplementedError
[docs] def send_data(self, data): raise NotImplementedError
[docs] def dns_response(self, data): request = DNSRecord.parse(data) logger.debug('%s', request) reply = DNSRecord( DNSHeader(, qr=1, aa=1, ra=1), q=request.q ) qname = request.q.qname qn = str(qname) if qn.endswith('.'): qn = qn[:-1] qtype = request.q.qtype qt = QTYPE[qtype] qnhost, qndomain = qn.split('.', 1) # # OK, so we are not conformant to the standards at all, as we never # return any SOA records and stuff... # if qndomain == settings.IPAUTH_DNSSERVER_DOMAIN: if qt in ['*', 'A']: for u in User.objects.filter(iptouser__isnull=False): if qnhost == username_to_hostname(u.username): for itu in u.iptouser_set.all(): reply.add_answer( RR( rname=qname, rtype=QTYPE.A, rclass=1, ttl=self.server.command.options['ttl'], rdata=A(itu.ip_addr), ) ) elif qn.endswith(''): if qt in ['*', 'PTR']: qn = qn[: -len('')] parts = qn.split('.') if len(parts) == 4: ip = '.'.join(reversed(parts)) try: iptu = IpToUser.objects.get(ip_addr=ip) fqdn = ( username_to_hostname(iptu.user.username) + '.' + settings.IPAUTH_DNSSERVER_DOMAIN + '.' ) reply.add_answer( RR( rname=qname, rtype=QTYPE.PTR, rclass=1, ttl=self.server.command.options['ttl'], rdata=PTR(fqdn), ) ) except IpToUser.DoesNotExist: pass logger.debug('%s', reply) return reply.pack()
[docs] def handle(self): now = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f') logger.debug( "%s request %s (%s %s):", self.__class__.__name__[:3], now, self.client_address[0], self.client_address[1], ) try: data = self.get_data() logger.debug('%d %s', len(data), data.encode('hex')) self.send_data(self.dns_response(data)) # pylint: disable=broad-except except Exception: logger.warning("Exception handling request", exc_info=True)
[docs]class TCPRequestHandler(BaseRequestHandler):
[docs] def get_data(self): data = self.request.recv(8192).strip() sz = int(data[:2].encode('hex'), 16) if sz < len(data) - 2: raise Exception("Wrong size of TCP packet") elif sz > len(data) - 2: raise Exception("Too big TCP packet") return data[2:]
[docs] def send_data(self, data): sz = hex(len(data))[2:].zfill(4).decode('hex') return self.request.sendall(sz + data)
[docs]class UDPRequestHandler(BaseRequestHandler):
[docs] def get_data(self): return self.request[0].strip()
[docs] def send_data(self, data): return self.request[1].sendto(data, self.client_address)
[docs]class UDPServer(socketserver.ThreadingUDPServer): def __init__(self, command, *args, **kwargs): socketserver.ThreadingUDPServer.__init__(self, *args, **kwargs) self.command = command
[docs]class TCPServer(socketserver.ThreadingTCPServer): def __init__(self, command, *args, **kwargs): socketserver.ThreadingTCPServer.__init__(self, *args, **kwargs) self.command = command
[docs]class Command(BaseCommand):
[docs] help = ( "DNS server for ipdnsauth.\n\nAnswers DNS queries for names " "and IP addresses managed by ipdnsauth module." )
[docs] def add_arguments(self, parser): parser.add_argument( '--port', '-p', type=int, default=8053, help="Specify port to listen on" ) parser.add_argument( '--bind-addr', dest='bind_addr', type=str, default='', help="IP address to bind the server", ) parser.add_argument( '--ttl', type=int, default=60, help="Specify TTL for returned records" )
def __init__(self, *args, **kwargs): super(Command, self).__init__(*args, **kwargs) self.options = None
[docs] def handle(self, *args, **options): if not getattr(settings, 'IPAUTH_DNSSERVER_DOMAIN', None): raise CommandError("IPAUTH_DNSSERVER_DOMAIN not set in settings") self.options = options listen_addr = (options['bind_addr'], options['port']) servers = [ ('udp', UDPServer(self, listen_addr, UDPRequestHandler)), ('tcp', TCPServer(self, listen_addr, TCPRequestHandler)), ] threads = [] for name, s in servers: thread = threading.Thread( target=s.serve_forever, name=('ipauth-dnsserver-' + name) ) thread.daemon = True thread.start() threads.append(thread) while True: time.sleep(1) # Terminate the script in case both threads terminate for t in threads: t.join()