Index: Lib/socket.py =================================================================== --- Lib/socket.py (revisione 84591) +++ Lib/socket.py (copia locale) @@ -93,6 +93,13 @@ self._io_refs = 0 self._closed = False + def __enter__(self): + return self + + def __exit__(self, *args): + if not self._closed: + self.close() + def __repr__(self): """Wrap __repr__() to reveal the real class name.""" s = _socket.socket.__repr__(self) Index: Lib/test/test_socket.py =================================================================== --- Lib/test/test_socket.py (revisione 84591) +++ Lib/test/test_socket.py (copia locale) @@ -15,6 +15,7 @@ import array from weakref import proxy import signal +import socketserver def try_address(host, port=0, family=socket.AF_INET): """Try to bind a socket on the given host:port and return True @@ -1564,6 +1565,52 @@ self.cli.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class ContextManagersTest(unittest.TestCase): + + def setUp(self): + class EchoHandler(socketserver.BaseRequestHandler): + def handle(self): + data = self.request.recv(1024) + self.request.sendall(data) + + self.server = socketserver.TCPServer((HOST, 0), EchoHandler) + self.thread = threading.Thread(target=self.server.serve_forever) + self.thread.start() + + def tearDown(self): + self.server.shutdown() + self.thread.join() + + def test_case(self): + # base test + with socket.socket() as sock: + self.assertFalse(sock._closed) + self.assertTrue(sock._closed) + # close inside with block + with socket.socket() as sock: + sock.close() + self.assertTrue(sock._closed) + # exception inside with block + with socket.socket() as sock: + self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertTrue(sock._closed) + + # same for create_connection + address = self.server.server_address + with socket.create_connection(address) as sock: + self.assertFalse(sock._closed) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + self.assertTrue(sock._closed) + + self.assertTrue(sock._closed) + with socket.create_connection(address) as sock: + sock.close() + self.assertTrue(sock._closed) + self.assertRaises(socket.error, sock.sendall, b'foo') + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -1578,6 +1625,7 @@ NetworkConnectionNoServer, NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, + ContextManagersTest, ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest)