Skip to content

Commit 7a1c3b0

Browse files
committed
Guard shutdowns and clean up code
1 parent 2c4ba3c commit 7a1c3b0

5 files changed

Lines changed: 97 additions & 23 deletions

File tree

examples/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def main():
140140
context = wolfssl.SSLContext(get_SSLmethod(args.v))
141141

142142
# enable debug, if native wolfSSL has been compiled with '--enable-debug'
143-
wolfssl.WolfSSL.enable_debug()
143+
try:
144+
wolfssl.WolfSSL.enable_debug()
145+
except RuntimeError:
146+
pass
144147

145148
context.load_cert_chain(args.c, args.k)
146149

examples/server.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def main():
119119
args = build_arg_parser().parse_args()
120120
# DTLS connection over UDP
121121
if args.u:
122-
# Set DTLSv1.2 as default if unspecified
123-
if args.v == 5:
122+
# Set DTLSv1.2 as default if unspecified
123+
if args.v > 2:
124124
args.v = 1
125125
bind_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
126126
bind_socket.bind(("" if args.b else "localhost", args.p))
@@ -136,7 +136,10 @@ def main():
136136
print("Server listening on port", bind_socket.getsockname()[1])
137137

138138
# enable debug, if native wolfSSL has been compiled with '--enable-debug'
139-
wolfssl.WolfSSL.enable_debug()
139+
try:
140+
wolfssl.WolfSSL.enable_debug()
141+
except RuntimeError:
142+
pass
140143

141144
context.load_cert_chain(args.c, args.k)
142145

@@ -170,6 +173,9 @@ def main():
170173
finally:
171174
if secure_socket:
172175
secure_socket.shutdown(socket.SHUT_RDWR)
176+
# Don't close for DTLS - secure_socket wraps the
177+
# shared bind_socket which is needed for
178+
# subsequent connections
173179
if not args.u:
174180
secure_socket.close()
175181

tests/test_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,47 @@ def test_get_version(ssl_server, ssl_version, tcp_socket):
9393
secure_socket.read(1024)
9494

9595

96+
def test_close_after_connected(ssl_server, tcp_socket):
97+
ctx = wolfssl.SSLContext(wolfssl.PROTOCOL_TLSv1_2)
98+
sock = ctx.wrap_socket(tcp_socket)
99+
sock.connect(('127.0.0.1', ssl_server.port))
100+
sock.write(b'hello wolfssl')
101+
sock.read(1024)
102+
sock.close()
103+
104+
105+
def test_recv_into_nbytes_zero(ssl_server, tcp_socket):
106+
ctx = wolfssl.SSLContext(wolfssl.PROTOCOL_TLSv1_2)
107+
sock = ctx.wrap_socket(tcp_socket)
108+
sock.connect(('127.0.0.1', ssl_server.port))
109+
sock.write(b'hello wolfssl')
110+
buf = bytearray(1024)
111+
n = sock.recv_into(buf, 0)
112+
assert n > 0
113+
sock.close()
114+
115+
116+
def test_unwrap_returns_socket(ssl_server, tcp_socket):
117+
import socket as _socket
118+
ctx = wolfssl.SSLContext(wolfssl.PROTOCOL_TLSv1_2)
119+
sock = ctx.wrap_socket(tcp_socket)
120+
sock.connect(('127.0.0.1', ssl_server.port))
121+
sock.write(b'hello wolfssl')
122+
sock.read(1024)
123+
raw = sock.unwrap()
124+
assert isinstance(raw, _socket.socket)
125+
raw.close()
126+
127+
128+
def test_sendall_large_buffer(ssl_server, tcp_socket):
129+
ctx = wolfssl.SSLContext(wolfssl.PROTOCOL_TLSv1_2)
130+
sock = ctx.wrap_socket(tcp_socket)
131+
sock.connect(('127.0.0.1', ssl_server.port))
132+
sock.sendall(b'x' * 8192)
133+
sock.read(1024)
134+
sock.close()
135+
136+
96137
def test_client_cert_verification_failure():
97138
"""
98139
Test that a connection fails when the server requires client certificates

tests/test_context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,21 @@ def test_check_hostname_requires_cert_required(ssl_provider, ssl_context):
8888
def test_wrap_socket_server_side_mismatch(ssl_context, tcp_socket):
8989
with pytest.raises(ValueError):
9090
ssl_context.wrap_socket(tcp_socket, server_side=True)
91+
92+
93+
def test_close_without_handshake(ssl_context, tcp_socket):
94+
sock = ssl_context.wrap_socket(tcp_socket)
95+
sock.close()
96+
97+
98+
def test_close_releases_native_object(ssl_context, tcp_socket):
99+
sock = ssl_context.wrap_socket(tcp_socket)
100+
sock.close()
101+
sock.close()
102+
103+
104+
def test_operations_after_close_raise(ssl_context, tcp_socket):
105+
sock = ssl_context.wrap_socket(tcp_socket)
106+
sock.close()
107+
with pytest.raises(ValueError):
108+
sock.read()

wolfssl/__init__.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,6 @@ def sendall(self, data, flags=0):
612612

613613
while sent < length:
614614
ret = self.write(data[sent:])
615-
616615
sent += ret
617616

618617
return None
@@ -736,11 +735,15 @@ def unwrap(self):
736735
Returns the wrapped OS socket.
737736
"""
738737
if self.native_object != _ffi.NULL:
739-
_lib.wolfSSL_shutdown(self.native_object)
738+
if self._connected:
739+
# Single-step shutdown is intentional; any
740+
# bidirectional close_notify exchange is the
741+
# caller's responsibility on the raw socket.
742+
_lib.wolfSSL_shutdown(self.native_object)
740743
self._release_native_object()
741744

742745
sock = socket(family=self._sock.family,
743-
sock_type=self._sock.type,
746+
type=self._sock.type,
744747
proto=self._sock.proto,
745748
fileno=self._sock.fileno())
746749

@@ -750,19 +753,19 @@ def unwrap(self):
750753
return sock
751754

752755
def add_peer(self, addr):
753-
peerAddr = _lib.wolfSSL_dtls_create_peer(addr[1],t2b(addr[0]))
754-
if peerAddr == _ffi.NULL:
755-
raise SSLError("Failed to create peer")
756-
try:
757-
ret = _lib.wolfSSL_dtls_set_peer(
758-
self.native_object, peerAddr,
759-
_SOCKADDR_SZ)
760-
if ret != _SSL_SUCCESS:
761-
raise SSLError(
762-
"Unable to set dtls peer."
763-
" E(%d)" % ret)
764-
finally:
765-
_lib.wolfSSL_dtls_free_peer(peerAddr)
756+
peerAddr = _lib.wolfSSL_dtls_create_peer(addr[1], t2b(addr[0]))
757+
if peerAddr == _ffi.NULL:
758+
raise SSLError("Failed to create peer")
759+
try:
760+
ret = _lib.wolfSSL_dtls_set_peer(
761+
self.native_object, peerAddr,
762+
_SOCKADDR_SZ)
763+
if ret != _SSL_SUCCESS:
764+
raise SSLError(
765+
"Unable to set dtls peer."
766+
" E(%d)" % ret)
767+
finally:
768+
_lib.wolfSSL_dtls_free_peer(peerAddr)
766769

767770
def do_handshake(self, block=False): # pylint: disable=unused-argument
768771
"""
@@ -912,7 +915,11 @@ def version(self):
912915
# API and are provided here for compatibility.
913916
def close(self):
914917
if self.native_object != _ffi.NULL:
915-
_lib.wolfSSL_shutdown(self.native_object)
918+
if self._connected:
919+
# Single-step shutdown is intentional here; the
920+
# socket is about to be closed so a bidirectional
921+
# close_notify exchange is not required.
922+
_lib.wolfSSL_shutdown(self.native_object)
916923
self._release_native_object()
917924
self._sock.close()
918925

@@ -1048,8 +1055,7 @@ def _get_passwd(self, passwd, sz, rw, userdata):
10481055
"Problem getting password from callback")
10491056
if not isinstance(result, bytes):
10501057
raise ValueError(
1051-
"Password callback must return bytes,"
1052-
" not str")
1058+
"Password callback must return bytes")
10531059
if len(result) > sz:
10541060
raise ValueError(
10551061
"Problem with password returned"

0 commit comments

Comments
 (0)