diff -r 3ae2cd85a908 Doc/library/smtpd.rst --- a/Doc/library/smtpd.rst Sun Mar 09 11:18:16 2014 +0100 +++ b/Doc/library/smtpd.rst Tue Jun 10 23:50:31 2014 +0200 @@ -51,6 +51,15 @@ containing the contents of the e-mail (which should be in :rfc:`2822` format). + .. method:: validate_recipient_address(address) + + Returns ``True``. Override this in a subclass to accept or deny a + recipient addresses conditionally. + *address* is a string of the form ``local-part@domain``. If this method + returns ``False`` on a given address :meth:`SMTPChannel.smtp_RCPT` + answers with ``b'554 <[address]>: Relay access denied.'``. + Use this method to avoid becoming an open relay. + .. attribute:: channel_class Override this in subclasses to use a custom :class:`SMTPChannel` for @@ -59,6 +68,9 @@ .. versionchanged:: 3.4 The *map* argument was added. + .. versionadded:: 3.6 + The :meth:`validate_recipient_address` method was added. + DebuggingServer Objects ----------------------- diff -r 3ae2cd85a908 Lib/smtpd.py --- a/Lib/smtpd.py Sun Mar 09 11:18:16 2014 +0100 +++ b/Lib/smtpd.py Tue Jun 10 23:50:31 2014 +0200 @@ -538,6 +538,9 @@ if not address: self.push('501 Syntax: RCPT TO:
') return + if not self.smtp_server.validate_recipient_address(address): + self.push('554 <%s>: Relay access denied.' % address) + return self.rcpttos.append(address) print('recips:', self.rcpttos, file=DEBUGSTREAM) self.push('250 OK') @@ -626,6 +629,13 @@ """ raise NotImplementedError + def validate_recipient_address(self, address): + """ + Override this method to determine if the server should relay messages + to a given address. Accepts any recipient by default. + """ + return True + class DebuggingServer(SMTPServer): # Do something with the gathered message diff -r 3ae2cd85a908 Lib/test/test_smtpd.py --- a/Lib/test/test_smtpd.py Sun Mar 09 11:18:16 2014 +0100 +++ b/Lib/test/test_smtpd.py Tue Jun 10 23:50:31 2014 +0200 @@ -45,11 +45,49 @@ write_line(b'DATA') self.assertRaises(NotImplementedError, write_line, b'spam\r\n.\r\n') + def test_validate_recipient_address(self): + server = smtpd.SMTPServer((support.HOST, 0), ('b', 0)) + self.assertTrue(server.validate_recipient_address('test@example.com')) + def tearDown(self): asyncore.close_all() asyncore.socket = smtpd.socket = socket +class SMTPDRestrictedServerTest(unittest.TestCase): + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.allowed_domains = ['example.com', 'spam.egg'] + + class RestrictedSMTPServer(DummyServer): + allowed_domains = self.allowed_domains + + def validate_recipient_address(self, address): + print(self.allowed_domains) + if address.split('@')[1] in self.allowed_domains: + return True + return False + + self.server = RestrictedSMTPServer((support.HOST, 0), ('b', 0)) + conn, addr = self.server.accept() + self.channel = smtpd.SMTPChannel(self.server, conn, addr) + + def write_line(self, line): + self.channel.socket.queue_recv(line) + self.channel.handle_read() + + def test_validate_recipient_address_with_restriction(self): + self.write_line(b'EHLO example') + self.write_line(b'MAIL FROM: foo@example.com') + + for address in ['spam@egg.foo', 'foo@spam.egg', 'test@example.com']: + self.write_line(('RCPT TO: %s' % address).encode('ascii')) + self.assertEqual( + self.channel.socket.last.decode('ascii').split(' ')[0], + "250" if address.split('@')[1] in self.allowed_domains + else "554") + + class SMTPDChannelTest(unittest.TestCase): def setUp(self): smtpd.socket = asyncore.socket = mock_socket