Index: Lib/asyncore.py =================================================================== --- Lib/asyncore.py (revisione 88396) +++ Lib/asyncore.py (copia locale) @@ -521,6 +521,114 @@ self.log_info('unhandled close event', 'warning') self.close() + +try: + import ssl +except ImportError: + pass +else: + class ssl_dispatcher(dispatcher): + """A dispatcher subclass supporting SSL.""" + + _ssl_accepting = False + _ssl_established = False + _ssl_closing = False + + # --- public API + + def secure_connection(self, ssl_context, server_side=False, + server_hostname=None): + """Setup encrypted connection.""" + self.socket = ssl_context.wrap_socket(self.socket, + do_handshake_on_connect=False, suppress_ragged_eofs=True, + server_side=server_side, server_hostname=server_hostname) + self._ssl_accepting = True + + def ssl_shutdown(self): + """Tear down SSL layer switching back to a clear text connection.""" + if not self._ssl_established: + raise ValueError("not using SSL") + self._ssl_closing = True + try: + self.socket = self.socket.unwrap() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_SSL: + pass + else: + raise + except socket.error as err: + # Any "socket error" corresponds to a SSL_ERROR_SYSCALL + # return from OpenSSL's SSL_shutdown(), corresponding to + # a closed socket condition. See also: + # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html + pass + self._ssl_closing = False + self.handle_ssl_shutdown() + + def handle_ssl_established(self): + """Called when the SSL handshake has completed.""" + self.log_info('unhandled handle_ssl_established event', 'warning') + + def handle_ssl_shutdown(self): + """Called when SSL shutdown() has completed""" + self.log_info('unhandled handle_ssl_shutdown event', 'warning') + + # --- internals + + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + else: + self._ssl_accepting = False + self._ssl_established = True + self.handle_ssl_established() + + def handle_read_event(self): + if self._ssl_accepting: + self._do_ssl_handshake() + elif self._ssl_closing: + self.ssl_shutdown() + else: + super(ssl_dispatcher, self).handle_read_event() + + def handle_write_event(self): + if self._ssl_accepting: + self._do_ssl_handshake() + elif self._ssl_closing: + self.ssl_shutdown() + else: + super(ssl_dispatcher, self).handle_write_event() + + def send(self, data): + try: + return super(ssl_dispatcher, self).send(data) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): + return 0 + raise + + def recv(self, buffer_size): + try: + return super(ssl_dispatcher, self).recv(buffer_size) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): + self.handle_close() + return '' + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return '' + raise + + # --------------------------------------------------------------------------- # adds simple buffered output capability, useful for simple clients. # [for more sophisticated usage use asynchat.async_chat] Index: Lib/test/test_asyncore.py =================================================================== --- Lib/test/test_asyncore.py (revisione 88393) +++ Lib/test/test_asyncore.py (copia locale) @@ -471,8 +471,11 @@ """A server which listens on an address and dispatches the connection to a handler. """ + handler = BaseTestHandler - def __init__(self, handler=BaseTestHandler, host=HOST, port=0): + def __init__(self, handler=None, host=HOST, port=0): + if handler is None: + handler = self.handler asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.set_reuse_addr() @@ -739,11 +742,304 @@ use_poll = True +try: + import ssl +except ImportError: + ssl = None +else: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ssl_context.load_cert_chain(os.path.join(os.path.dirname(__file__), "keycert.pem")) + + class SSLBaseTestHandler(asyncore.ssl_dispatcher): + + def __init__(self, sock=None): + asyncore.ssl_dispatcher.__init__(self, sock) + self.flag = False + + def handle_accept(self): + raise Exception("handle_accept not supposed to be called") + + def handle_accepted(self): + raise Exception("handle_accepted not supposed to be called") + + def handle_connect(self): + raise Exception("handle_connect not supposed to be called") + + def handle_expt(self): + raise Exception("handle_expt not supposed to be called") + + def handle_close(self): + raise Exception("handle_close not supposed to be called") + + def handle_ssl_shutdown(self): + raise Exception("handle_ssl_shutdown not supposed to be called") + + def handle_error(self): + raise + + + class SSLTCPServer(TCPServer): + handler = SSLBaseTestHandler + + def handle_accepted(self, sock, addr): + handler = self.handler(sock) + handler.secure_connection(ssl_context, server_side=True) + + class SSLBaseClient(SSLBaseTestHandler): + + def __init__(self, address): + SSLBaseTestHandler.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.connect(address) + + def handle_connect(self): + self.secure_connection(ssl_context) + + +@unittest.skipIf(ssl is None, "ssl module not available") +class TestSSLDispatcher(unittest.TestCase): + + use_poll = False + + def tearDown(self): + asyncore.close_all() + + def loop_waiting_for_flag(self, instance, timeout=5): + timeout = float(timeout) / 100 + count = 100 + while asyncore.socket_map and count > 0: + asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) + if instance.flag: + return + count -= 1 + time.sleep(timeout) + self.fail("flag not set") + + def test_ssl_established(self): + + class TestClient(SSLBaseClient): + def handle_ssl_established(self): + self.flag = True + + server = SSLTCPServer() + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + def test_handle_connect(self): + # make sure handle_connect is called on connect() + + class TestClient(SSLBaseClient): + def handle_connect(self): + self.flag = True + + server = SSLTCPServer() + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + def test_ssl_shutdown(self): + + class TestHandler(SSLBaseTestHandler): + + def handle_ssl_established(self): + self.ssl_shutdown() + + def handle_ssl_shutdown(self): + flag.l.append(1) + + class TestClient(SSLBaseClient): + + def handle_ssl_established(self): + self.ssl_shutdown() + + def handle_ssl_shutdown(self): + flag.l.append(2) + + class Flag: + l = [] + @property + def flag(self): + return 1 in self.l and 2 in self.l + + flag = Flag() + server = SSLTCPServer(TestHandler) + client = TestClient(server.address) + self.loop_waiting_for_flag(flag) + + def test_handle_accept(self): + # make sure handle_accept() is called when a client connects + + class TestListener(SSLBaseTestHandler): + + def __init__(self): + SSLBaseTestHandler.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.bind((HOST, 0)) + self.listen(5) + self.address = self.socket.getsockname()[:2] + + def handle_accept(self): + self.flag = True + + server = TestListener() + client = SSLBaseClient(server.address) + self.loop_waiting_for_flag(server) + + def test_handle_accepted(self): + # make sure handle_accepted() is called when a client connects + + class TestListener(SSLBaseTestHandler): + + def __init__(self): + SSLBaseTestHandler.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.bind((HOST, 0)) + self.listen(5) + self.address = self.socket.getsockname()[:2] + + def handle_accept(self): + asyncore.dispatcher.handle_accept(self) + + def handle_accepted(self, sock, addr): + sock.close() + self.flag = True + + server = TestListener() + client = SSLBaseClient(server.address) + self.loop_waiting_for_flag(server) + + def test_handle_read(self): + # make sure handle_read is called on data received + + class TestClient(SSLBaseClient): + def handle_read(self): + self.flag = True + + class TestHandler(SSLBaseTestHandler): + + def __init__(self, conn): + SSLBaseTestHandler.__init__(self, conn) + + def handle_ssl_established(self): + self.send(b'x' * 1024) + + server = SSLTCPServer(TestHandler) + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + def test_handle_write(self): + # make sure handle_write is called + + class TestClient(SSLBaseClient): + def handle_write(self): + self.flag = True + + server = SSLTCPServer() + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + + def test_handle_close(self): + # make sure handle_close is called when the other end closes + # the connection + + class TestClient(SSLBaseClient): + + def handle_read(self): + # in order to make handle_close be called we are supposed + # to make at least one recv() call + self.recv(1024) + + def handle_close(self): + self.flag = True + self.close() + + class TestHandler(SSLBaseTestHandler): + def __init__(self, conn): + SSLBaseTestHandler.__init__(self, conn) + + def handle_ssl_established(self): + self.close() + + server = SSLTCPServer(TestHandler) + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + + def test_handle_error(self): + + class TestClient(SSLBaseClient): + def handle_write(self): + 1.0 / 0 + def handle_error(self): + self.flag = True + try: + raise + except ZeroDivisionError: + pass + else: + raise Exception("exception not raised") + + server = SSLTCPServer() + client = TestClient(server.address) + self.loop_waiting_for_flag(client) + + def test_connection_attributes(self): + server = SSLTCPServer() + client = SSLBaseClient(server.address) + + # we start disconnected + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + # this can't be taken for granted across all platforms + #self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # execute some loops so that client connects to server + asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertTrue(client.connected) + self.assertFalse(client.accepting) + + # disconnect the client + client.close() + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # stop serving + server.close() + self.assertFalse(server.connected) + self.assertFalse(server.accepting) + + def test_create_socket(self): + s = asyncore.dispatcher() + s.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.assertEqual(s.socket.family, socket.AF_INET) + SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) + self.assertEqual(s.socket.type, socket.SOCK_STREAM | SOCK_NONBLOCK) + + def test_bind(self): + s1 = asyncore.dispatcher() + s1.create_socket(socket.AF_INET, socket.SOCK_STREAM) + s1.bind((HOST, 0)) + s1.listen(5) + port = s1.socket.getsockname()[1] + + s2 = asyncore.dispatcher() + s2.create_socket(socket.AF_INET, socket.SOCK_STREAM) + # EADDRINUSE indicates the socket was correctly bound + self.assertRaises(socket.error, s2.bind, (HOST, port)) + + + def test_main(): tests = [HelperFunctionTests, DispatcherTests, DispatcherWithSendTests, DispatcherWithSendTests_UsePoll, TestAPI_UseSelect, - TestAPI_UsePoll, FileWrapperTest] + TestAPI_UsePoll, FileWrapperTest, TestSSLDispatcher] run_unittest(*tests) if __name__ == "__main__": test_main() +