diff -r a47996c10579 Doc/library/smtpd.rst --- a/Doc/library/smtpd.rst Wed Aug 06 18:55:54 2014 +0300 +++ b/Doc/library/smtpd.rst Sat Aug 09 00:38:23 2014 +0000 @@ -60,6 +60,15 @@ argument will be a unicode string. If it is set to ``False``, it will be a bytes object. + .. method:: validate_recipient_address(address) + + Returns ``True``. Override this in a subclass to accept or deny a + recipient addresses conditionally. + *address* is a string of the form ``local-part@domain``. If this method + returns ``False`` on a given address :meth:`SMTPChannel.smtp_RCPT` + answers with ``b'554 <[address]>: Relay access denied.'``. + Use this method to avoid becoming an open relay. + .. attribute:: channel_class Override this in subclasses to use a custom :class:`SMTPChannel` for @@ -71,6 +80,9 @@ .. versionchanged:: 3.5 the *decode_data* argument was added, and *localaddr* and *remoteaddr* may now contain IPv6 addresses. + .. versionadded:: 3.6 + The :meth:`validate_recipient_address` method was added. + DebuggingServer Objects ----------------------- diff -r a47996c10579 Lib/smtpd.py --- a/Lib/smtpd.py Wed Aug 06 18:55:54 2014 +0300 +++ b/Lib/smtpd.py Sat Aug 09 00:38:23 2014 +0000 @@ -558,6 +558,9 @@ if not address: self.push('501 Syntax: RCPT TO:
') return + if not self.smtp_server.validate_recipient_address(address): + self.push('554 <%s>: Relay access denied.' % address) + return self.rcpttos.append(address) print('recips:', self.rcpttos, file=DEBUGSTREAM) self.push('250 OK') @@ -655,6 +658,13 @@ """ raise NotImplementedError + def validate_recipient_address(self, address): + """ + Override this method to determine if the server should relay messages + to a given address. Accepts any recipient by default. + """ + return True + class DebuggingServer(SMTPServer): # Do something with the gathered message diff -r a47996c10579 Lib/test/test_smtpd.py --- a/Lib/test/test_smtpd.py Wed Aug 06 18:55:54 2014 +0300 +++ b/Lib/test/test_smtpd.py Sat Aug 09 00:38:23 2014 +0000 @@ -55,6 +55,11 @@ with self.assertWarns(DeprecationWarning): smtpd.SMTPServer((support.HOST, 0), ('b', 0)) + def test_validate_recipient_address(self): + server = smtpd.SMTPServer( + (support.HOST, 0), ('b', 0), decode_data=True) + self.assertTrue(server.validate_recipient_address('test@example.com')) + def tearDown(self): asyncore.close_all() asyncore.socket = smtpd.socket = socket @@ -80,6 +85,41 @@ self.assertEqual(server.socket.family, socket.AF_INET) +class SMTPDRestrictedServerTest(unittest.TestCase): + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.allowed_domains = ['example.com', 'spam.egg'] + + class RestrictedSMTPServer(DummyServer): + allowed_domains = self.allowed_domains + + def validate_recipient_address(self, address): + if address.split('@')[1] in self.allowed_domains: + return True + return False + + self.server = RestrictedSMTPServer( + (support.HOST, 0), ('b', 0), decode_data=True) + conn, addr = self.server.accept() + self.channel = smtpd.SMTPChannel( + self.server, conn, addr, decode_data=True) + + def write_line(self, line): + self.channel.socket.queue_recv(line) + self.channel.handle_read() + + def test_validate_recipient_address_with_restriction(self): + self.write_line(b'EHLO example') + self.write_line(b'MAIL FROM: foo@example.com') + + for address in ['spam@egg.foo', 'foo@spam.egg', 'test@example.com']: + self.write_line(('RCPT TO: %s' % address).encode('ascii')) + self.assertEqual( + self.channel.socket.last.decode('ascii').split(' ')[0], + "250" if address.split('@')[1] in self.allowed_domains + else "554") + + class SMTPDChannelTest(unittest.TestCase): def setUp(self): smtpd.socket = asyncore.socket = mock_socket