From b22331f3528535255b648f3b8faf60ee9902ba1c Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Thu, 5 Mar 2026 17:02:32 -0700 Subject: [PATCH] JSSE: add PSK support via JNI callbacks and WolfSSLParameters --- examples/provider/PskClientEngine.java | 259 +++++++ examples/provider/PskClientEngine.sh | 5 + examples/provider/PskClientSocket.java | 170 +++++ examples/provider/PskClientSocket.sh | 5 + examples/provider/PskServerEngine.java | 261 +++++++ examples/provider/PskServerEngine.sh | 5 + examples/provider/PskServerSocket.java | 174 +++++ examples/provider/PskServerSocket.sh | 5 + examples/provider/README.md | 31 + native/com_wolfssl_WolfSSLContext.c | 248 ++++--- .../provider/jsse/WolfSSLEngineHelper.java | 101 ++- .../provider/jsse/WolfSSLParameters.java | 267 +++++-- .../jsse/WolfSSLParametersHelper.java | 38 + .../jsse/adapter/WolfSSLJDK8Helper.java | 6 +- .../jsse/test/WolfSSLJSSETestSuite.java | 3 +- .../jsse/test/WolfSSLParametersPskTest.java | 690 ++++++++++++++++++ 16 files changed, 2070 insertions(+), 198 deletions(-) create mode 100644 examples/provider/PskClientEngine.java create mode 100755 examples/provider/PskClientEngine.sh create mode 100644 examples/provider/PskClientSocket.java create mode 100755 examples/provider/PskClientSocket.sh create mode 100644 examples/provider/PskServerEngine.java create mode 100755 examples/provider/PskServerEngine.sh create mode 100644 examples/provider/PskServerSocket.java create mode 100755 examples/provider/PskServerSocket.sh create mode 100644 src/test/com/wolfssl/provider/jsse/test/WolfSSLParametersPskTest.java diff --git a/examples/provider/PskClientEngine.java b/examples/provider/PskClientEngine.java new file mode 100644 index 00000000..849a35c7 --- /dev/null +++ b/examples/provider/PskClientEngine.java @@ -0,0 +1,259 @@ +/* PskClientEngine.java + * + * Copyright (C) 2006-2026 wolfSSL Inc. + * + * This file is part of wolfSSL. + * + * wolfSSL is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * wolfSSL is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA + */ + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.security.Security; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLSession; + +import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLPskClientCallback; +import com.wolfssl.provider.jsse.WolfSSLProvider; +import com.wolfssl.provider.jsse.WolfSSLParameters; + +/** + * Simple PSK client example using SSLEngine and WolfSSLParameters. + * + * This example demonstrates configuring PSK through WolfSSLParameters with + * SSLEngine. + * + * Usage: PskClientEngine [host] [port] + */ +public class PskClientEngine { + + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 11111; + + public static void main(String[] args) throws Exception { + + String host = DEFAULT_HOST; + int port = DEFAULT_PORT; + + if (args.length >= 1) { + host = args[0]; + } + if (args.length >= 2) { + port = Integer.parseInt(args[1]); + } + + /* Install wolfJSSE provider */ + Security.insertProviderAt(new WolfSSLProvider(), 1); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", "wolfJSSE"); + ctx.init(null, null, null); + + /* Create SSLEngine */ + SSLEngine engine = ctx.createSSLEngine(host, port); + engine.setUseClientMode(true); + + /* Find a PSK cipher suite available in this build */ + String pskCipher = findPskCipher(engine.getSupportedCipherSuites()); + + /* Configure PSK via WolfSSLParameters */ + WolfSSLParameters params = new WolfSSLParameters(); + params.setPskClientCb(new MyPskClientCallback()); + params.setCipherSuites(new String[]{pskCipher}); + engine.setSSLParameters(params); + System.out.println("Using cipher: " + pskCipher); + + /* Connect via SocketChannel */ + SocketChannel sc = SocketChannel.open( + new InetSocketAddress(host, port)); + System.out.println("Connected to " + host + ":" + port); + + try { + /* Perform handshake */ + doHandshake(engine, sc); + System.out.println("SSL handshake complete"); + SSLSession sess = engine.getSession(); + System.out.println(" Protocol: " + sess.getProtocol()); + System.out.println(" Cipher: " + sess.getCipherSuite()); + + /* Send application data */ + String msg = "Hello from PSK Engine client!"; + ByteBuffer appOut = ByteBuffer.wrap(msg.getBytes()); + ByteBuffer netOut = ByteBuffer.allocate( + sess.getPacketBufferSize()); + + SSLEngineResult res = engine.wrap(appOut, netOut); + netOut.flip(); + while (netOut.hasRemaining()) { + sc.write(netOut); + } + System.out.println("Sent: " + msg); + + /* Receive response */ + ByteBuffer netIn = ByteBuffer.allocate( + sess.getPacketBufferSize()); + ByteBuffer appIn = ByteBuffer.allocate( + sess.getApplicationBufferSize()); + + sc.read(netIn); + netIn.flip(); + res = engine.unwrap(netIn, appIn); + appIn.flip(); + byte[] data = new byte[appIn.remaining()]; + appIn.get(data); + System.out.println("Received: " + new String(data)); + + engine.closeOutbound(); + + } finally { + sc.close(); + } + + System.out.println("Connection closed"); + } + + /** + * Perform TLS handshake using SSLEngine over SocketChannel. + */ + private static void doHandshake(SSLEngine engine, SocketChannel sc) + throws Exception { + + SSLSession sess = engine.getSession(); + int netSize = sess.getPacketBufferSize(); + int appSize = sess.getApplicationBufferSize(); + + ByteBuffer localNet = ByteBuffer.allocate(netSize); + ByteBuffer peerNet = ByteBuffer.allocate(netSize); + ByteBuffer localApp = ByteBuffer.allocate(0); + ByteBuffer peerApp = ByteBuffer.allocate(appSize); + + engine.beginHandshake(); + HandshakeStatus hs = engine.getHandshakeStatus(); + + while (hs != HandshakeStatus.FINISHED && + hs != HandshakeStatus.NOT_HANDSHAKING) { + + SSLEngineResult res; + switch (hs) { + case NEED_WRAP: + localNet.clear(); + res = engine.wrap(localApp, localNet); + hs = res.getHandshakeStatus(); + localNet.flip(); + while (localNet.hasRemaining()) { + sc.write(localNet); + } + break; + + case NEED_UNWRAP: + if (sc.read(peerNet) < 0) { + throw new IOException( + "Channel closed during handshake"); + } + peerNet.flip(); + res = engine.unwrap(peerNet, peerApp); + peerNet.compact(); + hs = res.getHandshakeStatus(); + + if (res.getStatus() == + SSLEngineResult.Status.BUFFER_UNDERFLOW) { + /* Need more data, continue reading */ + continue; + } + break; + + case NEED_TASK: + Runnable task; + while ((task = engine.getDelegatedTask()) != null) { + task.run(); + } + hs = engine.getHandshakeStatus(); + break; + + default: + break; + } + } + } + + /** + * Find first available ephemeral PSK cipher suite from supported list. + * Prefers ECDHE over DHE, AES-GCM over others. Falls back to static PSK + * if no ephemeral suite is available. + */ + private static String findPskCipher(String[] suites) { + + String ecdhe = null; + String dhe = null; + String plain = null; + + for (String s : suites) { + if (s.startsWith("TLS_ECDHE_PSK_WITH_")) { + if (ecdhe == null || s.contains("GCM")) { + ecdhe = s; + } + } + else if (s.startsWith("TLS_DHE_PSK_WITH_")) { + if (dhe == null || s.contains("GCM")) { + dhe = s; + } + } + else if (s.startsWith("TLS_PSK_WITH_")) { + if (plain == null) { + plain = s; + } + } + } + + if (ecdhe != null) { return ecdhe; } + if (dhe != null) { return dhe; } + if (plain != null) { return plain; } + + throw new RuntimeException( + "No PSK cipher suites available. " + + "No PSK cipher suites compiled into wolfSSL"); + } + + /** + * PSK client callback implementation. + */ + static class MyPskClientCallback implements WolfSSLPskClientCallback { + + public long pskClientCallback(WolfSSLSession ssl, String hint, + StringBuffer identity, long idMaxLen, byte[] key, long keyMaxLen) { + + System.out.println("PSK Client Callback:"); + System.out.println(" Hint: " + hint); + + String id = "Client_identity"; + if (id.length() > idMaxLen || keyMaxLen < 4) { + return 0; + } + identity.append(id); + + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + + return 4; + } + } +} diff --git a/examples/provider/PskClientEngine.sh b/examples/provider/PskClientEngine.sh new file mode 100755 index 00000000..42e10a5d --- /dev/null +++ b/examples/provider/PskClientEngine.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd ./examples/build +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:../../lib/:/usr/local/lib +java -classpath ../../lib/wolfssl.jar:../../lib/wolfssl-jsse.jar:./ -Dsun.boot.library.path=../../lib/ PskClientEngine "$@" diff --git a/examples/provider/PskClientSocket.java b/examples/provider/PskClientSocket.java new file mode 100644 index 00000000..5762348c --- /dev/null +++ b/examples/provider/PskClientSocket.java @@ -0,0 +1,170 @@ +/* PskClientSocket.java + * + * Copyright (C) 2006-2026 wolfSSL Inc. + * + * This file is part of wolfSSL. + * + * wolfSSL is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * wolfSSL is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA + */ + +import java.io.InputStream; +import java.io.OutputStream; +import java.security.Security; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; + +import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLPskClientCallback; +import com.wolfssl.provider.jsse.WolfSSLProvider; +import com.wolfssl.provider.jsse.WolfSSLParameters; + +/** + * Simple PSK client example using SSLSocket and WolfSSLParameters. + * + * This example demonstrates configuring PSK through WolfSSLParameters and the + * standard setSSLParameters() API. + * + * Usage: PskClientSocket [host] [port] + */ +public class PskClientSocket { + + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 11111; + + public static void main(String[] args) throws Exception { + + String host = DEFAULT_HOST; + int port = DEFAULT_PORT; + + if (args.length >= 1) { + host = args[0]; + } + if (args.length >= 2) { + port = Integer.parseInt(args[1]); + } + + /* Install wolfJSSE provider */ + Security.insertProviderAt(new WolfSSLProvider(), 1); + + /* Create SSLContext with no KeyManager/TrustManager + * since PSK does not use certificates */ + SSLContext ctx = SSLContext.getInstance("TLSv1.2", "wolfJSSE"); + ctx.init(null, null, null); + + /* Create socket */ + SSLSocket sock = (SSLSocket)ctx.getSocketFactory() + .createSocket(host, port); + + /* Find a PSK cipher suite available in this build */ + String pskCipher = findPskCipher(sock.getSupportedCipherSuites()); + + /* Configure PSK via WolfSSLParameters */ + WolfSSLParameters params = new WolfSSLParameters(); + params.setPskClientCb(new MyPskClientCallback()); + params.setCipherSuites(new String[]{pskCipher}); + sock.setSSLParameters(params); + System.out.println("Using cipher: " + pskCipher); + + System.out.println("Connected to " + host + ":" + port); + + /* Do handshake */ + sock.startHandshake(); + System.out.println("SSL handshake complete"); + System.out.println(" Protocol: " + sock.getSession().getProtocol()); + System.out.println(" Cipher: " + sock.getSession().getCipherSuite()); + + /* Send/receive data */ + OutputStream out = sock.getOutputStream(); + InputStream in = sock.getInputStream(); + + String msg = "Hello from PSK client!"; + out.write(msg.getBytes()); + System.out.println("Sent: " + msg); + + byte[] buf = new byte[1024]; + int n = in.read(buf); + if (n > 0) { + System.out.println("Received: " + new String(buf, 0, n)); + } + + sock.close(); + System.out.println("Connection closed"); + } + + /** + * Find first available ephemeral PSK cipher suite from supported list. + * Prefers ECDHE over DHE, AES-GCM over others. Falls back to static PSK + * if no ephemeral suite is available. + */ + private static String findPskCipher(String[] suites) { + + String ecdhe = null; + String dhe = null; + String plain = null; + + for (String s : suites) { + if (s.startsWith("TLS_ECDHE_PSK_WITH_")) { + if (ecdhe == null || s.contains("GCM")) { + ecdhe = s; + } + } + else if (s.startsWith("TLS_DHE_PSK_WITH_")) { + if (dhe == null || s.contains("GCM")) { + dhe = s; + } + } + else if (s.startsWith("TLS_PSK_WITH_")) { + if (plain == null) { + plain = s; + } + } + } + + if (ecdhe != null) { return ecdhe; } + if (dhe != null) { return dhe; } + if (plain != null) { return plain; } + + throw new RuntimeException( + "No PSK cipher suites available. " + + "No PSK cipher suites compiled into wolfSSL"); + } + + /** + * PSK client callback implementation. + */ + static class MyPskClientCallback implements WolfSSLPskClientCallback { + + public long pskClientCallback(WolfSSLSession ssl, String hint, + StringBuffer identity, long idMaxLen, byte[] key, long keyMaxLen) { + + System.out.println("PSK Client Callback:"); + System.out.println(" Hint: " + hint); + + String id = "Client_identity"; + if (id.length() > idMaxLen || keyMaxLen < 4) { + return 0; + } + identity.append(id); + + /* Pre-shared key: 0x1a2b3c4d */ + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + + return 4; + } + } +} diff --git a/examples/provider/PskClientSocket.sh b/examples/provider/PskClientSocket.sh new file mode 100755 index 00000000..60b39197 --- /dev/null +++ b/examples/provider/PskClientSocket.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd ./examples/build +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:../../lib/:/usr/local/lib +java -classpath ../../lib/wolfssl.jar:../../lib/wolfssl-jsse.jar:./ -Dsun.boot.library.path=../../lib/ PskClientSocket "$@" diff --git a/examples/provider/PskServerEngine.java b/examples/provider/PskServerEngine.java new file mode 100644 index 00000000..351b0066 --- /dev/null +++ b/examples/provider/PskServerEngine.java @@ -0,0 +1,261 @@ +/* PskServerEngine.java + * + * Copyright (C) 2006-2026 wolfSSL Inc. + * + * This file is part of wolfSSL. + * + * wolfSSL is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * wolfSSL is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA + */ + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.security.Security; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLSession; + +import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLPskServerCallback; +import com.wolfssl.provider.jsse.WolfSSLProvider; +import com.wolfssl.provider.jsse.WolfSSLParameters; + +/** + * Simple PSK server example using SSLEngine and WolfSSLParameters. + * + * This example demonstrates configuring PSK through WolfSSLParameters with + * SSLEngine. + * + * Usage: PskServerEngine [port] + */ +public class PskServerEngine { + + private static final int DEFAULT_PORT = 11111; + + public static void main(String[] args) throws Exception { + + int port = DEFAULT_PORT; + if (args.length >= 1) { + port = Integer.parseInt(args[0]); + } + + /* Install wolfJSSE provider */ + Security.insertProviderAt(new WolfSSLProvider(), 1); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", "wolfJSSE"); + ctx.init(null, null, null); + + /* Create server socket */ + ServerSocketChannel ssc = ServerSocketChannel.open(); + ssc.socket().bind(new InetSocketAddress(port)); + System.out.println("PSK Engine Server listening on port " + port); + + /* Accept one client */ + SocketChannel sc = ssc.accept(); + System.out.println("Client connected from " + + sc.socket().getInetAddress().getHostAddress()); + + try { + /* Create SSLEngine */ + SSLEngine engine = ctx.createSSLEngine(); + engine.setUseClientMode(false); + + /* Find a PSK cipher suite available in this build */ + String pskCipher = findPskCipher(engine.getSupportedCipherSuites()); + + /* Configure PSK via WolfSSLParameters */ + WolfSSLParameters params = new WolfSSLParameters(); + params.setPskServerCb(new MyPskServerCallback()); + params.setPskIdentityHint("wolfssl psk hint"); + params.setCipherSuites(new String[]{pskCipher}); + engine.setSSLParameters(params); + System.out.println("Using cipher: " + pskCipher); + + /* Perform handshake */ + doHandshake(engine, sc); + System.out.println("SSL handshake complete"); + SSLSession sess = engine.getSession(); + System.out.println(" Protocol: " + sess.getProtocol()); + System.out.println(" Cipher: " + sess.getCipherSuite()); + + /* Receive application data */ + ByteBuffer netIn = ByteBuffer.allocate( + sess.getPacketBufferSize()); + ByteBuffer appIn = ByteBuffer.allocate( + sess.getApplicationBufferSize()); + + sc.read(netIn); + netIn.flip(); + SSLEngineResult res = engine.unwrap(netIn, appIn); + appIn.flip(); + byte[] data = new byte[appIn.remaining()]; + appIn.get(data); + String received = new String(data); + System.out.println("Received: " + received); + + /* Echo back */ + ByteBuffer appOut = ByteBuffer.wrap(received.getBytes()); + ByteBuffer netOut = ByteBuffer.allocate(sess.getPacketBufferSize()); + res = engine.wrap(appOut, netOut); + netOut.flip(); + while (netOut.hasRemaining()) { + sc.write(netOut); + } + System.out.println("Echoed back: " + received); + + engine.closeOutbound(); + + } finally { + sc.close(); + ssc.close(); + } + + System.out.println("Server closed"); + } + + /** + * Perform TLS handshake using SSLEngine over SocketChannel. + */ + private static void doHandshake(SSLEngine engine, SocketChannel sc) + throws Exception { + + SSLSession sess = engine.getSession(); + int netSize = sess.getPacketBufferSize(); + int appSize = sess.getApplicationBufferSize(); + + ByteBuffer localNet = ByteBuffer.allocate(netSize); + ByteBuffer peerNet = ByteBuffer.allocate(netSize); + ByteBuffer localApp = ByteBuffer.allocate(0); + ByteBuffer peerApp = ByteBuffer.allocate(appSize); + + engine.beginHandshake(); + HandshakeStatus hs = engine.getHandshakeStatus(); + + while (hs != HandshakeStatus.FINISHED && + hs != HandshakeStatus.NOT_HANDSHAKING) { + + SSLEngineResult res; + switch (hs) { + case NEED_WRAP: + localNet.clear(); + res = engine.wrap(localApp, localNet); + hs = res.getHandshakeStatus(); + localNet.flip(); + while (localNet.hasRemaining()) { + sc.write(localNet); + } + break; + + case NEED_UNWRAP: + if (sc.read(peerNet) < 0) { + throw new IOException( + "Channel closed during handshake"); + } + peerNet.flip(); + res = engine.unwrap(peerNet, peerApp); + peerNet.compact(); + hs = res.getHandshakeStatus(); + + if (res.getStatus() == + SSLEngineResult.Status.BUFFER_UNDERFLOW) { + continue; + } + break; + + case NEED_TASK: + Runnable task; + while ((task = engine.getDelegatedTask()) != null) { + task.run(); + } + hs = engine.getHandshakeStatus(); + break; + + default: + break; + } + } + } + + /** + * Find first available ephemeral PSK cipher suite from supported list. + * Prefers ECDHE over DHE, AES-GCM over others. Falls back to static PSK + * if no ephemeral suite is available. + */ + private static String findPskCipher(String[] suites) { + + String ecdhe = null; + String dhe = null; + String plain = null; + + for (String s : suites) { + if (s.startsWith("TLS_ECDHE_PSK_WITH_")) { + if (ecdhe == null || s.contains("GCM")) { + ecdhe = s; + } + } + else if (s.startsWith("TLS_DHE_PSK_WITH_")) { + if (dhe == null || s.contains("GCM")) { + dhe = s; + } + } + else if (s.startsWith("TLS_PSK_WITH_")) { + if (plain == null) { + plain = s; + } + } + } + + if (ecdhe != null) { return ecdhe; } + if (dhe != null) { return dhe; } + if (plain != null) { return plain; } + + throw new RuntimeException( + "No PSK cipher suites available. " + + "No PSK cipher suites compiled into wolfSSL"); + } + + /** + * PSK server callback implementation. + */ + static class MyPskServerCallback implements WolfSSLPskServerCallback { + + public long pskServerCallback(WolfSSLSession ssl, String identity, + byte[] key, long keyMaxLen) { + + System.out.println("PSK Server Callback:"); + System.out.println(" Identity: " + identity); + + if (!"Client_identity".equals(identity)) { + System.out.println("Unknown client identity!"); + return 0; + } + + if (keyMaxLen < 4) { + return 0; + } + + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + + return 4; + } + } +} diff --git a/examples/provider/PskServerEngine.sh b/examples/provider/PskServerEngine.sh new file mode 100755 index 00000000..88a1eea3 --- /dev/null +++ b/examples/provider/PskServerEngine.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd ./examples/build +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:../../lib/:/usr/local/lib +java -classpath ../../lib/wolfssl.jar:../../lib/wolfssl-jsse.jar:./ -Dsun.boot.library.path=../../lib/ PskServerEngine "$@" diff --git a/examples/provider/PskServerSocket.java b/examples/provider/PskServerSocket.java new file mode 100644 index 00000000..3e91f688 --- /dev/null +++ b/examples/provider/PskServerSocket.java @@ -0,0 +1,174 @@ +/* PskServerSocket.java + * + * Copyright (C) 2006-2026 wolfSSL Inc. + * + * This file is part of wolfSSL. + * + * wolfSSL is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * wolfSSL is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA + */ + +import java.io.InputStream; +import java.io.OutputStream; +import java.security.Security; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocket; + +import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLPskServerCallback; +import com.wolfssl.provider.jsse.WolfSSLProvider; +import com.wolfssl.provider.jsse.WolfSSLParameters; + +/** + * Simple PSK server example using SSLSocket and WolfSSLParameters. + * + * This example demonstrates configuring PSK through WolfSSLParameters and the + * standard setSSLParameters() API. + * + * Usage: PskServerSocket [port] + */ +public class PskServerSocket { + + private static final int DEFAULT_PORT = 11111; + + public static void main(String[] args) throws Exception { + + int port = DEFAULT_PORT; + if (args.length >= 1) { + port = Integer.parseInt(args[0]); + } + + /* Install wolfJSSE provider */ + Security.insertProviderAt(new WolfSSLProvider(), 1); + + /* Create SSLContext with no KeyManager/TrustManager since PSK does not + * use certificates */ + SSLContext ctx = SSLContext.getInstance("TLSv1.2", "wolfJSSE"); + ctx.init(null, null, null); + + /* Create server socket */ + SSLServerSocket ss = (SSLServerSocket)ctx.getServerSocketFactory() + .createServerSocket(port); + + System.out.println("PSK Server listening on port " + port); + + /* Accept one client */ + SSLSocket sock = (SSLSocket)ss.accept(); + System.out.println("Client connected from " + + sock.getInetAddress().getHostAddress()); + + /* Find a PSK cipher suite available in this build */ + String pskCipher = findPskCipher(sock.getSupportedCipherSuites()); + + /* Configure PSK via WolfSSLParameters */ + WolfSSLParameters params = new WolfSSLParameters(); + params.setPskServerCb(new MyPskServerCallback()); + params.setPskIdentityHint("wolfssl psk hint"); + params.setCipherSuites(new String[]{pskCipher}); + sock.setSSLParameters(params); + System.out.println("Using cipher: " + pskCipher); + + /* Do handshake */ + sock.startHandshake(); + System.out.println("SSL handshake complete"); + System.out.println(" Protocol: " + sock.getSession().getProtocol()); + System.out.println(" Cipher: " + sock.getSession().getCipherSuite()); + + /* Read/write data */ + InputStream in = sock.getInputStream(); + OutputStream out = sock.getOutputStream(); + + byte[] buf = new byte[1024]; + int n = in.read(buf); + if (n > 0) { + String received = new String(buf, 0, n); + System.out.println("Received: " + received); + out.write(received.getBytes()); + System.out.println("Echoed back: " + received); + } + + sock.close(); + ss.close(); + System.out.println("Server closed"); + } + + /** + * Find first available ephemeral PSK cipher suite from supported list. + * Prefers ECDHE over DHE, AES-GCM over others. Falls back to static PSK + * if no ephemeral suite is available. + */ + private static String findPskCipher(String[] suites) { + + String ecdhe = null; + String dhe = null; + String plain = null; + + for (String s : suites) { + if (s.startsWith("TLS_ECDHE_PSK_WITH_")) { + if (ecdhe == null || s.contains("GCM")) { + ecdhe = s; + } + } + else if (s.startsWith("TLS_DHE_PSK_WITH_")) { + if (dhe == null || s.contains("GCM")) { + dhe = s; + } + } + else if (s.startsWith("TLS_PSK_WITH_")) { + if (plain == null) { + plain = s; + } + } + } + + if (ecdhe != null) { return ecdhe; } + if (dhe != null) { return dhe; } + if (plain != null) { return plain; } + + throw new RuntimeException( + "No PSK cipher suites available. " + + "No PSK cipher suites compiled into wolfSSL"); + } + + /** + * PSK server callback implementation. + */ + static class MyPskServerCallback implements WolfSSLPskServerCallback { + + public long pskServerCallback(WolfSSLSession ssl, String identity, + byte[] key, long keyMaxLen) { + + System.out.println("PSK Server Callback:"); + System.out.println(" Identity: " + identity); + + if (!"Client_identity".equals(identity)) { + System.out.println("Unknown client identity!"); + return 0; + } + + if (keyMaxLen < 4) { + return 0; + } + + /* Pre-shared key: 0x1a2b3c4d */ + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + + return 4; + } + } +} diff --git a/examples/provider/PskServerSocket.sh b/examples/provider/PskServerSocket.sh new file mode 100755 index 00000000..caca32ac --- /dev/null +++ b/examples/provider/PskServerSocket.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd ./examples/build +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:../../lib/:/usr/local/lib +java -classpath ../../lib/wolfssl.jar:../../lib/wolfssl-jsse.jar:./ -Dsun.boot.library.path=../../lib/ PskServerSocket "$@" diff --git a/examples/provider/README.md b/examples/provider/README.md index b97236b5..6f44c941 100644 --- a/examples/provider/README.md +++ b/examples/provider/README.md @@ -201,6 +201,37 @@ $ ./examples/provider/DualProviderFIPSTest.sh The `WOLFCRYPTJNI_DIR` environment variable can be set to point to the wolfcrypt-jni build directory (defaults to `../../wolfcryptjni`). +## PSK Examples (SSLSocket and SSLEngine) + +Example client/server applications that demonstrate PSK (Pre-Shared Key) +authentication using `WolfSSLParameters` with both SSLSocket and SSLEngine. + +PSK callbacks and cipher suites are configured through `WolfSSLParameters`, +which extends `SSLParameters` and is applied via the standard +`setSSLParameters()` API. + +**PskServerSocket.java** - PSK server using SSLSocket \ +**PskClientSocket.java** - PSK client using SSLSocket \ +**PskServerEngine.java** - PSK server using SSLEngine \ +**PskClientEngine.java** - PSK client using SSLEngine + +SSLSocket examples: + +``` +$ ./examples/provider/PskServerSocket.sh +$ ./examples/provider/PskClientSocket.sh +``` + +SSLEngine examples: + +``` +$ ./examples/provider/PskServerEngine.sh +$ ./examples/provider/PskClientEngine.sh +``` + +Optional arguments: `[host] [port]` for clients, `[port]` for servers. Default +host is `localhost`, default port is `11111`. + ## Support Please contact the wolfSSL support team at support@wolfssl.com with any diff --git a/native/com_wolfssl_WolfSSLContext.c b/native/com_wolfssl_WolfSSLContext.c index b4ea500d..3795fada 100644 --- a/native/com_wolfssl_WolfSSLContext.c +++ b/native/com_wolfssl_WolfSSLContext.c @@ -5910,8 +5910,9 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -5919,10 +5920,10 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, g_cachedSSLObj = (jobject*) wolfSSL_get_jobject((WOLFSSL*)ssl); if (!g_cachedSSLObj) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession object reference in " - "NativePskClientCb"); - if (needsDetach) + "Can't get WolfSSLSession object reference in NativePskClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -5930,55 +5931,56 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, sessClass = (*jenv)->GetObjectClass(jenv, (jobject)(*g_cachedSSLObj)); if (!sessClass) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession class reference in " - "NativePskClientCb"); - if (needsDetach) + "Can't get WolfSSLSession class reference in NativePskClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* lookup WolfSSLContext private member fieldID */ ctxFid = (*jenv)->GetFieldID(jenv, sessClass, "ctx", - "Lcom/wolfssl/WolfSSLContext;"); + "Lcom/wolfssl/WolfSSLContext;"); if (!ctxFid) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLContext field ID in " - "NativePSKClientCb"); - if (needsDetach) + "Can't get WolfSSLContext field ID in NativePSKClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* find WolfSSLSession.getAssociatedContextPtr() method */ getCtxMethodId = (*jenv)->GetMethodID(jenv, sessClass, - "getAssociatedContextPtr", - "()Lcom/wolfssl/WolfSSLContext;"); + "getAssociatedContextPtr", "()Lcom/wolfssl/WolfSSLContext;"); if (!getCtxMethodId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get getAssociatedContextPtr() method ID in " - "NativePSKClientCb"); - if (needsDetach) + "Can't get getAssociatedContextPtr() method ID in " + "NativePSKClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* get WolfSSLContext(ctx) object from Java WolfSSLSession object */ ctxRef = (*jenv)->CallObjectMethod(jenv, (jobject)(*g_cachedSSLObj), - getCtxMethodId); + getCtxMethodId); CheckException(jenv); if (!ctxRef) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get WolfSSLContext object in NativePskClientCb"); - if (needsDetach) + "Can't get WolfSSLContext object in NativePskClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -5986,54 +5988,53 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, innerCtxClass = (*jenv)->GetObjectClass(jenv, ctxRef); if (!innerCtxClass) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLContext class reference in " - "NativePskClientCb"); + "Can't get native WolfSSLContext class reference in " + "NativePskClientCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* figure out if we need to call the CTX or SSL level callback */ - /* 1. Get internPskClientCb FieldID */ + /* 1. Get internPskClientCb FieldID */ internPskClientCbFid = (*jenv)->GetFieldID(jenv, innerCtxClass, - "internPskClientCb", - "Lcom/wolfssl/WolfSSLPskClientCallback;"); + "internPskClientCb", "Lcom/wolfssl/WolfSSLPskClientCallback;"); if (!internPskClientCbFid) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get native internPskClientCb field ID in " - "NativePSKClientCb"); + "Can't get internPskClientCb field ID in NativePSKClientCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } - /* 2. Get WolfSSLPskClientCallback object (or null) */ + /* 2. Get WolfSSLPskClientCallback object (or null) */ internPskClientCbObj = (*jenv)->GetObjectField(jenv, ctxRef, - internPskClientCbFid); + internPskClientCbFid); if (!internPskClientCbObj) { - printf("Using SSL level PSK Client callback!!!\n"); usingSSLCallback = 1; } if (usingSSLCallback == 1) { /* WolfSSLSession level callback */ pskClientMethodId = (*jenv)->GetMethodID(jenv, sessClass, - "internalPskClientCallback", - "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" - "Ljava/lang/StringBuffer;J[BJ)J"); + "internalPskClientCallback", + "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" + "Ljava/lang/StringBuffer;J[BJ)J"); } else { /* WolfSSLContext level callback */ pskClientMethodId = (*jenv)->GetMethodID(jenv, innerCtxClass, - "internalPskClientCallback", - "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" - "Ljava/lang/StringBuffer;J[BJ)J"); + "internalPskClientCallback", + "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" + "Ljava/lang/StringBuffer;J[BJ)J"); } if (!pskClientMethodId) { @@ -6043,10 +6044,11 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, } (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalPskClientCallback method from JNI"); + "Error getting internalPskClientCallback method from JNI"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6058,10 +6060,11 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error creating String for PSK client hint"); + "Error creating String for PSK client hint"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6073,29 +6076,29 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error finding StringBuffer class for PSK client identity"); + "Error finding StringBuffer class for PSK client identity"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* find StringBuffer Constructor */ - strBufMethodId = (*jenv)->GetMethodID(jenv, strBufClass, - "", "()V"); + strBufMethodId = (*jenv)->GetMethodID(jenv, strBufClass, "", "()V"); if (!strBufMethodId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get StringBuffer constructor method ID " - "in NativePskClientCb"); + "Can't get StringBuffer constructor methodID in NativePskClientCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6107,11 +6110,12 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get StringBuffer object in NativePskClientCb"); + "Can't get StringBuffer object in NativePskClientCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6119,12 +6123,13 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, keyArray = (*jenv)->NewByteArray(jenv, max_key_len); if (!keyArray) { (*jenv)->ThrowNew(jenv, excClass, - "Error creating jbyteArray for PSK client key"); + "Error creating jbyteArray for PSK client key"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6132,8 +6137,8 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, if (usingSSLCallback == 1) { /* call WolfSSLSession level callback */ retval = (*jenv)->CallLongMethod(jenv, (jobject)(*g_cachedSSLObj), - pskClientMethodId, (jobject)(*g_cachedSSLObj), hintString, - strBufObj, (jlong)id_max_len, keyArray, (jlong)max_key_len); + pskClientMethodId, (jobject)(*g_cachedSSLObj), hintString, + strBufObj, (jlong)id_max_len, keyArray, (jlong)max_key_len); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); @@ -6141,15 +6146,16 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } } else { /* call WolfSSLContext level callback */ retval = (*jenv)->CallLongMethod(jenv, ctxRef, pskClientMethodId, - (jobject)(*g_cachedSSLObj), hintString, strBufObj, - (jlong)id_max_len, keyArray, (jlong)max_key_len); + (jobject)(*g_cachedSSLObj), hintString, strBufObj, + (jlong)id_max_len, keyArray, (jlong)max_key_len); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); @@ -6157,8 +6163,9 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } } @@ -6174,27 +6181,29 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* get the String from the StringBuffer */ toStringId = (*jenv)->GetMethodID(jenv, strBufClass, - "toString", "()Ljava/lang/String;"); + "toString", "()Ljava/lang/String;"); if (!toStringId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error getting String ID from StringBuffer in PSK CB"); + "Error getting String ID from StringBuffer in PSK CB"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6203,13 +6212,14 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, CheckException(jenv); if (!bufString) { (*jenv)->ThrowNew(jenv, excClass, - "Error getting String from StringBuffer in PSK CB"); + "Error getting String from StringBuffer in PSK CB"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6217,14 +6227,15 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, tmpString = (*jenv)->GetStringUTFChars(jenv, bufString, 0); if (!tmpString) { (*jenv)->ThrowNew(jenv, excClass, - "Error with GetStringUTFChars in PSK Client CB"); + "Error with GetStringUTFChars in PSK Client CB"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); (*jenv)->DeleteLocalRef(jenv, bufString); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } strcpy(identity, tmpString); @@ -6237,8 +6248,9 @@ unsigned int NativePskClientCb(WOLFSSL* ssl, const char* hint, char* identity, (*jenv)->DeleteLocalRef(jenv, hintString); (*jenv)->DeleteLocalRef(jenv, strBufObj); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return retval; } @@ -6332,8 +6344,9 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6341,10 +6354,10 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, g_cachedSSLObj = (jobject*) wolfSSL_get_jobject((WOLFSSL*)ssl); if (!g_cachedSSLObj) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession object reference in " - "NativePskServerCb"); - if (needsDetach) + "Can't get WolfSSLSession object reference in NativePskServerCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6352,55 +6365,56 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, sessClass = (*jenv)->GetObjectClass(jenv, (jobject)(*g_cachedSSLObj)); if (!sessClass) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession class reference in " - "NativePskServerCb"); - if (needsDetach) + "Can't get WolfSSLSession class reference in NativePskServerCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* lookup WolfSSLContext private member fieldID */ ctxFid = (*jenv)->GetFieldID(jenv, sessClass, "ctx", - "Lcom/wolfssl/WolfSSLContext;"); + "Lcom/wolfssl/WolfSSLContext;"); if (!ctxFid) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLContext field ID in " - "NativePSKClientCb"); - if (needsDetach) + "Can't get WolfSSLContext field ID in NativePSKClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* find WolfSSLSession.getAssociatedContextPtr() method */ getCtxMethodId = (*jenv)->GetMethodID(jenv, sessClass, - "getAssociatedContextPtr", - "()Lcom/wolfssl/WolfSSLContext;"); + "getAssociatedContextPtr", "()Lcom/wolfssl/WolfSSLContext;"); if (!getCtxMethodId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get getAssociatedContextPtr() method ID in " - "NativePSKClientCb"); - if (needsDetach) + "Can't get getAssociatedContextPtr() method ID in " + "NativePSKClientCb"); + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* get WolfSSLContext(ctx) object from Java WolfSSLSession object */ ctxRef = (*jenv)->CallObjectMethod(jenv, (jobject)(*g_cachedSSLObj), - getCtxMethodId); + getCtxMethodId); CheckException(jenv); if (!ctxRef) { (*jenv)->ThrowNew(jenv, excClass, "Can't get WolfSSLContext object in NativePskServerCb"); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6408,17 +6422,17 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, innerCtxClass = (*jenv)->GetObjectClass(jenv, ctxRef); if (!innerCtxClass) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLContext class reference in " - "NativePskServerCb"); + "Can't get WolfSSLContext class reference in NativePskServerCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* figure out if we need to call the CTX or SSL level callback */ - /* 1. Get the internPskServerCb FieldID */ + /* 1. Get the internPskServerCb FieldID */ internPskServerCbFid = (*jenv)->GetFieldID(jenv, innerCtxClass, "internPskServerCb", "Lcom/wolfssl/WolfSSLPskServerCallback;"); @@ -6428,34 +6442,31 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get native internPskServerCb field ID in " - "NativePskServerCb"); + "Can't get internPskServerCb field ID in NativePskServerCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } - /* 2. Get WolfSSLPskServerCallback object (or null) */ + /* 2. Get WolfSSLPskServerCallback object (or null) */ internPskServerCbObj = (*jenv)->GetObjectField(jenv, ctxRef, - internPskServerCbFid); + internPskServerCbFid); if (!internPskServerCbObj) { - printf("Using SSL level PSK Server callback!!!\n"); usingSSLCallback = 1; } if (usingSSLCallback == 1) { /* WolfSSLSession level callback */ pskServerMethodId = (*jenv)->GetMethodID(jenv, sessClass, - "internalPskServerCallback", - "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" - "[BJ)J"); + "internalPskServerCallback", + "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;[BJ)J"); } else { /* WolfSSLContext level callback */ pskServerMethodId = (*jenv)->GetMethodID(jenv, innerCtxClass, - "internalPskServerCallback", - "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;" - "[BJ)J"); + "internalPskServerCallback", + "(Lcom/wolfssl/WolfSSLSession;Ljava/lang/String;[BJ)J"); } if (!pskServerMethodId) { @@ -6464,10 +6475,11 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalPskServerCallback method from JNI"); + "Error getting internalPskServerCallback method from JNI"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6479,10 +6491,11 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error creating String for PSK client identity"); + "Error creating String for PSK client identity"); (*jenv)->DeleteLocalRef(jenv, ctxRef); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } @@ -6490,41 +6503,44 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, keyArray = (*jenv)->NewByteArray(jenv, max_key_len); if (!keyArray) { (*jenv)->ThrowNew(jenv, excClass, - "Error creating jbyteArray for PSK server key"); + "Error creating jbyteArray for PSK server key"); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, identityString); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } /* call Java PSK server callback */ if (usingSSLCallback == 1) { retval = (*jenv)->CallLongMethod(jenv, (jobject)(*g_cachedSSLObj), - pskServerMethodId, (jobject)(*g_cachedSSLObj), identityString, - keyArray, (jlong)max_key_len); + pskServerMethodId, (jobject)(*g_cachedSSLObj), identityString, + keyArray, (jlong)max_key_len); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, identityString); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } } else { retval = (*jenv)->CallLongMethod(jenv, ctxRef, pskServerMethodId, - (jobject)(*g_cachedSSLObj), identityString, - keyArray, (jlong)max_key_len); + (jobject)(*g_cachedSSLObj), identityString, + keyArray, (jlong)max_key_len); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, identityString); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } } @@ -6539,8 +6555,9 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, identityString); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return 0; } } @@ -6549,8 +6566,9 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, (*jenv)->DeleteLocalRef(jenv, ctxRef); (*jenv)->DeleteLocalRef(jenv, identityString); (*jenv)->DeleteLocalRef(jenv, keyArray); - if (needsDetach) + if (needsDetach) { (*g_vm)->DetachCurrentThread(g_vm); + } return retval; } diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLEngineHelper.java b/src/java/com/wolfssl/provider/jsse/WolfSSLEngineHelper.java index cb35d970..dadceb7f 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLEngineHelper.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLEngineHelper.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -48,6 +49,8 @@ import com.wolfssl.WolfSSLSession; import com.wolfssl.WolfSSLException; import com.wolfssl.WolfSSLJNIException; +import com.wolfssl.WolfSSLPskClientCallback; +import com.wolfssl.WolfSSLPskServerCallback; /** * This is a helper class to account for similar methods between SSLSocket @@ -934,6 +937,7 @@ private void setLocalServerNames() { boolean isHttpsConnection = this.clientMode && this.hostname != null && this.peerAddr != null && + this.params.getWolfSSLServerNames() == null && this.params.getServerNames() == null; /* Enable SNI if explicitly requested via property or if @@ -950,15 +954,27 @@ else if (this.clientMode) { () -> "jsse.enableSNIExtension property set to true, " + "enabling SNI"); - /* Explicitly set if user has set through SSLParameters */ - List names = this.params.getServerNames(); - if (names != null && names.size() > 0) { - /* Should only be one server name */ + /* Explicitly set if user has set through SSLParameters. Check + * wolfSSL-specific server names first, then fall back to standard + * SSLParameters server names (set via parent setServerNames()). */ + List names = + this.params.getWolfSSLServerNames(); + List parentNames = + this.params.getServerNames(); + if (names != null && !names.isEmpty()) { WolfSSLSNIServerName sni = names.get(0); if (sni != null) { this.ssl.useSNI((byte)sni.getType(), sni.getEncoded()); } - } else if (autoSNI) { + } else if (parentNames != null && !parentNames.isEmpty()) { + SNIServerName sni = parentNames.get(0); + if (sni != null) { + this.ssl.useSNI((byte)sni.getType(), sni.getEncoded()); + } + } + + if ((names == null || names.isEmpty()) && + (parentNames == null || parentNames.isEmpty()) && autoSNI) { if (this.peerAddr != null && this.jdkTlsTrustNameService) { WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, () -> "setting SNI extension with " + @@ -967,6 +983,7 @@ else if (this.clientMode) { this.ssl.useSNI((byte)0, this.peerAddr.getHostName().getBytes()); + } else if (this.hostname != null) { if (peerAddr != null) { WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, @@ -1252,6 +1269,70 @@ private void setLocalExtendedMasterSecret() { } } + private void setLocalPskSettings() throws SSLException { + + WolfSSLPskClientCallback clientCb = this.params.getPskClientCb(); + WolfSSLPskServerCallback serverCb = this.params.getPskServerCb(); + String identityHint = this.params.getPskIdentityHint(); + boolean keepArr = this.params.getKeepArrays(); + + try { + if (clientCb != null) { + if (this.clientMode) { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "setting PSK client callback " + + "from WolfSSLParameters"); + this.ssl.setPskClientCb(clientCb); + } else { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "ignoring PSK client callback " + + "in server mode"); + } + } + + if (serverCb != null) { + if (!this.clientMode) { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "setting PSK server callback " + + "from WolfSSLParameters"); + this.ssl.setPskServerCb(serverCb); + } else { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "ignoring PSK server callback " + + "in client mode"); + } + } + + if (identityHint != null) { + if (!this.clientMode) { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "setting PSK identity hint " + + "from WolfSSLParameters"); + int ret = this.ssl.usePskIdentityHint(identityHint); + if (ret != WolfSSL.SSL_SUCCESS) { + throw new SSLException( + "Error setting PSK identity hint, ret = " + ret); + } + } else { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "ignoring PSK identity hint in client mode"); + } + } + + if (keepArr) { + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + () -> "enabling keepArrays from WolfSSLParameters"); + this.ssl.keepArrays(); + } + + } catch (IllegalStateException | WolfSSLJNIException e) { + SSLException sslEx = new SSLException( + "Error setting PSK parameters: " + e.getMessage()); + sslEx.initCause(e); + throw sslEx; + } + } + private void setLocalParams(SSLSocket socket, SSLEngine engine) throws SSLException { @@ -1269,6 +1350,7 @@ private void setLocalParams(SSLSocket socket, SSLEngine engine) this.setLocalSupportedCurves(); this.setLocalMaximumPacketSize(); this.setLocalExtendedMasterSecret(); + this.setLocalPskSettings(); } /** @@ -1515,8 +1597,7 @@ else if (peerAddr != null) { this.session.updateStoredSessionValues(); if (!this.clientMode && !matchSNI()) { - throw new SSLHandshakeException( - "Unrecognized Server Name"); + throw new SSLHandshakeException("Unrecognized Server Name"); } return ret; @@ -1592,11 +1673,11 @@ private boolean isLegacyDHEnabled() { * @return true on success or false if no match was found */ protected synchronized boolean matchSNI(){ - List matchers = this.params.getSNIMatchers(); + Collection matchers = this.params.getSNIMatchers(); if (matchers != null && !matchers.isEmpty()) { /* Match a server name to SNI requested by Client */ - List serverNames = this.session - .getRequestedServerNames(); + List serverNames = + this.session.getRequestedServerNames(); if (serverNames != null && !serverNames.isEmpty()) { for (SNIServerName serverName : serverNames) { if (serverName.getType() == WolfSSL.WOLFSSL_SNI_HOST_NAME) { diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLParameters.java b/src/java/com/wolfssl/provider/jsse/WolfSSLParameters.java index d6621c1e..1141a995 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLParameters.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLParameters.java @@ -22,39 +22,72 @@ import java.util.List; import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SSLParameters; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import com.wolfssl.WolfSSLPskClientCallback; +import com.wolfssl.WolfSSLPskServerCallback; + /** - * wolfJSSE implementation of SSLParameters + * wolfJSSE implementation of SSLParameters. + * + * Extends {@link javax.net.ssl.SSLParameters} so that instances can be + * passed directly to {@code SSLEngine.setSSLParameters()} and + * {@code SSLSocket.setSSLParameters()}. * * This class includes the functionality of java SSLParameters, but allows * wolfJSSE better control over settings, interop with older Java versions, * etc. Strings set and returned should be cloned. * - * This class is used internally to wolfJSSE. When a SSLParameters needs to - * be returned to an application (ex: SSLContext.getDefaultSSLParameters(), - * SSLContext.getSupportedSSLParameters()) wolfJSSE calls - * WolfSSLEngineHelper.decoupleParams() which creates a SSLParameters object - * from a WolfSSLParameters. + * In addition to the standard SSLParameters settings, this class exposes + * wolfSSL-specific options such as PSK callbacks, PSK identity hint, and + * keepArrays. Applications can configure these fields and then call + * {@code engine.setSSLParameters(wolfParams)} to apply them before the + * TLS handshake. + * + * @author wolfSSL */ -final class WolfSSLParameters { +public class WolfSSLParameters extends SSLParameters { private String[] cipherSuites; private String[] protocols; private boolean wantClientAuth = false; private boolean needClientAuth = false; private String endpointIdAlgorithm = null; - private List serverNames; - private List sniMatchers; - private boolean useCipherSuiteOrder = true; + private List wolfSSLServerNames; String[] applicationProtocols = new String[0]; private boolean useSessionTickets = false; private byte[] alpnProtocols = null; /* Default to 0, means use implicit implementation size */ private int maxPacketSize = 0; + /* PSK callbacks and settings, set via public API */ + private WolfSSLPskClientCallback pskClientCb = null; + private WolfSSLPskServerCallback pskServerCb = null; + private String pskIdentityHint = null; + private boolean keepArrays = false; + + /* Local storage for cipher-suite-order preference, to support + * runtimes where SSLParameters.get/setUseCipherSuitesOrder() is + * unavailable. */ + private Boolean useCipherSuitesOrder = null; + + /** Default WolfSSLParameters constructor */ + @SuppressWarnings("this-escape") + public WolfSSLParameters() { + super(); + /* wolfJSSE defaults to honoring server cipher order */ + this.useCipherSuitesOrder = Boolean.TRUE; + try { + super.setUseCipherSuitesOrder(true); + } catch (NoSuchMethodError e) { + /* Older runtimes may not have this method, + * state kept in local useCipherSuitesOrder field */ + } + } + /* create duplicate copy of these parameters */ protected synchronized WolfSSLParameters copy() { WolfSSLParameters cp = new WolfSSLParameters(); @@ -62,31 +95,47 @@ protected synchronized WolfSSLParameters copy() { cp.setProtocols(this.protocols); cp.wantClientAuth = this.wantClientAuth; cp.needClientAuth = this.needClientAuth; - cp.setServerNames(this.getServerNames()); + cp.setWolfSSLServerNames(this.getWolfSSLServerNames()); cp.useSessionTickets = this.useSessionTickets; cp.endpointIdAlgorithm = this.endpointIdAlgorithm; cp.setApplicationProtocols(this.applicationProtocols); - cp.useCipherSuiteOrder = this.useCipherSuiteOrder; + try { + cp.setUseCipherSuitesOrder(this.getUseCipherSuitesOrder()); + } catch (NoSuchMethodError e) { + /* Fall back to local field copy for older runtimes where parent + * final methods are absent */ + cp.useCipherSuitesOrder = this.useCipherSuitesOrder; + } cp.maxPacketSize = this.maxPacketSize; + cp.pskClientCb = this.pskClientCb; + cp.pskServerCb = this.pskServerCb; + cp.pskIdentityHint = this.pskIdentityHint; + cp.keepArrays = this.keepArrays; if (alpnProtocols != null && alpnProtocols.length != 0) { cp.setAlpnProtocols(this.alpnProtocols); } - /* TODO: duplicate other properties here when WolfSSLParameters - * can handle them */ + /* Copy SNI matchers and server names using parent final methods. + * Server names set via the standard SSLParameters.setServerNames() + * are stored in the parent and must be copied separately from + * wolfSSLServerNames. */ cp.setSNIMatchers(this.getSNIMatchers()); + cp.setServerNames(this.getServerNames()); + return cp; } - String[] getCipherSuites() { + @Override + public String[] getCipherSuites() { if (this.cipherSuites == null) { return null; } return this.cipherSuites.clone(); } - void setCipherSuites(String[] cipherSuites) { + @Override + public void setCipherSuites(String[] cipherSuites) { /* cipherSuites array is sanitized by wolfJSSE caller */ if (cipherSuites == null) { this.cipherSuites = null; @@ -96,14 +145,16 @@ void setCipherSuites(String[] cipherSuites) { } } - synchronized String[] getProtocols() { + @Override + public synchronized String[] getProtocols() { if (this.protocols == null) { return null; } return this.protocols.clone(); } - synchronized void setProtocols(String[] protocols) { + @Override + public synchronized void setProtocols(String[] protocols) { /* protocols array is sanitized by wolfJSSE caller */ if (protocols == null) { this.protocols = null; @@ -113,11 +164,13 @@ synchronized void setProtocols(String[] protocols) { } } - boolean getWantClientAuth() { + @Override + public boolean getWantClientAuth() { return this.wantClientAuth; } - void setWantClientAuth(boolean wantClientAuth) { + @Override + public void setWantClientAuth(boolean wantClientAuth) { /* wantClientAuth OR needClientAuth can be set true, not both */ this.wantClientAuth = wantClientAuth; if (this.wantClientAuth) { @@ -125,11 +178,13 @@ void setWantClientAuth(boolean wantClientAuth) { } } - boolean getNeedClientAuth() { + @Override + public boolean getNeedClientAuth() { return this.needClientAuth; } - void setNeedClientAuth(boolean needClientAuth) { + @Override + public void setNeedClientAuth(boolean needClientAuth) { /* wantClientAuth OR needClientAuth can be set true, not both */ this.needClientAuth = needClientAuth; if (this.needClientAuth) { @@ -137,33 +192,49 @@ void setNeedClientAuth(boolean needClientAuth) { } } - String getEndpointIdentificationAlgorithm() { + @Override + public String getEndpointIdentificationAlgorithm() { return this.endpointIdAlgorithm; } - void setEndpointIdentificationAlgorithm(String algorithm) { + @Override + public void setEndpointIdentificationAlgorithm(String algorithm) { this.endpointIdAlgorithm = algorithm; } - void setServerNames(List serverNames) { + /** + * Set wolfSSL SNI server names. + * Uses WolfSSLSNIServerName type to maintain compatibility with older Java + * versions. This is separate from the parent SSLParameters setServerNames() + * which uses SNIServerName. + * + * @param serverNames list of WolfSSLSNIServerName to set, or null to clear + */ + public void setWolfSSLServerNames(List serverNames) { if (serverNames == null) { - this.serverNames = null; + this.wolfSSLServerNames = null; } else { - this.serverNames = Collections.unmodifiableList( - new ArrayList(serverNames)); + this.wolfSSLServerNames = Collections.unmodifiableList( + new ArrayList(serverNames)); } } - List getServerNames() { - if (this.serverNames == null) { + /** + * Get wolfSSL SNI server names. + * Returns WolfSSLSNIServerName type for internal use. + * + * @return list of WolfSSLSNIServerName, or null if not set + */ + public List getWolfSSLServerNames() { + if (this.wolfSSLServerNames == null) { return null; } else { return Collections.unmodifiableList( - new ArrayList(this.serverNames)); + new ArrayList(this.wolfSSLServerNames)); } } - /* not part of Java SSLParameters. Needed here for Android compatibility */ + /* Not part of Java SSLParameters. Needed here for Android compatibility */ void setUseSessionTickets(boolean useTickets) { this.useSessionTickets = useTickets; } @@ -173,7 +244,6 @@ boolean getUseSessionTickets() { } void setAlpnProtocols(byte[] alpnProtos) { - if (alpnProtos == null || alpnProtos.length == 0) { throw new IllegalArgumentException( "ALPN protocol array null or zero length"); @@ -186,45 +256,29 @@ byte[] getAlpnProtos() { return this.alpnProtocols; } - /* TODO, create our own class for SNIMatcher, in case Java doesn't support it */ - void setSNIMatchers(Collection matchers) { - if (matchers != null && !matchers.isEmpty()) { - if (this.sniMatchers == null) { - this.sniMatchers = new ArrayList(); - } - for (SNIMatcher matcher : matchers) { - this.sniMatchers.add(matcher); - } - } else { - this.sniMatchers = new ArrayList(); - } - } - - /* TODO, create our own class for SNIMatcher, in case Java doesn't support it */ - List getSNIMatchers() { - if (this.sniMatchers != null && !this.sniMatchers.isEmpty()) { - return Collections.unmodifiableList(new ArrayList(sniMatchers)); - } else { - return Collections.emptyList(); - } - } - - void setUseCipherSuitesOrder(boolean honorOrder) { - this.useCipherSuiteOrder = honorOrder; - } - - boolean getUseCipherSuitesOrder() { - return this.useCipherSuiteOrder; - } - - String[] getApplicationProtocols() { + /* + * SSLParameters.setSNIMatchers() and getSNIMatchers() are final. This + * class delegates to the parent for SNI matcher storage and does not + * maintain its own field. + * + * SSLParameters.setServerNames() and getServerNames() are also final, + * so wolfSSL-specific server names use + * setWolfSSLServerNames()/getWolfSSLServerNames() instead. + * + * SSLParameters.setUseCipherSuitesOrder() and getUseCipherSuitesOrder() + * are also final. The local useCipherSuitesOrder field provides a backup + * copy so copy() can transfer the value without calling the parent + * methods, which may not exist on older runtimes. + */ + + public String[] getApplicationProtocols() { if (this.applicationProtocols == null) { return null; } return this.applicationProtocols.clone(); } - void setApplicationProtocols(String[] protocols) { + public void setApplicationProtocols(String[] protocols) { if (protocols == null) { this.applicationProtocols = new String[0]; } @@ -233,12 +287,87 @@ void setApplicationProtocols(String[] protocols) { } } - int getMaximumPacketSize() { + public int getMaximumPacketSize() { return this.maxPacketSize; } - void setMaximumPacketSize(int maximumPacketSize) { + public void setMaximumPacketSize(int maximumPacketSize) { this.maxPacketSize = maximumPacketSize; } -} + /** + * Set the PSK client callback to be used for this connection. + * + * @param callback PSK client callback implementation, or null to clear + */ + public void setPskClientCb(WolfSSLPskClientCallback callback) { + this.pskClientCb = callback; + } + + /** + * Get the PSK client callback set for this connection. + * + * @return PSK client callback, or null if not set + */ + public WolfSSLPskClientCallback getPskClientCb() { + return this.pskClientCb; + } + + /** + * Set the PSK server callback to be used for this connection. + * + * @param callback PSK server callback implementation, or null to clear + */ + public void setPskServerCb(WolfSSLPskServerCallback callback) { + this.pskServerCb = callback; + } + + /** + * Get the PSK server callback set for this connection. + * + * @return PSK server callback, or null if not set + */ + public WolfSSLPskServerCallback getPskServerCb() { + return this.pskServerCb; + } + + /** + * Set the PSK identity hint for this connection. + * + * @param hint PSK identity hint string, or null to clear + */ + public void setPskIdentityHint(String hint) { + this.pskIdentityHint = hint; + } + + /** + * Get the PSK identity hint set for this connection. + * + * @return PSK identity hint string, or null if not set + */ + public String getPskIdentityHint() { + return this.pskIdentityHint; + } + + /** + * Set whether to keep handshake arrays after handshake completion. + *

