#!/usr/bin/env python import asynchat import asyncore from email.utils import formatdate import errno import smtpd import smtplib import socket import threading import unittest class TestSMTPChannel(smtpd.SMTPChannel): """ This derived class has had to be created because smtpd does not support use of custom channel maps, although they are allowed by asyncore's design. Issue #11959 has been raised to address this, and if resolved satisfactorily, some of this code can be removed. """ def __init__(self, server, conn, addr, sockmap): asynchat.async_chat.__init__(self, conn, sockmap) self.smtp_server = server self.conn = conn self.addr = addr self.received_lines = [] self.smtp_state = self.COMMAND self.seen_greeting = '' self.mailfrom = None self.rcpttos = [] self.received_data = '' self.fqdn = socket.getfqdn() self.num_bytes = 0 try: self.peer = conn.getpeername() except socket.error as err: # a race condition may occur if the other end is closing # before we can get the peername self.close() if err.args[0] != errno.ENOTCONN: raise return self.push('220 %s %s' % (self.fqdn, smtpd.__version__)) self.set_terminator(b'\r\n') class TestSMTPServer(smtpd.SMTPServer): """ This class implements a test SMTP server. :param addr: A (host, port) tuple which the server listens on. You can specify a port value of zero: the server's *port* attribute will hold the actual port number used, which can be used in client connections. :param handler: A callable which will be called to process incoming messages. The handler will be passed the client address tuple, who the message is from, a list of recipients and the message data. :param poll_interval: The interval, in seconds, used in the underlying :func:`select` or :func:`poll` call by :func:`asyncore.loop`. :param sockmap: A dictionary which will be used to hold :class:`asyncore.dispatcher` instances used by :func:`asyncore.loop`. This avoids changing the :mod:`asyncore` module's global state. """ channel_class = TestSMTPChannel def __init__(self, addr, handler, poll_interval, sockmap): self._localaddr = addr self._remoteaddr = None self.sockmap = sockmap asyncore.dispatcher.__init__(self, map=sockmap) try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setblocking(0) self.set_socket(sock, map=sockmap) # try to re-use a server port if possible self.set_reuse_addr() self.bind(addr) self.port = sock.getsockname()[1] self.listen(5) except: self.close() raise self._handler = handler self._thread = None self.poll_interval = poll_interval def handle_accepted(self, conn, addr): """ Redefined only because the base class does not pass in a map, forcing use of a global in :mod:`asyncore`. """ channel = self.channel_class(self, conn, addr, self.sockmap) def process_message(self, peer, mailfrom, rcpttos, data): """ Delegates to the handler passed in to the server's constructor. Typically, this will be a test case method. :param peer: The client (host, port) tuple. :param mailfrom: The address of the sender. :param rcpttos: The addresses of the recipients. :param data: The message. """ self._handler(peer, mailfrom, rcpttos, data) def start(self): """ Start the server running on a separate daemon thread. """ self._thread = t = threading.Thread(target=self.serve_forever, args=(self.poll_interval,)) t.setDaemon(True) t.start() def serve_forever(self, poll_interval): """ Run the :mod:`asyncore` loop until normal termination conditions arise. :param poll_interval: The interval, in seconds, used in the underlying :func:`select` or :func:`poll` call by :func:`asyncore.loop`. """ asyncore.loop(poll_interval, map=self.sockmap) def stop(self, timeout=None): """ Stop the thread by closing the server instance. Wait for the server thread to terminate. :param timeout: How long to wait for the server thread to terminate. """ self.close() self._thread.join(timeout) self._thread = None class ControlMixin(object): """ This mixin is used to start a server on a separate thread, and shut it down programmatically. Request handling is simplified - instead of needing to derive a suitable RequestHandler subclass, you just provide a callable which will be passed each received request to be processed. :param handler: A handler callable which will be called with a single parameter - the request - in order to process the request. This handler is called on the server thread, effectively meaning that requests are processed serially. While not quite Web scale ;-), this should be fine for testing applications. :param poll_interval: The polling interval in seconds. """ def __init__(self, handler, poll_interval): self._thread = None self.poll_interval = poll_interval self._handler = handler self.ready = threading.Event() def start(self): """ Create a daemon thread to run the server, and start it. """ self._thread = t = threading.Thread(target=self.serve_forever, args=(self.poll_interval,)) t.setDaemon(True) t.start() def serve_forever(self, poll_interval): """ Run the server. Set the ready flag before entering the service loop. """ self.ready.set() super(ControlMixin, self).serve_forever(poll_interval) def stop(self, timeout=None): """ Tell the server thread to stop, and wait for it to do so. :param timeout: How long to wait for the server thread to terminate. """ self.shutdown() if self._thread is not None: self._thread.join(timeout) self._thread = None self.server_close() self.ready.clear() class ServerTestCase(unittest.TestCase): def test_smtp(self): sockmap = {} server = TestSMTPServer(('localhost', 0), self.process_message, 0.001, sockmap) server.start() self.handled = threading.Event() smtp = smtplib.SMTP('localhost', server.port) fromaddr = 'me' toaddrs = ['you'] subject = 'Log' msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\n\r\nHello" % ( fromaddr, ",".join(toaddrs), subject, formatdate()) self.messages = [] smtp.sendmail(fromaddr, toaddrs, msg) smtp.quit() self.handled.wait() server.stop() self.assertEqual(len(self.messages), 1) peer, mailfrom, rcpttos, data = self.messages[0] self.assertEqual(mailfrom, fromaddr) self.assertEqual(rcpttos, toaddrs) self.assertTrue('\nSubject: Log\n' in data) self.assertTrue(data.endswith('\n\nHello')) def process_message(self, *args): self.messages.append(args) self.handled.set() if __name__ == '__main__': unittest.main()