Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(62637)

Delta Between Two Patch Sets: Lib/ssl.py

Issue 22417: PEP 476: verify HTTPS certificates by default
Left Patch Set: Created 4 years, 9 months ago
Right Patch Set: Created 4 years, 7 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « Lib/http/client.py ('k') | Lib/test/test_httplib.py » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 # Wrapper module for _ssl, providing some additional facilities 1 # Wrapper module for _ssl, providing some additional facilities
2 # implemented in Python. Written by Bill Janssen. 2 # implemented in Python. Written by Bill Janssen.
3 3
4 """This module provides some more Pythonic support for SSL. 4 """This module provides some more Pythonic support for SSL.
5 5
6 Object types: 6 Object types:
7 7
8 SSLSocket -- subtype of socket.socket which does SSL over the socket 8 SSLSocket -- subtype of socket.socket which does SSL over the socket
9 9
10 Exceptions: 10 Exceptions:
(...skipping 74 matching lines...) Expand 10 before | Expand all | Expand 10 after
85 ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE 85 ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
86 ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE 86 ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
87 ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY 87 ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
88 """ 88 """
89 89
90 import textwrap 90 import textwrap
91 import re 91 import re
92 import sys 92 import sys
93 import os 93 import os
94 from collections import namedtuple 94 from collections import namedtuple
95 from enum import Enum as _Enum 95 from enum import Enum as _Enum, IntEnum as _IntEnum
96 96
97 import _ssl # if we can't import it, let the error propagate 97 import _ssl # if we can't import it, let the error propagate
98 98
99 from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION 99 from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
100 from _ssl import _SSLContext 100 from _ssl import _SSLContext, MemoryBIO
101 from _ssl import ( 101 from _ssl import (
102 SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, 102 SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
103 SSLSyscallError, SSLEOFError, 103 SSLSyscallError, SSLEOFError,
104 ) 104 )
105 from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED 105 from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
106 from _ssl import (VERIFY_DEFAULT, VERIFY_CRL_CHECK_LEAF, VERIFY_CRL_CHECK_CHAIN, 106 from _ssl import (VERIFY_DEFAULT, VERIFY_CRL_CHECK_LEAF, VERIFY_CRL_CHECK_CHAIN,
107 VERIFY_X509_STRICT) 107 VERIFY_X509_STRICT)
108 from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj 108 from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
109 from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes 109 from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes
110 110
111 def _import_symbols(prefix): 111 def _import_symbols(prefix):
112 for n in dir(_ssl): 112 for n in dir(_ssl):
113 if n.startswith(prefix): 113 if n.startswith(prefix):
114 globals()[n] = getattr(_ssl, n) 114 globals()[n] = getattr(_ssl, n)
115 115
116 _import_symbols('OP_') 116 _import_symbols('OP_')
117 _import_symbols('ALERT_DESCRIPTION_') 117 _import_symbols('ALERT_DESCRIPTION_')
118 _import_symbols('SSL_ERROR_') 118 _import_symbols('SSL_ERROR_')
119 119
120 from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN 120 from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
121 121
122 from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
123 from _ssl import _OPENSSL_API_VERSION 122 from _ssl import _OPENSSL_API_VERSION
124 123
125 124 _SSLMethod = _IntEnum('_SSLMethod',
126 _PROTOCOL_NAMES = { 125 {name: value for name, value in vars(_ssl).items()
127 PROTOCOL_TLSv1: "TLSv1", 126 if name.startswith('PROTOCOL_')})
128 PROTOCOL_SSLv23: "SSLv23", 127 globals().update(_SSLMethod.__members__)
129 PROTOCOL_SSLv3: "SSLv3", 128
130 } 129 _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items() }
130
131 try: 131 try:
132 from _ssl import PROTOCOL_SSLv2
133 _SSLv2_IF_EXISTS = PROTOCOL_SSLv2 132 _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
134 except ImportError: 133 except NameError:
135 _SSLv2_IF_EXISTS = None 134 _SSLv2_IF_EXISTS = None
136 else:
137 _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
138
139 try:
140 from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2
141 except ImportError:
142 pass
143 else:
144 _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1"
145 _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2"
146 135
147 if sys.platform == "win32": 136 if sys.platform == "win32":
148 from _ssl import enum_certificates, enum_crls 137 from _ssl import enum_certificates, enum_crls
149 138
150 from socket import socket, AF_INET, SOCK_STREAM, create_connection 139 from socket import socket, AF_INET, SOCK_STREAM, create_connection
151 from socket import SOL_SOCKET, SO_TYPE 140 from socket import SOL_SOCKET, SO_TYPE
152 import base64 # for DER-to-PEM translation 141 import base64 # for DER-to-PEM translation
153 import errno 142 import errno
154 143
155 144
(...skipping 50 matching lines...) Expand 10 before | Expand all | Expand 10 after
206 """ 195 """
207 pats = [] 196 pats = []
208 if not dn: 197 if not dn:
209 return False 198 return False
210 199
211 leftmost, *remainder = dn.split(r'.') 200 leftmost, *remainder = dn.split(r'.')
212 201
213 wildcards = leftmost.count('*') 202 wildcards = leftmost.count('*')
214 if wildcards > max_wildcards: 203 if wildcards > max_wildcards:
215 # Issue #17980: avoid denials of service by refusing more 204 # Issue #17980: avoid denials of service by refusing more
216 # than one wildcard per fragment. A survery of established 205 # than one wildcard per fragment. A survey of established
217 # policy among SSL implementations showed it to be a 206 # policy among SSL implementations showed it to be a
218 # reasonable choice. 207 # reasonable choice.
219 raise CertificateError( 208 raise CertificateError(
220 "too many wildcards in certificate DNS name: " + repr(dn)) 209 "too many wildcards in certificate DNS name: " + repr(dn))
221 210
222 # speed up common case w/o wildcards 211 # speed up common case w/o wildcards
223 if not wildcards: 212 if not wildcards:
224 return dn.lower() == hostname.lower() 213 return dn.lower() == hostname.lower()
225 214
226 # RFC 6125, section 6.4.3, subitem 1. 215 # RFC 6125, section 6.4.3, subitem 1.
(...skipping 129 matching lines...) Expand 10 before | Expand all | Expand 10 after
356 def wrap_socket(self, sock, server_side=False, 345 def wrap_socket(self, sock, server_side=False,
357 do_handshake_on_connect=True, 346 do_handshake_on_connect=True,
358 suppress_ragged_eofs=True, 347 suppress_ragged_eofs=True,
359 server_hostname=None): 348 server_hostname=None):
360 return SSLSocket(sock=sock, server_side=server_side, 349 return SSLSocket(sock=sock, server_side=server_side,
361 do_handshake_on_connect=do_handshake_on_connect, 350 do_handshake_on_connect=do_handshake_on_connect,
362 suppress_ragged_eofs=suppress_ragged_eofs, 351 suppress_ragged_eofs=suppress_ragged_eofs,
363 server_hostname=server_hostname, 352 server_hostname=server_hostname,
364 _context=self) 353 _context=self)
365 354
355 def wrap_bio(self, incoming, outgoing, server_side=False,
356 server_hostname=None):
357 sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
358 server_hostname=server_hostname)
359 return SSLObject(sslobj)
360
366 def set_npn_protocols(self, npn_protocols): 361 def set_npn_protocols(self, npn_protocols):
367 protos = bytearray() 362 protos = bytearray()
368 for protocol in npn_protocols: 363 for protocol in npn_protocols:
369 b = bytes(protocol, 'ascii') 364 b = bytes(protocol, 'ascii')
370 if len(b) == 0 or len(b) > 255: 365 if len(b) == 0 or len(b) > 255:
371 raise SSLError('NPN protocols must be 1 to 255 in length') 366 raise SSLError('NPN protocols must be 1 to 255 in length')
372 protos.append(len(b)) 367 protos.append(len(b))
373 protos.extend(b) 368 protos.extend(b)
374 369
375 self._set_npn_protocols(protos) 370 self._set_npn_protocols(protos)
376 371
377 def _load_windows_store_certs(self, storename, purpose): 372 def _load_windows_store_certs(self, storename, purpose):
378 certs = bytearray() 373 certs = bytearray()
379 for cert, encoding, trust in enum_certificates(storename): 374 for cert, encoding, trust in enum_certificates(storename):
380 # CA certs are never PKCS#7 encoded 375 # CA certs are never PKCS#7 encoded
381 if encoding == "x509_asn": 376 if encoding == "x509_asn":
382 if trust is True or purpose.oid in trust: 377 if trust is True or purpose.oid in trust:
383 certs.extend(cert) 378 certs.extend(cert)
384 self.load_verify_locations(cadata=certs) 379 self.load_verify_locations(cadata=certs)
385 return certs 380 return certs
386 381
387 def load_default_certs(self, purpose=Purpose.SERVER_AUTH): 382 def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
388 if not isinstance(purpose, _ASN1Object): 383 if not isinstance(purpose, _ASN1Object):
389 raise TypeError(purpose) 384 raise TypeError(purpose)
390 if sys.platform == "win32": 385 if sys.platform == "win32":
391 for storename in self._windows_cert_stores: 386 for storename in self._windows_cert_stores:
392 self._load_windows_store_certs(storename, purpose) 387 self._load_windows_store_certs(storename, purpose)
393 else: 388 self.set_default_verify_paths()
394 self.set_default_verify_paths()
395 389
396 390
397 def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, 391 def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
398 capath=None, cadata=None): 392 capath=None, cadata=None):
399 """Create a SSLContext object with default settings. 393 """Create a SSLContext object with default settings.
400 394
401 NOTE: The protocol and settings may change anytime without prior 395 NOTE: The protocol and settings may change anytime without prior
402 deprecation. The values represent a fair balance between maximum 396 deprecation. The values represent a fair balance between maximum
403 compatibility and security. 397 compatibility and security.
404 """ 398 """
(...skipping 47 matching lines...) Expand 10 before | Expand all | Expand 10 after
452 objects in order to keep common settings in one place. The configuration 446 objects in order to keep common settings in one place. The configuration
453 is less restrict than create_default_context()'s to increase backward 447 is less restrict than create_default_context()'s to increase backward
454 compatibility. 448 compatibility.
455 """ 449 """
456 if not isinstance(purpose, _ASN1Object): 450 if not isinstance(purpose, _ASN1Object):
457 raise TypeError(purpose) 451 raise TypeError(purpose)
458 452
459 context = SSLContext(protocol) 453 context = SSLContext(protocol)
460 # SSLv2 considered harmful. 454 # SSLv2 considered harmful.
461 context.options |= OP_NO_SSLv2 455 context.options |= OP_NO_SSLv2
456 # SSLv3 has problematic security and is only required for really old
457 # clients such as IE6 on Windows XP
458 context.options |= OP_NO_SSLv3
462 459
463 if cert_reqs is not None: 460 if cert_reqs is not None:
464 context.verify_mode = cert_reqs 461 context.verify_mode = cert_reqs
465 context.check_hostname = check_hostname 462 context.check_hostname = check_hostname
466 463
467 if keyfile and not certfile: 464 if keyfile and not certfile:
468 raise ValueError("certfile must be specified") 465 raise ValueError("certfile must be specified")
469 if certfile or keyfile: 466 if certfile or keyfile:
470 context.load_cert_chain(certfile, keyfile) 467 context.load_cert_chain(certfile, keyfile)
471 468
472 # load CA root certs 469 # load CA root certs
473 if cafile or capath or cadata: 470 if cafile or capath or cadata:
474 context.load_verify_locations(cafile, capath, cadata) 471 context.load_verify_locations(cafile, capath, cadata)
475 elif context.verify_mode != CERT_NONE: 472 elif context.verify_mode != CERT_NONE:
476 # no explicit cafile, capath or cadata but the verify mode is 473 # no explicit cafile, capath or cadata but the verify mode is
477 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system 474 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
478 # root CA certificates for the given purpose. This may fail silently. 475 # root CA certificates for the given purpose. This may fail silently.
479 context.load_default_certs(purpose) 476 context.load_default_certs(purpose)
480 477
481 return context 478 return context
482 479
483 # PEP 476 target for monkeypatching hack that reverts to old behaviour 480
481 class SSLObject:
482 """This class implements an interface on top of a low-level SSL object as
483 implemented by OpenSSL. This object captures the state of an SSL connection
484 but does not provide any network IO itself. IO needs to be performed
485 through separate "BIO" objects which are OpenSSL's IO abstraction layer.
486
487 This class does not have a public constructor. Instances are returned by
488 ``SSLContext.wrap_bio``. This class is typically used by framework authors
489 that want to implement asynchronous IO for SSL through memory buffers.
490
491 When compared to ``SSLSocket``, this object lacks the following features:
492
493 * Any form of network IO incluging methods such as ``recv`` and ``send``.
494 * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
495 """
496
497 def __init__(self, sslobj, owner=None):
498 self._sslobj = sslobj
499 # Note: _sslobj takes a weak reference to owner
500 self._sslobj.owner = owner or self
501
502 @property
503 def context(self):
504 """The SSLContext that is currently in use."""
505 return self._sslobj.context
506
507 @context.setter
508 def context(self, ctx):
509 self._sslobj.context = ctx
510
511 @property
512 def server_side(self):
513 """Whether this is a server-side socket."""
514 return self._sslobj.server_side
515
516 @property
517 def server_hostname(self):
518 """The currently set server hostname (for SNI), or ``None`` if no
519 server hostame is set."""
520 return self._sslobj.server_hostname
521
522 def read(self, len=0, buffer=None):
523 """Read up to 'len' bytes from the SSL object and return them.
524
525 If 'buffer' is provided, read into this buffer and return the number of
526 bytes read.
527 """
528 if buffer is not None:
529 v = self._sslobj.read(len, buffer)
530 else:
531 v = self._sslobj.read(len or 1024)
532 return v
533
534 def write(self, data):
535 """Write 'data' to the SSL object and return the number of bytes
536 written.
537
538 The 'data' argument must support the buffer interface.
539 """
540 return self._sslobj.write(data)
541
542 def getpeercert(self, binary_form=False):
543 """Returns a formatted version of the data in the certificate provided
544 by the other end of the SSL channel.
545
546 Return None if no certificate was provided, {} if a certificate was
547 provided, but not validated.
548 """
549 return self._sslobj.peer_certificate(binary_form)
550
551 def selected_npn_protocol(self):
552 """Return the currently selected NPN protocol as a string, or ``None``
553 if a next protocol was not negotiated or if NPN is not supported by one
554 of the peers."""
555 if _ssl.HAS_NPN:
556 return self._sslobj.selected_npn_protocol()
557
558 def cipher(self):
559 """Return the currently selected cipher as a 3-tuple ``(name,
560 ssl_version, secret_bits)``."""
561 return self._sslobj.cipher()
562
563 def compression(self):
564 """Return the current compression algorithm in use, or ``None`` if
565 compression was not negotiated or not supported by one of the peers."""
566 return self._sslobj.compression()
567
568 def pending(self):
569 """Return the number of bytes that can be read immediately."""
570 return self._sslobj.pending()
571
572 def do_handshake(self):
573 """Start the SSL/TLS handshake."""
574 self._sslobj.do_handshake()
575 if self.context.check_hostname:
576 if not self.server_hostname:
577 raise ValueError("check_hostname needs server_hostname "
578 "argument")
579 match_hostname(self.getpeercert(), self.server_hostname)
580
581 def unwrap(self):
582 """Start the SSL shutdown handshake."""
583 return self._sslobj.shutdown()
584
585 def get_channel_binding(self, cb_type="tls-unique"):
586 """Get channel binding data for current connection. Raise ValueError
587 if the requested `cb_type` is not supported. Return bytes of the data
588 or None if the data is not available (e.g. before the handshake)."""
589 if cb_type not in CHANNEL_BINDING_TYPES:
590 raise ValueError("Unsupported channel binding type")
591 if cb_type != "tls-unique":
592 raise NotImplementedError(
593 "{0} channel binding type not implemented"
594 .format(cb_type))
595 return self._sslobj.tls_unique_cb()
596
597 def version(self):
598 """Return a string identifying the protocol version used by the
599 current SSL channel. """
600 return self._sslobj.version()
601
602
603 # Used by http.client if no context is explicitly passed.
484 _create_default_https_context = create_default_context 604 _create_default_https_context = create_default_context
AntoinePitrou 2014/09/18 15:41:59 Compared to _create_stdlib_context, this changes o
485 # To revert back to the old behaviour, monkeypatch the ssl module: 605
486 # ssl._create_default_https_context = ssl._create_unverified_context 606
487 607 # Backwards compatibility alias, even though it's not a public name.
488
489 # Minimise impact of PEP 476 patch on other modules in 3.4 and 2.7
490 # by providing a backwards compatibility alias for the old private name
491 _create_stdlib_context = _create_unverified_context 608 _create_stdlib_context = _create_unverified_context
492 609
493 610
494 class SSLSocket(socket): 611 class SSLSocket(socket):
495 """This class implements a subtype of socket.socket that wraps 612 """This class implements a subtype of socket.socket that wraps
496 the underlying OS socket in an SSL context when necessary, and 613 the underlying OS socket in an SSL context when necessary, and
497 provides read and write methods over that channel.""" 614 provides read and write methods over that channel."""
498 615
499 def __init__(self, sock=None, keyfile=None, certfile=None, 616 def __init__(self, sock=None, keyfile=None, certfile=None,
500 server_side=False, cert_reqs=CERT_NONE, 617 server_side=False, cert_reqs=CERT_NONE,
(...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after
571 connected = False 688 connected = False
572 else: 689 else:
573 connected = True 690 connected = True
574 691
575 self._closed = False 692 self._closed = False
576 self._sslobj = None 693 self._sslobj = None
577 self._connected = connected 694 self._connected = connected
578 if connected: 695 if connected:
579 # create the SSL object 696 # create the SSL object
580 try: 697 try:
581 self._sslobj = self._context._wrap_socket(self, server_side, 698 sslobj = self._context._wrap_socket(self, server_side,
582 server_hostname) 699 server_hostname)
700 self._sslobj = SSLObject(sslobj, owner=self)
583 if do_handshake_on_connect: 701 if do_handshake_on_connect:
584 timeout = self.gettimeout() 702 timeout = self.gettimeout()
585 if timeout == 0.0: 703 if timeout == 0.0:
586 # non-blocking 704 # non-blocking
587 raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") 705 raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
588 self.do_handshake() 706 self.do_handshake()
589 707
590 except (OSError, ValueError): 708 except (OSError, ValueError):
591 self.close() 709 self.close()
592 raise 710 raise
(...skipping 24 matching lines...) Expand all
617 self.getpeername() 735 self.getpeername()
618 736
619 def read(self, len=0, buffer=None): 737 def read(self, len=0, buffer=None):
620 """Read up to LEN bytes and return them. 738 """Read up to LEN bytes and return them.
621 Return zero-length string on EOF.""" 739 Return zero-length string on EOF."""
622 740
623 self._checkClosed() 741 self._checkClosed()
624 if not self._sslobj: 742 if not self._sslobj:
625 raise ValueError("Read on closed or unwrapped SSL socket.") 743 raise ValueError("Read on closed or unwrapped SSL socket.")
626 try: 744 try:
627 if buffer is not None: 745 return self._sslobj.read(len, buffer)
628 v = self._sslobj.read(len, buffer)
629 else:
630 v = self._sslobj.read(len or 1024)
631 return v
632 except SSLError as x: 746 except SSLError as x:
633 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: 747 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
634 if buffer is not None: 748 if buffer is not None:
635 return 0 749 return 0
636 else: 750 else:
637 return b'' 751 return b''
638 else: 752 else:
639 raise 753 raise
640 754
641 def write(self, data): 755 def write(self, data):
642 """Write DATA to the underlying SSL channel. Returns 756 """Write DATA to the underlying SSL channel. Returns
643 number of bytes of DATA actually transmitted.""" 757 number of bytes of DATA actually transmitted."""
644 758
645 self._checkClosed() 759 self._checkClosed()
646 if not self._sslobj: 760 if not self._sslobj:
647 raise ValueError("Write on closed or unwrapped SSL socket.") 761 raise ValueError("Write on closed or unwrapped SSL socket.")
648 return self._sslobj.write(data) 762 return self._sslobj.write(data)
649 763
650 def getpeercert(self, binary_form=False): 764 def getpeercert(self, binary_form=False):
651 """Returns a formatted version of the data in the 765 """Returns a formatted version of the data in the
652 certificate provided by the other end of the SSL channel. 766 certificate provided by the other end of the SSL channel.
653 Return None if no certificate was provided, {} if a 767 Return None if no certificate was provided, {} if a
654 certificate was provided, but not validated.""" 768 certificate was provided, but not validated."""
655 769
656 self._checkClosed() 770 self._checkClosed()
657 self._check_connected() 771 self._check_connected()
658 return self._sslobj.peer_certificate(binary_form) 772 return self._sslobj.getpeercert(binary_form)
659 773
660 def selected_npn_protocol(self): 774 def selected_npn_protocol(self):
661 self._checkClosed() 775 self._checkClosed()
662 if not self._sslobj or not _ssl.HAS_NPN: 776 if not self._sslobj or not _ssl.HAS_NPN:
663 return None 777 return None
664 else: 778 else:
665 return self._sslobj.selected_npn_protocol() 779 return self._sslobj.selected_npn_protocol()
666 780
667 def cipher(self): 781 def cipher(self):
668 self._checkClosed() 782 self._checkClosed()
669 if not self._sslobj: 783 if not self._sslobj:
670 return None 784 return None
671 else: 785 else:
672 return self._sslobj.cipher() 786 return self._sslobj.cipher()
673 787
674 def compression(self): 788 def compression(self):
675 self._checkClosed() 789 self._checkClosed()
676 if not self._sslobj: 790 if not self._sslobj:
677 return None 791 return None
678 else: 792 else:
679 return self._sslobj.compression() 793 return self._sslobj.compression()
680 794
681 def send(self, data, flags=0): 795 def send(self, data, flags=0):
682 self._checkClosed() 796 self._checkClosed()
683 if self._sslobj: 797 if self._sslobj:
684 if flags != 0: 798 if flags != 0:
685 raise ValueError( 799 raise ValueError(
686 "non-zero flags not allowed in calls to send() on %s" % 800 "non-zero flags not allowed in calls to send() on %s" %
687 self.__class__) 801 self.__class__)
688 try: 802 return self._sslobj.write(data)
689 v = self._sslobj.write(data)
690 except SSLError as x:
691 if x.args[0] == SSL_ERROR_WANT_READ:
692 return 0
693 elif x.args[0] == SSL_ERROR_WANT_WRITE:
694 return 0
695 else:
696 raise
697 else:
698 return v
699 else: 803 else:
700 return socket.send(self, data, flags) 804 return socket.send(self, data, flags)
701 805
702 def sendto(self, data, flags_or_addr, addr=None): 806 def sendto(self, data, flags_or_addr, addr=None):
703 self._checkClosed() 807 self._checkClosed()
704 if self._sslobj: 808 if self._sslobj:
705 raise ValueError("sendto not allowed on instances of %s" % 809 raise ValueError("sendto not allowed on instances of %s" %
706 self.__class__) 810 self.__class__)
707 elif addr is None: 811 elif addr is None:
708 return socket.sendto(self, data, flags_or_addr) 812 return socket.sendto(self, data, flags_or_addr)
(...skipping 15 matching lines...) Expand all
724 self.__class__) 828 self.__class__)
725 amount = len(data) 829 amount = len(data)
726 count = 0 830 count = 0
727 while (count < amount): 831 while (count < amount):
728 v = self.send(data[count:]) 832 v = self.send(data[count:])
729 count += v 833 count += v
730 return amount 834 return amount
731 else: 835 else:
732 return socket.sendall(self, data, flags) 836 return socket.sendall(self, data, flags)
733 837
838 def sendfile(self, file, offset=0, count=None):
839 """Send a file, possibly by using os.sendfile() if this is a
840 clear-text socket. Return the total number of bytes sent.
841 """
842 if self._sslobj is None:
843 # os.sendfile() works with plain sockets only
844 return super().sendfile(file, offset, count)
845 else:
846 return self._sendfile_use_send(file, offset, count)
847
734 def recv(self, buflen=1024, flags=0): 848 def recv(self, buflen=1024, flags=0):
735 self._checkClosed() 849 self._checkClosed()
736 if self._sslobj: 850 if self._sslobj:
737 if flags != 0: 851 if flags != 0:
738 raise ValueError( 852 raise ValueError(
739 "non-zero flags not allowed in calls to recv() on %s" % 853 "non-zero flags not allowed in calls to recv() on %s" %
740 self.__class__) 854 self.__class__)
741 return self.read(buflen) 855 return self.read(buflen)
742 else: 856 else:
743 return socket.recv(self, buflen, flags) 857 return socket.recv(self, buflen, flags)
(...skipping 44 matching lines...) Expand 10 before | Expand all | Expand 10 after
788 else: 902 else:
789 return 0 903 return 0
790 904
791 def shutdown(self, how): 905 def shutdown(self, how):
792 self._checkClosed() 906 self._checkClosed()
793 self._sslobj = None 907 self._sslobj = None
794 socket.shutdown(self, how) 908 socket.shutdown(self, how)
795 909
796 def unwrap(self): 910 def unwrap(self):
797 if self._sslobj: 911 if self._sslobj:
798 s = self._sslobj.shutdown() 912 s = self._sslobj.unwrap()
799 self._sslobj = None 913 self._sslobj = None
800 return s 914 return s
801 else: 915 else:
802 raise ValueError("No SSL wrapper around " + str(self)) 916 raise ValueError("No SSL wrapper around " + str(self))
803 917
804 def _real_close(self): 918 def _real_close(self):
805 self._sslobj = None 919 self._sslobj = None
806 socket._real_close(self) 920 socket._real_close(self)
807 921
808 def do_handshake(self, block=False): 922 def do_handshake(self, block=False):
809 """Perform a TLS/SSL handshake.""" 923 """Perform a TLS/SSL handshake."""
810 self._check_connected() 924 self._check_connected()
811 timeout = self.gettimeout() 925 timeout = self.gettimeout()
812 try: 926 try:
813 if timeout == 0.0 and block: 927 if timeout == 0.0 and block:
814 self.settimeout(None) 928 self.settimeout(None)
815 self._sslobj.do_handshake() 929 self._sslobj.do_handshake()
816 finally: 930 finally:
817 self.settimeout(timeout) 931 self.settimeout(timeout)
818 932
819 if self.context.check_hostname:
820 if not self.server_hostname:
821 raise ValueError("check_hostname needs server_hostname "
822 "argument")
823 match_hostname(self.getpeercert(), self.server_hostname)
824
825 def _real_connect(self, addr, connect_ex): 933 def _real_connect(self, addr, connect_ex):
826 if self.server_side: 934 if self.server_side:
827 raise ValueError("can't connect in server-side mode") 935 raise ValueError("can't connect in server-side mode")
828 # Here we assume that the socket is client-side, and not 936 # Here we assume that the socket is client-side, and not
829 # connected at the time of the call. We connect it, then wrap it. 937 # connected at the time of the call. We connect it, then wrap it.
830 if self._connected: 938 if self._connected:
831 raise ValueError("attempt to connect already-connected SSLSocket!") 939 raise ValueError("attempt to connect already-connected SSLSocket!")
832 self._sslobj = self.context._wrap_socket(self, False, self.server_hostna me) 940 sslobj = self.context._wrap_socket(self, False, self.server_hostname)
941 self._sslobj = SSLObject(sslobj, owner=self)
833 try: 942 try:
834 if connect_ex: 943 if connect_ex:
835 rc = socket.connect_ex(self, addr) 944 rc = socket.connect_ex(self, addr)
836 else: 945 else:
837 rc = None 946 rc = None
838 socket.connect(self, addr) 947 socket.connect(self, addr)
839 if not rc: 948 if not rc:
840 self._connected = True 949 self._connected = True
841 if self.do_handshake_on_connect: 950 if self.do_handshake_on_connect:
842 self.do_handshake() 951 self.do_handshake()
(...skipping 22 matching lines...) Expand all
865 do_handshake_on_connect=self.do_handshake_on_connect, 974 do_handshake_on_connect=self.do_handshake_on_connect,
866 suppress_ragged_eofs=self.suppress_ragged_eofs, 975 suppress_ragged_eofs=self.suppress_ragged_eofs,
867 server_side=True) 976 server_side=True)
868 return newsock, addr 977 return newsock, addr
869 978
870 def get_channel_binding(self, cb_type="tls-unique"): 979 def get_channel_binding(self, cb_type="tls-unique"):
871 """Get channel binding data for current connection. Raise ValueError 980 """Get channel binding data for current connection. Raise ValueError
872 if the requested `cb_type` is not supported. Return bytes of the data 981 if the requested `cb_type` is not supported. Return bytes of the data
873 or None if the data is not available (e.g. before the handshake). 982 or None if the data is not available (e.g. before the handshake).
874 """ 983 """
875 if cb_type not in CHANNEL_BINDING_TYPES:
876 raise ValueError("Unsupported channel binding type")
877 if cb_type != "tls-unique":
878 raise NotImplementedError(
879 "{0} channel binding type not implemented"
880 .format(cb_type))
881 if self._sslobj is None: 984 if self._sslobj is None:
882 return None 985 return None
883 return self._sslobj.tls_unique_cb() 986 return self._sslobj.get_channel_binding(cb_type)
987
988 def version(self):
989 """
990 Return a string identifying the protocol version used by the
991 current SSL channel, or None if there is no established channel.
992 """
993 if self._sslobj is None:
994 return None
995 return self._sslobj.version()
884 996
885 997
886 def wrap_socket(sock, keyfile=None, certfile=None, 998 def wrap_socket(sock, keyfile=None, certfile=None,
887 server_side=False, cert_reqs=CERT_NONE, 999 server_side=False, cert_reqs=CERT_NONE,
888 ssl_version=PROTOCOL_SSLv23, ca_certs=None, 1000 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
889 do_handshake_on_connect=True, 1001 do_handshake_on_connect=True,
890 suppress_ragged_eofs=True, 1002 suppress_ragged_eofs=True,
891 ciphers=None): 1003 ciphers=None):
892 1004
893 return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, 1005 return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
894 server_side=server_side, cert_reqs=cert_reqs, 1006 server_side=server_side, cert_reqs=cert_reqs,
895 ssl_version=ssl_version, ca_certs=ca_certs, 1007 ssl_version=ssl_version, ca_certs=ca_certs,
896 do_handshake_on_connect=do_handshake_on_connect, 1008 do_handshake_on_connect=do_handshake_on_connect,
897 suppress_ragged_eofs=suppress_ragged_eofs, 1009 suppress_ragged_eofs=suppress_ragged_eofs,
898 ciphers=ciphers) 1010 ciphers=ciphers)
899 1011
900 # some utility functions 1012 # some utility functions
901 1013
902 def cert_time_to_seconds(cert_time): 1014 def cert_time_to_seconds(cert_time):
903 """Takes a date-time string in standard ASN1_print form 1015 """Return the time in seconds since the Epoch, given the timestring
904 ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return 1016 representing the "notBefore" or "notAfter" date from a certificate
905 a Python time value in seconds past the epoch.""" 1017 in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
906 1018
907 import time 1019 "notBefore" or "notAfter" dates must use UTC (RFC 5280).
908 return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) 1020
1021 Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1022 UTC should be specified as GMT (see ASN1_TIME_print())
1023 """
1024 from time import strptime
1025 from calendar import timegm
1026
1027 months = (
1028 "Jan","Feb","Mar","Apr","May","Jun",
1029 "Jul","Aug","Sep","Oct","Nov","Dec"
1030 )
1031 time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1032 try:
1033 month_number = months.index(cert_time[:3].title()) + 1
1034 except ValueError:
1035 raise ValueError('time data %r does not match '
1036 'format "%%b%s"' % (cert_time, time_format))
1037 else:
1038 # found valid month
1039 tt = strptime(cert_time[3:], time_format)
1040 # return an integer, the previous mktime()-based implementation
1041 # returned a float (fractional seconds are always zero here).
1042 return timegm((tt[0], month_number) + tt[2:6])
909 1043
910 PEM_HEADER = "-----BEGIN CERTIFICATE-----" 1044 PEM_HEADER = "-----BEGIN CERTIFICATE-----"
911 PEM_FOOTER = "-----END CERTIFICATE-----" 1045 PEM_FOOTER = "-----END CERTIFICATE-----"
912 1046
913 def DER_cert_to_PEM_cert(der_cert_bytes): 1047 def DER_cert_to_PEM_cert(der_cert_bytes):
914 """Takes a certificate in binary DER format and returns the 1048 """Takes a certificate in binary DER format and returns the
915 PEM version of it as a string.""" 1049 PEM version of it as a string."""
916 1050
917 f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict') 1051 f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
918 return (PEM_HEADER + '\n' + 1052 return (PEM_HEADER + '\n' +
919 textwrap.fill(f, 64) + '\n' + 1053 textwrap.fill(f, 64) + '\n' +
920 PEM_FOOTER + '\n') 1054 PEM_FOOTER + '\n')
921 1055
922 def PEM_cert_to_DER_cert(pem_cert_string): 1056 def PEM_cert_to_DER_cert(pem_cert_string):
923 """Takes a certificate in ASCII PEM format and returns the 1057 """Takes a certificate in ASCII PEM format and returns the
924 DER-encoded version of it as a byte sequence""" 1058 DER-encoded version of it as a byte sequence"""
925 1059
926 if not pem_cert_string.startswith(PEM_HEADER): 1060 if not pem_cert_string.startswith(PEM_HEADER):
927 raise ValueError("Invalid PEM encoding; must start with %s" 1061 raise ValueError("Invalid PEM encoding; must start with %s"
928 % PEM_HEADER) 1062 % PEM_HEADER)
929 if not pem_cert_string.strip().endswith(PEM_FOOTER): 1063 if not pem_cert_string.strip().endswith(PEM_FOOTER):
930 raise ValueError("Invalid PEM encoding; must end with %s" 1064 raise ValueError("Invalid PEM encoding; must end with %s"
931 % PEM_FOOTER) 1065 % PEM_FOOTER)
932 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] 1066 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
933 return base64.decodebytes(d.encode('ASCII', 'strict')) 1067 return base64.decodebytes(d.encode('ASCII', 'strict'))
934 1068
935 def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): 1069 def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
936 """Retrieve the certificate from the server at the specified address, 1070 """Retrieve the certificate from the server at the specified address,
937 and return it as a PEM-encoded string. 1071 and return it as a PEM-encoded string.
938 If 'ca_certs' is specified, validate the server cert against it. 1072 If 'ca_certs' is specified, validate the server cert against it.
939 If 'ssl_version' is specified, use it in the connection attempt.""" 1073 If 'ssl_version' is specified, use it in the connection attempt."""
940 1074
941 host, port = addr 1075 host, port = addr
942 if ca_certs is not None: 1076 if ca_certs is not None:
943 cert_reqs = CERT_REQUIRED 1077 cert_reqs = CERT_REQUIRED
944 else: 1078 else:
945 cert_reqs = CERT_NONE 1079 cert_reqs = CERT_NONE
946 context = _create_stdlib_context(ssl_version, 1080 context = _create_stdlib_context(ssl_version,
947 cert_reqs=cert_reqs, 1081 cert_reqs=cert_reqs,
948 cafile=ca_certs) 1082 cafile=ca_certs)
949 with create_connection(addr) as sock: 1083 with create_connection(addr) as sock:
950 with context.wrap_socket(sock) as sslsock: 1084 with context.wrap_socket(sock) as sslsock:
951 dercert = sslsock.getpeercert(True) 1085 dercert = sslsock.getpeercert(True)
952 return DER_cert_to_PEM_cert(dercert) 1086 return DER_cert_to_PEM_cert(dercert)
953 1087
954 def get_protocol_name(protocol_code): 1088 def get_protocol_name(protocol_code):
955 return _PROTOCOL_NAMES.get(protocol_code, '<unknown>') 1089 return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
LEFTRIGHT

RSS Feeds Recent Issues | This issue
This is Rietveld 894c83f36cb7+