5151from wolfssl ._methods import ( # noqa: F401
5252 PROTOCOL_SSLv23 , PROTOCOL_SSLv3 , PROTOCOL_TLSv1 ,
5353 PROTOCOL_TLSv1_1 , PROTOCOL_TLSv1_2 , PROTOCOL_TLSv1_3 ,
54- PROTOCOL_TLS , WolfSSLMethod as _WolfSSLMethod
54+ PROTOCOL_TLS , PROTOCOL_DTLSv1 , PROTOCOL_DTLSv1_2 ,
55+ WolfSSLMethod as _WolfSSLMethod
5556)
5657
5758CERT_NONE = 0
6465_SSL_ERROR_WANT_READ = 2
6566_SSL_ERROR_WANT_WRITE = 3
6667
68+ _SOCKADDR_SZ = 16
69+
6770_PY3 = sys .version_info [0 ] == 3
6871
6972
@@ -519,7 +522,6 @@ def enable_crl(self, options):
519522 """
520523 Enables CRL certificate revocation
521524 """
522-
523525 ret = _lib .wolfSSL_EnableCRL (self .native_object , options )
524526
525527 if ret != _SSL_SUCCESS :
@@ -529,7 +531,6 @@ def load_crl_file(self, path, filetype):
529531 """
530532 Load CRL certificate revocation
531533 """
532-
533534 ret = _lib .wolfSSL_LoadCRLFile (self .native_object ,
534535 t2b (path ) if path else _ffi .NULL ,
535536 filetype )
@@ -543,7 +544,12 @@ def write(self, data):
543544 Returns number of bytes of DATA actually transmitted.
544545 """
545546 self ._check_closed ("write" )
546- self ._check_connected ()
547+ # Check connected if not DTLS
548+ if self ._context .protocol < PROTOCOL_DTLSv1 :
549+ self ._check_connected ()
550+ # Complete handshake if DTLS connection
551+ else :
552+ self .do_handshake ()
547553
548554 data = t2b (data )
549555
@@ -599,7 +605,12 @@ def read(self, length=1024, buffer=None):
599605 Return zero-length string on EOF.
600606 """
601607 self ._check_closed ("read" )
602- self ._check_connected ()
608+ # Check connected if not DTLS
609+ if self ._context .protocol < PROTOCOL_DTLSv1 :
610+ self ._check_connected ()
611+ # Complete handshake if DTLS connection
612+ else :
613+ self .do_handshake ()
603614
604615 if buffer is not None :
605616 raise ValueError ("buffer not allowed in calls to "
@@ -630,7 +641,8 @@ def recv_into(self, buffer, nbytes=None, flags=0):
630641 to full size of buffer.
631642 """
632643 self ._check_closed ("read" )
633- self ._check_connected ()
644+ if self ._context .protocol < PROTOCOL_DTLSv1 :
645+ self ._check_connected ()
634646
635647 if buffer is None :
636648 raise ValueError ("buffer cannot be None" )
@@ -678,7 +690,8 @@ def shutdown(self, how):
678690 if self .native_object != _ffi .NULL :
679691 _lib .wolfSSL_shutdown (self .native_object )
680692 self ._release_native_object ()
681- self ._sock .shutdown (how )
693+ if self ._context .protocol < PROTOCOL_DTLSv1 :
694+ self ._sock .shutdown (how )
682695
683696 def unwrap (self ):
684697 """
@@ -698,12 +711,23 @@ def unwrap(self):
698711
699712 return sock
700713
714+ def add_peer (self , addr ):
715+ peerAddr = _lib .wolfSSL_dtls_create_peer (addr [1 ],t2b (addr [0 ]))
716+ if peerAddr == _ffi .NULL :
717+ raise SSLError ("Failed to create peer" )
718+ ret = _lib .wolfSSL_dtls_set_peer (self .native_object , peerAddr ,
719+ _SOCKADDR_SZ )
720+ if ret != _SSL_SUCCESS :
721+ raise SSLError ("Unable to set dtls peer. E(%d)" % ret )
722+ _lib .wolfSSL_dtls_free_peer (peerAddr )
723+
701724 def do_handshake (self , block = False ): # pylint: disable=unused-argument
702725 """
703726 Perform a TLS/SSL handshake.
704727 """
705728 self ._check_closed ("do_handshake" )
706- self ._check_connected ()
729+ if self ._context .protocol < PROTOCOL_DTLSv1 :
730+ self ._check_connected ()
707731
708732 if self ._server_side :
709733 ret = _lib .wolfSSL_accept (self .native_object )
@@ -756,13 +780,19 @@ def _real_connect(self, addr, connect_ex):
756780 if self ._connected :
757781 raise ValueError ("attempt to connect already-connected SSLSocket!" )
758782
759- if connect_ex :
760- err = self ._sock .connect_ex (addr )
783+ err = 0
784+ ret = _SSL_SUCCESS
785+
786+ if self ._context .protocol >= PROTOCOL_DTLSv1 :
787+ self .add_peer (addr )
761788 else :
762- err = 0
763- self ._sock .connect (addr )
789+ if connect_ex :
790+ err = self ._sock .connect_ex (addr )
791+ else :
792+ err = 0
793+ self ._sock .connect (addr )
764794
765- if err == 0 :
795+ if err == 0 and ret == _SSL_SUCCESS :
766796 self ._connected = True
767797 if self .do_handshake_on_connect :
768798 self .do_handshake ()
0 commit comments