#!/usr/bin/env python import asynchat import asyncore import errno from http.server import HTTPServer, BaseHTTPRequestHandler import logging import logging.handlers import smtpd import socket import threading import unittest from urllib.parse import urlparse, parse_qs class TestSMTPChannel(smtpd.SMTPChannel): """ This derived class has had to be created because smtpd does not support use of custom channel maps, although they are allowed by asyncore's design. Issue #11959 has been raised to address this, and if resolved satisfactorily, some of this code can be removed. """ def __init__(self, server, conn, addr, sockmap): asynchat.async_chat.__init__(self, conn, sockmap) self.smtp_server = server self.conn = conn self.addr = addr self.received_lines = [] self.smtp_state = self.COMMAND self.seen_greeting = '' self.mailfrom = None self.rcpttos = [] self.received_data = '' self.fqdn = socket.getfqdn() self.num_bytes = 0 try: self.peer = conn.getpeername() except socket.error as err: # a race condition may occur if the other end is closing # before we can get the peername self.close() if err.args[0] != errno.ENOTCONN: raise return self.push('220 %s %s' % (self.fqdn, smtpd.__version__)) self.set_terminator(b'\r\n') class TestSMTPServer(smtpd.SMTPServer): """ This class implements a test SMTP server. :param addr: A (host, port) tuple which the server listens on. You can specify a port value of zero: the server's *port* attribute will hold the actual port number used, which can be used in client connections. :param handler: A callable which will be called to process incoming messages. The handler will be passed the client address tuple, who the message is from, a list of recipients and the message data. :param poll_interval: The interval, in seconds, used in the underlying :func:`select` or :func:`poll` call by :func:`asyncore.loop`. :param sockmap: A dictionary which will be used to hold :class:`asyncore.dispatcher` instances used by :func:`asyncore.loop`. This avoids changing the :mod:`asyncore` module's global state. """ channel_class = TestSMTPChannel def __init__(self, addr, handler, poll_interval, sockmap): self._localaddr = addr self._remoteaddr = None self.sockmap = sockmap asyncore.dispatcher.__init__(self, map=sockmap) try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setblocking(0) self.set_socket(sock, map=sockmap) # try to re-use a server port if possible self.set_reuse_addr() self.bind(addr) self.port = sock.getsockname()[1] self.listen(5) except: self.close() raise self._handler = handler self._thread = None self.poll_interval = poll_interval def handle_accepted(self, conn, addr): """ Redefined only because the base class does not pass in a map, forcing use of a global in :mod:`asyncore`. """ channel = self.channel_class(self, conn, addr, self.sockmap) def process_message(self, peer, mailfrom, rcpttos, data): """ Delegates to the handler passed in to the server's constructor. Typically, this will be a test case method. :param peer: The client (host, port) tuple. :param mailfrom: The address of the sender. :param rcpttos: The addresses of the recipients. :param data: The message. """ self._handler(peer, mailfrom, rcpttos, data) def start(self): """ Start the server running on a separate daemon thread. """ self._thread = t = threading.Thread(target=self.serve_forever, args=(self.poll_interval,)) t.setDaemon(True) t.start() def serve_forever(self, poll_interval): """ Run the :mod:`asyncore` loop until normal termination conditions arise. :param poll_interval: The interval, in seconds, used in the underlying :func:`select` or :func:`poll` call by :func:`asyncore.loop`. """ asyncore.loop(poll_interval, map=self.sockmap) def stop(self, timeout=None): """ Stop the thread by closing the server instance. Wait for the server thread to terminate. :param timeout: How long to wait for the server thread to terminate. """ self.close() self._thread.join(timeout) self._thread = None class ControlMixin(object): """ This mixin is used to start a server on a separate thread, and shut it down programmatically. Request handling is simplified - instead of needing to derive a suitable RequestHandler subclass, you just provide a callable which will be passed each received request to be processed. :param handler: A handler callable which will be called with a single parameter - the request - in order to process the request. This handler is called on the server thread, effectively meaning that requests are processed serially. While not quite Web scale ;-), this should be fine for testing applications. :param poll_interval: The polling interval in seconds. """ def __init__(self, handler, poll_interval): self._thread = None self.poll_interval = poll_interval self._handler = handler self.ready = threading.Event() def start(self): """ Create a daemon thread to run the server, and start it. """ self._thread = t = threading.Thread(target=self.serve_forever, args=(self.poll_interval,)) t.setDaemon(True) t.start() def serve_forever(self, poll_interval): """ Run the server. Set the ready flag before entering the service loop. """ self.ready.set() super(ControlMixin, self).serve_forever(poll_interval) def stop(self, timeout=None): """ Tell the server thread to stop, and wait for it to do so. :param timeout: How long to wait for the server thread to terminate. """ self.shutdown() if self._thread is not None: self._thread.join(timeout) self._thread = None self.server_close() self.ready.clear() class TestHTTPServer(ControlMixin, HTTPServer): """ An HTTP server which is controllable using :class:`ControlMixin`. :param addr: A tuple with the IP address and port to listen on. :param handler: A handler callable which will be called with a single parameter - the request - in order to process the request. :param poll_interval: The polling interval in seconds. :param log: Pass ``True`` to enable log messages. """ def __init__(self, addr, handler, poll_interval=0.5, log=False): class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): def __getattr__(self, name, default=None): if name.startswith('do_'): return self.process_request raise AttributeError(name) def process_request(self): self.server._handler(self) def log_message(self, format, *args): if log: super(DelegatingHTTPRequestHandler, self).log_message(format, *args) HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) ControlMixin.__init__(self, handler, poll_interval) class ServerTestCase(unittest.TestCase): def test_smtp(self): sockmap = {} server = TestSMTPServer(('localhost', 0), self.process_message, 0.001, sockmap) server.start() addr = ('localhost', server.port) h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log') self.assertEqual(h.toaddrs, ['you']) self.messages = [] r = logging.makeLogRecord({'msg': 'Hello'}) self.handled = threading.Event() h.handle(r) self.handled.wait() server.stop() self.assertEqual(len(self.messages), 1) peer, mailfrom, rcpttos, data = self.messages[0] self.assertEqual(mailfrom, 'me') self.assertEqual(rcpttos, ['you']) self.assertTrue('\nSubject: Log\n' in data) self.assertTrue(data.endswith('\n\nHello')) h.close() def test_http(self): # The log message sent to the HTTPHandler is properly received. addr = ('localhost', 0) logger = logging.getLogger("http") server = TestHTTPServer(addr, self.handle_request, 0.01) server.start() server.ready.wait() host = 'localhost:%d' % server.server_port hdlr = logging.handlers.HTTPHandler(host, '/frob') logger.addHandler(hdlr) self.log_data = None self.handled = threading.Event() for method in ('GET', 'POST'): hdlr.method = method msg = "sp\xe4m" logger.error(msg) self.handled.wait() self.assertEqual(self.log_data.path, '/frob') self.assertEqual(self.command, method) if method == 'GET': d = parse_qs(self.log_data.query) else: d = parse_qs(self.post_data.decode('utf-8')) self.assertEqual(d['name'], ['http']) self.assertEqual(d['funcName'], ['test_http']) self.assertEqual(d['msg'], [msg]) server.stop(2.0) logger.removeHandler(hdlr) hdlr.close() def handle_request(self, request): self.command = request.command self.log_data = urlparse(request.path) if self.command == 'POST': try: rlen = int(request.headers['Content-Length']) self.post_data = request.rfile.read(rlen) except: self.post_data = None request.send_response(200) self.handled.set() def process_message(self, *args): self.messages.append(args) self.handled.set() if __name__ == '__main__': unittest.main()