+ * When enabled, wolfSSL will retain internal arrays after the handshake, + * which is needed for some PSK use cases where session data must be + * accessed after handshake completion. + * + * @param keep true to keep arrays, false otherwise + */ + public void setKeepArrays(boolean keep) { + this.keepArrays = keep; + } + + /** + * Get whether keepArrays is enabled. + * + * @return true if keepArrays is enabled, false otherwise + */ + public boolean getKeepArrays() { + return this.keepArrays; + } +} diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLParametersHelper.java b/src/java/com/wolfssl/provider/jsse/WolfSSLParametersHelper.java index 7abb9d7b..e8103aa5 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLParametersHelper.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLParametersHelper.java @@ -337,6 +337,44 @@ protected static void importParams(SSLParameters in, /* Not available, just ignore and continue */ } + /* If input is a WolfSSLParameters, copy wolfJSSE specific fields + * that are not part of the standard SSLParameters API */ + if (in instanceof WolfSSLParameters) { + WolfSSLParameters wolfIn = (WolfSSLParameters)in; + out.setPskClientCb(wolfIn.getPskClientCb()); + out.setPskServerCb(wolfIn.getPskServerCb()); + out.setPskIdentityHint(wolfIn.getPskIdentityHint()); + out.setKeepArrays(wolfIn.getKeepArrays()); + out.setWolfSSLServerNames(wolfIn.getWolfSSLServerNames()); + } else { + /* Clear any existing PSK-related configuration to avoid + * leakage across importParams() calls when input is not + * WolfSSLParameters. */ + out.setPskClientCb(null); + out.setPskServerCb(null); + out.setPskIdentityHint(null); + out.setKeepArrays(false); + + /* Clear wolfSSL-specific SNI names only if input has no standard + * server names. JDK8Helper.getServerNames() only sets + * wolfSSLServerNames when in.getServerNames() is non-null, so + * stale values could persist otherwise. */ + boolean hasStdSni = false; + try { + if (getServerNames != null) { + Object sni = getServerNames.invoke(in); + if (sni != null) { + hasStdSni = true; + } + } + } catch (IllegalAccessException | + InvocationTargetException e) { + /* Not available, assume no standard SNI */ + } + if (!hasStdSni) { + out.setWolfSSLServerNames(null); + } + } } } diff --git a/src/java/com/wolfssl/provider/jsse/adapter/WolfSSLJDK8Helper.java b/src/java/com/wolfssl/provider/jsse/adapter/WolfSSLJDK8Helper.java index ddbb04f8..062a312c 100644 --- a/src/java/com/wolfssl/provider/jsse/adapter/WolfSSLJDK8Helper.java +++ b/src/java/com/wolfssl/provider/jsse/adapter/WolfSSLJDK8Helper.java @@ -60,7 +60,7 @@ protected static void setServerNames(final SSLParameters out, "WolfSSLJDK8Helper.setServerNames() cannot be null"); } - List wsni = in.getServerNames(); + List wsni = in.getWolfSSLServerNames(); if (wsni != null) { /* convert WolfSSLSNIServerName list to SNIServerName */ final ArrayList sni = @@ -109,8 +109,8 @@ protected static void getServerNames(final SSLParameters in, name.getEncoded())); } - /* call WolfSSLParameters.setServerNames() */ - out.setServerNames(wsni); + /* call WolfSSLParameters.setWolfSSLServerNames() */ + out.setWolfSSLServerNames(wsni); } } diff --git a/src/test/com/wolfssl/provider/jsse/test/WolfSSLJSSETestSuite.java b/src/test/com/wolfssl/provider/jsse/test/WolfSSLJSSETestSuite.java index 2826862f..9acb931a 100644 --- a/src/test/com/wolfssl/provider/jsse/test/WolfSSLJSSETestSuite.java +++ b/src/test/com/wolfssl/provider/jsse/test/WolfSSLJSSETestSuite.java @@ -38,7 +38,8 @@ WolfSSLSessionContextTest.class, WolfSSLX509Test.class, WolfSSLKeyX509Test.class, - WolfSSLServiceLoaderTest.class + WolfSSLServiceLoaderTest.class, + WolfSSLParametersPskTest.class }) diff --git a/src/test/com/wolfssl/provider/jsse/test/WolfSSLParametersPskTest.java b/src/test/com/wolfssl/provider/jsse/test/WolfSSLParametersPskTest.java new file mode 100644 index 00000000..9d7d1775 --- /dev/null +++ b/src/test/com/wolfssl/provider/jsse/test/WolfSSLParametersPskTest.java @@ -0,0 +1,690 @@ +/* WolfSSLParametersPskTest.java + * + * Copyright (C) 2006-2026 wolfSSL Inc. + * + * This file is part of wolfSSL. + * + * wolfSSL is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * wolfSSL is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA + */ + +package com.wolfssl.provider.jsse.test; + +import com.wolfssl.WolfSSL; +import com.wolfssl.WolfSSLException; +import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLPskClientCallback; +import com.wolfssl.WolfSSLPskServerCallback; +import com.wolfssl.provider.jsse.WolfSSLProvider; +import com.wolfssl.provider.jsse.WolfSSLParameters; + +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; + +import java.nio.ByteBuffer; +import java.security.Provider; +import java.security.Security; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +/** + * Tests for PSK support via WolfSSLParameters. + * + * @author wolfSSL + */ +public class WolfSSLParametersPskTest { + + private static final String engineProvider = "wolfJSSE"; + private static String pskCipher = null; + + @Rule + public Timeout globalTimeout = + new Timeout(60, TimeUnit.SECONDS); + + private static WolfSSLPskClientCallback testClientCb = + new WolfSSLPskClientCallback() { + public long pskClientCallback(WolfSSLSession ssl, String hint, + StringBuffer identity, long idMaxLen, byte[] key, + long keyMaxLen) { + + String id = "Client_identity"; + if (id.length() > idMaxLen || + keyMaxLen < 4) { + return 0; + } + identity.append(id); + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + return 4; + } + }; + + private static WolfSSLPskServerCallback testServerCb = + new WolfSSLPskServerCallback() { + public long pskServerCallback(WolfSSLSession ssl, String identity, + byte[] key, long keyMaxLen) { + + if (!"Client_identity".equals(identity)) { + return 0; + } + if (keyMaxLen < 4) { + return 0; + } + key[0] = 26; + key[1] = 43; + key[2] = 60; + key[3] = 77; + return 4; + } + }; + + @BeforeClass + public static void testSetup() throws WolfSSLException { + + System.out.println("WolfSSLParametersPskTest Class"); + + /* Install wolfJSSE provider */ + Security.insertProviderAt(new WolfSSLProvider(), 1); + Provider p = Security.getProvider(engineProvider); + assertNotNull(p); + + /* Skip all tests if PSK not compiled in */ + Assume.assumeTrue("PSK not enabled, skipping PSK tests", + WolfSSL.isEnabledPSK() == 1); + + /* Find an available PSK cipher suite, preferring ephemeral + * (ECDHE > DHE) over static PSK */ + try { + String ecdhe = null; + String dhe = null; + String plain = null; + String[] suites = null; + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + suites = ctx.createSSLEngine().getSupportedCipherSuites(); + + for (String s : suites) { + if (s.startsWith("TLS_ECDHE_PSK_WITH_")) { + if (ecdhe == null || s.contains("GCM")) { + ecdhe = s; + } + } + else if (s.startsWith("TLS_DHE_PSK_WITH_")) { + if (dhe == null || s.contains("GCM")) { + dhe = s; + } + } + else if (s.startsWith("TLS_PSK_WITH_")) { + if (plain == null) { + plain = s; + } + } + } + if (ecdhe != null) { + pskCipher = ecdhe; + } + else if (dhe != null) { + pskCipher = dhe; + } + else { + pskCipher = plain; + } + + } catch (Exception e) { + /* ignore */ + } + + Assume.assumeTrue("No PSK cipher suite available", pskCipher != null); + } + + @Test + public void testExtendsSSLParameters() { + System.out.print("\tExtends SSLParameters\t\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertTrue(wp instanceof SSLParameters); + + System.out.println("passed"); + } + + @Test + public void testGetSetPskClientCb() { + System.out.print("\tGet/Set PSK client cb\t\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertNull(wp.getPskClientCb()); + + wp.setPskClientCb(testClientCb); + assertEquals(testClientCb, wp.getPskClientCb()); + + wp.setPskClientCb(null); + assertNull(wp.getPskClientCb()); + + System.out.println("passed"); + } + + @Test + public void testGetSetPskServerCb() { + System.out.print("\tGet/Set PSK server cb\t\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertNull(wp.getPskServerCb()); + + wp.setPskServerCb(testServerCb); + assertEquals(testServerCb, wp.getPskServerCb()); + + wp.setPskServerCb(null); + assertNull(wp.getPskServerCb()); + + System.out.println("passed"); + } + + @Test + public void testGetSetPskIdentityHint() { + System.out.print("\tGet/Set PSK identity hint\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertNull(wp.getPskIdentityHint()); + + wp.setPskIdentityHint("test_hint"); + assertEquals("test_hint", wp.getPskIdentityHint()); + + wp.setPskIdentityHint(null); + assertNull(wp.getPskIdentityHint()); + + System.out.println("passed"); + } + + @Test + public void testGetSetKeepArrays() { + System.out.print("\tGet/Set keepArrays\t\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertFalse(wp.getKeepArrays()); + + wp.setKeepArrays(true); + assertTrue(wp.getKeepArrays()); + + wp.setKeepArrays(false); + assertFalse(wp.getKeepArrays()); + + System.out.println("passed"); + } + + @Test + public void testUseCipherSuitesOrderDefault() { + System.out.print("\tCipher suite order default\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + assertTrue("useCipherSuitesOrder should default to true", + wp.getUseCipherSuitesOrder()); + + System.out.println("passed"); + } + + @Test + public void testUseCipherSuitesOrderGetSet() { + System.out.print("\tGet/Set useCipherSuitesOrder\t... "); + + WolfSSLParameters wp = new WolfSSLParameters(); + + wp.setUseCipherSuitesOrder(false); + assertFalse(wp.getUseCipherSuitesOrder()); + + wp.setUseCipherSuitesOrder(true); + assertTrue(wp.getUseCipherSuitesOrder()); + + System.out.println("passed"); + } + + @Test + public void testPskFieldsNotLeakedViaGetSSLParameters() + throws Exception { + + System.out.print("\tPSK fields not in getSSLParams\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + SSLEngine engine = ctx.createSSLEngine(); + + /* getSSLParameters() returns a standard SSLParameters, PSK fields + * should not leak through */ + SSLParameters sp = engine.getSSLParameters(); + assertNotNull(sp); + if (sp instanceof WolfSSLParameters) { + WolfSSLParameters wp = (WolfSSLParameters)sp; + assertNull(wp.getPskClientCb()); + assertNull(wp.getPskServerCb()); + assertNull(wp.getPskIdentityHint()); + assertFalse(wp.getKeepArrays()); + } + + System.out.println("passed"); + } + + @Test + public void testPskClearedOnPlainSSLParamsImport() + throws Exception { + + System.out.print("\tPSK cleared by plain SSLParams\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + + SSLEngine serverEngine = ctx.createSSLEngine(); + serverEngine.setUseClientMode(false); + + /* Set WolfSSLParameters with PSK callbacks on server */ + WolfSSLParameters serverParams = new WolfSSLParameters(); + serverParams.setPskServerCb(testServerCb); + serverParams.setPskIdentityHint("wolfssl hint"); + serverParams.setCipherSuites(new String[] {pskCipher}); + serverEngine.setSSLParameters(serverParams); + + /* Overwrite with plain SSLParameters, should clear PSK + * callbacks but keep PSK cipher suite to force PSK path */ + SSLParameters plainParams = new SSLParameters(); + plainParams.setCipherSuites(new String[] {pskCipher}); + serverEngine.setSSLParameters(plainParams); + + SSLEngine clientEngine = ctx.createSSLEngine("localhost", 0); + clientEngine.setUseClientMode(true); + + WolfSSLParameters clientParams = new WolfSSLParameters(); + clientParams.setPskClientCb(testClientCb); + clientParams.setCipherSuites(new String[] {pskCipher}); + clientEngine.setSSLParameters(clientParams); + + /* Handshake should fail because server PSK callback was + * cleared by the plain SSLParameters import */ + boolean handshakeSucceeded = false; + try { + doInMemoryHandshake(clientEngine, serverEngine); + handshakeSucceeded = true; + } catch (Exception e) { + /* Expected: handshake fails without PSK callback */ + } catch (AssertionError e) { + /* Expected: doInMemoryHandshake loop exhausted */ + } + + assertFalse("PSK handshake should fail after PSK callback cleared by " + + "plain SSLParameters import", handshakeSucceeded); + + clientEngine.closeOutbound(); + serverEngine.closeOutbound(); + + System.out.println("passed"); + } + + @Test + public void testPskEngineHandshake() throws Exception { + + System.out.print("\tPSK SSLEngine handshake\t\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + + SSLEngine serverEngine = ctx.createSSLEngine(); + serverEngine.setUseClientMode(false); + + WolfSSLParameters serverParams = new WolfSSLParameters(); + serverParams.setPskServerCb(testServerCb); + serverParams.setPskIdentityHint("wolfssl hint"); + serverParams.setCipherSuites( + new String[]{pskCipher}); + serverEngine.setSSLParameters(serverParams); + + SSLEngine clientEngine = ctx.createSSLEngine("localhost", 0); + clientEngine.setUseClientMode(true); + + WolfSSLParameters clientParams = new WolfSSLParameters(); + clientParams.setPskClientCb(testClientCb); + clientParams.setCipherSuites( + new String[]{pskCipher}); + clientEngine.setSSLParameters(clientParams); + + /* Do in-memory handshake */ + doInMemoryHandshake(clientEngine, serverEngine); + + /* Verify handshake completed */ + assertEquals(HandshakeStatus.NOT_HANDSHAKING, + clientEngine.getHandshakeStatus()); + assertEquals(HandshakeStatus.NOT_HANDSHAKING, + serverEngine.getHandshakeStatus()); + + clientEngine.closeOutbound(); + serverEngine.closeOutbound(); + + System.out.println("passed"); + } + + @Test + public void testPskEngineKeepArrays() throws Exception { + + System.out.print("\tPSK SSLEngine keepArrays\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + + SSLEngine serverEngine = ctx.createSSLEngine(); + serverEngine.setUseClientMode(false); + + WolfSSLParameters serverParams = new WolfSSLParameters(); + serverParams.setPskServerCb(testServerCb); + serverParams.setPskIdentityHint("wolfssl hint"); + serverParams.setKeepArrays(true); + serverParams.setCipherSuites( + new String[]{pskCipher}); + serverEngine.setSSLParameters(serverParams); + + SSLEngine clientEngine = ctx.createSSLEngine("localhost", 0); + clientEngine.setUseClientMode(true); + + WolfSSLParameters clientParams = new WolfSSLParameters(); + clientParams.setPskClientCb(testClientCb); + clientParams.setKeepArrays(true); + clientParams.setCipherSuites( + new String[]{pskCipher}); + clientEngine.setSSLParameters(clientParams); + + doInMemoryHandshake(clientEngine, serverEngine); + + assertEquals(HandshakeStatus.NOT_HANDSHAKING, + clientEngine.getHandshakeStatus()); + + clientEngine.closeOutbound(); + serverEngine.closeOutbound(); + + System.out.println("passed"); + } + + @Test + public void testPskSocketHandshake() throws Exception { + + System.out.print("\tPSK SSLSocket handshake\t\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + + SSLServerSocket ss = (SSLServerSocket)ctx.getServerSocketFactory() + .createServerSocket(0); + int port = ss.getLocalPort(); + + final CountDownLatch latch = new CountDownLatch(1); + final Exception[] serverEx = new Exception[1]; + + /* Server thread */ + Thread serverThread = new Thread(() -> { + try { + SSLSocket serverSock = (SSLSocket)ss.accept(); + + WolfSSLParameters sp = new WolfSSLParameters(); + sp.setPskServerCb(testServerCb); + sp.setPskIdentityHint("wolfssl hint"); + sp.setCipherSuites( + new String[]{pskCipher}); + serverSock.setSSLParameters(sp); + + serverSock.startHandshake(); + + /* Read/write to confirm connection */ + InputStream in = serverSock.getInputStream(); + OutputStream out = serverSock.getOutputStream(); + byte[] buf = new byte[64]; + int n = in.read(buf); + out.write(buf, 0, n); + + serverSock.close(); + + } catch (Exception e) { + serverEx[0] = e; + } finally { + latch.countDown(); + } + }); + serverThread.start(); + + try { + SSLSocket clientSock = (SSLSocket) + ctx.getSocketFactory().createSocket( + InetAddress.getLoopbackAddress(), port); + + WolfSSLParameters cp = new WolfSSLParameters(); + cp.setPskClientCb(testClientCb); + cp.setCipherSuites( + new String[]{pskCipher}); + clientSock.setSSLParameters(cp); + + clientSock.startHandshake(); + + /* Send data and read echo */ + OutputStream out = clientSock.getOutputStream(); + InputStream in = clientSock.getInputStream(); + out.write("hello".getBytes()); + byte[] buf = new byte[64]; + int n = in.read(buf); + assertEquals("hello", new String(buf, 0, n)); + + clientSock.close(); + + } finally { + ss.close(); + } + + assertTrue("Server thread timed out", + latch.await(10, TimeUnit.SECONDS)); + + if (serverEx[0] != null) { + fail("Server thread failed: " + + serverEx[0].getMessage()); + } + + System.out.println("passed"); + } + + @Test + public void testPskSocketKeepArrays() throws Exception { + + System.out.print("\tPSK SSLSocket keepArrays\t... "); + + SSLContext ctx = SSLContext.getInstance("TLSv1.2", engineProvider); + ctx.init(null, null, null); + + SSLServerSocket ss = (SSLServerSocket)ctx.getServerSocketFactory() + .createServerSocket(0); + int port = ss.getLocalPort(); + + final CountDownLatch latch = new CountDownLatch(1); + final Exception[] serverEx = new Exception[1]; + + Thread serverThread = new Thread(() -> { + try { + SSLSocket serverSock = (SSLSocket)ss.accept(); + + WolfSSLParameters sp = new WolfSSLParameters(); + sp.setPskServerCb(testServerCb); + sp.setPskIdentityHint("wolfssl hint"); + sp.setKeepArrays(true); + sp.setCipherSuites( + new String[]{pskCipher}); + serverSock.setSSLParameters(sp); + + serverSock.startHandshake(); + + InputStream in = serverSock.getInputStream(); + OutputStream out = serverSock.getOutputStream(); + byte[] buf = new byte[64]; + int n = in.read(buf); + out.write(buf, 0, n); + + serverSock.close(); + + } catch (Exception e) { + serverEx[0] = e; + } finally { + latch.countDown(); + } + }); + serverThread.start(); + + try { + SSLSocket clientSock = (SSLSocket) + ctx.getSocketFactory().createSocket( + InetAddress.getLoopbackAddress(), port); + + WolfSSLParameters cp = new WolfSSLParameters(); + cp.setPskClientCb(testClientCb); + cp.setKeepArrays(true); + cp.setCipherSuites( + new String[]{pskCipher}); + clientSock.setSSLParameters(cp); + + clientSock.startHandshake(); + + OutputStream out = clientSock.getOutputStream(); + InputStream in = clientSock.getInputStream(); + out.write("hello".getBytes()); + byte[] buf = new byte[64]; + int n = in.read(buf); + assertEquals("hello", new String(buf, 0, n)); + + clientSock.close(); + + } finally { + ss.close(); + } + + assertTrue("Server thread timed out", + latch.await(10, TimeUnit.SECONDS)); + + if (serverEx[0] != null) { + fail("Server thread failed: " + serverEx[0].getMessage()); + } + + System.out.println("passed"); + } + + /** + * Perform SSLEngine handshake using in-memory buffers (no sockets). Data + * produced by one engine is fed directly to the other. + */ + private void doInMemoryHandshake(SSLEngine client, SSLEngine server) + throws Exception { + + int netSize = Math.max(client.getSession().getPacketBufferSize(), + server.getSession().getPacketBufferSize()); + int appSize = Math.max( + client.getSession().getApplicationBufferSize(), + server.getSession().getApplicationBufferSize()); + + /* Network buffers: client writes to cToS, server reads from cToS. + * Server writes to sToC, client reads from sToC. */ + ByteBuffer cToS = ByteBuffer.allocate(netSize); + ByteBuffer sToC = ByteBuffer.allocate(netSize); + ByteBuffer clientApp = ByteBuffer.allocate(appSize); + ByteBuffer serverApp = ByteBuffer.allocate(appSize); + ByteBuffer emptyApp = ByteBuffer.allocate(0); + + client.beginHandshake(); + server.beginHandshake(); + + HandshakeStatus chs = client.getHandshakeStatus(); + HandshakeStatus shs = server.getHandshakeStatus(); + + int maxLoops = 200; + int loops = 0; + + while (loops < maxLoops) { + boolean cDone = + (chs == HandshakeStatus.NOT_HANDSHAKING || + chs == HandshakeStatus.FINISHED); + boolean sDone = + (shs == HandshakeStatus.NOT_HANDSHAKING || + shs == HandshakeStatus.FINISHED); + if (cDone && sDone) { + break; + } + + /* Process client side */ + if (chs == HandshakeStatus.NEED_WRAP) { + SSLEngineResult res = client.wrap(emptyApp, cToS); + chs = res.getHandshakeStatus(); + } + else if (chs == HandshakeStatus.NEED_UNWRAP) { + sToC.flip(); + SSLEngineResult res = client.unwrap(sToC, clientApp); + sToC.compact(); + chs = res.getHandshakeStatus(); + } + else if (chs == HandshakeStatus.NEED_TASK) { + Runnable task; + while ((task = + client.getDelegatedTask()) != null) { + task.run(); + } + chs = client.getHandshakeStatus(); + } + + /* Process server side */ + if (shs == HandshakeStatus.NEED_WRAP) { + SSLEngineResult res = server.wrap(emptyApp, sToC); + shs = res.getHandshakeStatus(); + } + else if (shs == HandshakeStatus.NEED_UNWRAP) { + cToS.flip(); + SSLEngineResult res = server.unwrap(cToS, serverApp); + cToS.compact(); + shs = res.getHandshakeStatus(); + } + else if (shs == HandshakeStatus.NEED_TASK) { + Runnable task; + while ((task = + server.getDelegatedTask()) != null) { + task.run(); + } + shs = server.getHandshakeStatus(); + } + loops++; + } + + if (loops >= maxLoops) { + fail("Handshake did not complete in " + maxLoops + + " iterations, chs=" + chs + " shs=" + shs); + } + } +}