Index: Doc/whatsnew/3.2.rst =================================================================== --- Doc/whatsnew/3.2.rst (revisione 84638) +++ Doc/whatsnew/3.2.rst (copia locale) @@ -389,6 +389,12 @@ (Contributed by Giampaolo RodolĂ ; :issue:`8807`.) +* :func:`socket.create_connection` now supports the context manager protocol + to unconditionally consume :exc:`socket.error` exceptions and to close the + socket when done. + + (Contributed by Giampaolo RodolĂ ; :issue:`9794`.) + Multi-threading =============== Index: Doc/library/socket.rst =================================================================== --- Doc/library/socket.rst (revisione 84638) +++ Doc/library/socket.rst (copia locale) @@ -213,6 +213,9 @@ .. versionchanged:: 3.2 *source_address* was added. + .. versionchanged:: 3.2 + support for the :keyword:`with` statement was added. + .. function:: getaddrinfo(host, port, family=0, type=0, proto=0, flags=0) Index: Lib/socket.py =================================================================== --- Lib/socket.py (revisione 84638) +++ Lib/socket.py (copia locale) @@ -93,6 +93,13 @@ self._io_refs = 0 self._closed = False + def __enter__(self): + return self + + def __exit__(self, *args): + if not self._closed: + self.close() + def __repr__(self): """Wrap __repr__() to reveal the real class name.""" s = _socket.socket.__repr__(self) Index: Lib/test/test_socket.py =================================================================== --- Lib/test/test_socket.py (revisione 84638) +++ Lib/test/test_socket.py (copia locale) @@ -1595,6 +1595,49 @@ self.cli.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class ContextManagersTest(ThreadedTCPSocketTest): + + def _testSocketClass(self): + # base test + with socket.socket() as sock: + self.assertFalse(sock._closed) + self.assertTrue(sock._closed) + # close inside with block + with socket.socket() as sock: + sock.close() + self.assertTrue(sock._closed) + # exception inside with block + with socket.socket() as sock: + self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionBase(self): + conn, addr = self.serv.accept() + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionBase(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + self.assertFalse(sock._closed) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionClose(self): + conn, addr = self.serv.accept() + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionClose(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + sock.close() + self.assertTrue(sock._closed) + self.assertRaises(socket.error, sock.sendall, b'foo') + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -1609,6 +1652,7 @@ NetworkConnectionNoServer, NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, + ContextManagersTest, ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest)