diff --git a/packages/binding-mqtt/src/mqtt-client-factory.ts b/packages/binding-mqtt/src/mqtt-client-factory.ts index 78d64f77b..1a897f65d 100644 --- a/packages/binding-mqtt/src/mqtt-client-factory.ts +++ b/packages/binding-mqtt/src/mqtt-client-factory.ts @@ -23,8 +23,11 @@ import MqttClient from "./mqtt-client"; const debug = createDebugLogger("binding-mqtt", "mqtt-client-factory"); export default class MqttClientFactory implements ProtocolClientFactory { - public readonly scheme: string = "mqtt"; + public readonly scheme: string; private readonly clients: Array = []; + constructor(scheme: string = "mqtt") { + this.scheme = scheme; + } getClient(): ProtocolClient { const client = new MqttClient(); diff --git a/packages/binding-mqtt/test/mqtt-broker-server-interaction-test.integration.ts b/packages/binding-mqtt/test/mqtt-broker-server-interaction-test.integration.ts index 4f8c9fb69..d989978a8 100644 --- a/packages/binding-mqtt/test/mqtt-broker-server-interaction-test.integration.ts +++ b/packages/binding-mqtt/test/mqtt-broker-server-interaction-test.integration.ts @@ -103,7 +103,10 @@ describe("MQTT broker server interaction implementation", () => { servient = new Servient(); brokerServer = new MqttBrokerServer({ uri: brokerUri, selfHost: true }); servient.addServer(brokerServer); - servient.addClientFactory(new MqttClientFactory()); + servient.addClientFactory(new MqttClientFactory("mqtt")); + servient.addClientFactory(new MqttClientFactory("mqtts")); + servient.addClientFactory(new MqttClientFactory("ws+mqtt")); + servient.addClientFactory(new MqttClientFactory("wss+mqtt")); servient.start().then((WoT) => { WoT.produce(mqttThingModel) .then((thing) => { diff --git a/packages/binding-mqtt/test/mqtt-client-subscribe-test.integration.ts b/packages/binding-mqtt/test/mqtt-client-subscribe-test.integration.ts index 35ccc126f..a9bab15c6 100644 --- a/packages/binding-mqtt/test/mqtt-client-subscribe-test.integration.ts +++ b/packages/binding-mqtt/test/mqtt-client-subscribe-test.integration.ts @@ -49,7 +49,10 @@ describe("MQTT client implementation - integration", () => { brokerServer = new MqttBrokerServer({ uri: brokerUri, selfHost: true }); servient.addServer(brokerServer); - servient.addClientFactory(new MqttClientFactory()); + servient.addClientFactory(new MqttClientFactory("mqtt")); + servient.addClientFactory(new MqttClientFactory("mqtts")); + servient.addClientFactory(new MqttClientFactory("ws+mqtt")); + servient.addClientFactory(new MqttClientFactory("wss+mqtt")); servient.start().then((WoT) => { expect(brokerServer.getPort()).to.equal(brokerPort); @@ -110,7 +113,10 @@ describe("MQTT client implementation - integration", () => { }); servient.addServer(brokerServer); - servient.addClientFactory(new MqttClientFactory()); + servient.addClientFactory(new MqttClientFactory("mqtt")); + servient.addClientFactory(new MqttClientFactory("mqtts")); + servient.addClientFactory(new MqttClientFactory("ws+mqtt")); + servient.addClientFactory(new MqttClientFactory("wss+mqtt")); servient.addClientFactory(new MqttsClientFactory({ rejectUnauthorized: false })); servient.start().then((WoT) => { diff --git a/packages/binding-mqtt/test/mqtt-ws-test.integration.ts b/packages/binding-mqtt/test/mqtt-ws-test.integration.ts new file mode 100644 index 000000000..09560c1a0 --- /dev/null +++ b/packages/binding-mqtt/test/mqtt-ws-test.integration.ts @@ -0,0 +1,98 @@ +import { Servient } from "@node-wot/core"; +import MqttClientFactory from "../src/mqtt-client-factory"; +import * as WoT from "wot-typescript-definitions"; +import assert from "assert"; + +import aedes from "aedes"; +import * as http from "http"; +import * as WebSocket from "ws"; +import { createWebSocketStream } from "ws"; + +describe("MQTT over WebSocket Integration Test", () => { + let servient: Servient; + let wot: typeof WoT; + + let broker: ReturnType; + let httpServer: http.Server; + + before(async () => { + broker = aedes(); + + httpServer = http.createServer(); + + const wsServer = new WebSocket.Server({ server: httpServer }); + + wsServer.on("connection", (ws) => { + const stream = createWebSocketStream(ws); + broker.handle(stream); + }); + + await new Promise((resolve) => { + httpServer.listen(9001, resolve); + }); + + servient = new Servient(); + + servient.addClientFactory(new MqttClientFactory("mqtt")); + servient.addClientFactory(new MqttClientFactory("mqtts")); + servient.addClientFactory(new MqttClientFactory("ws+mqtt")); + servient.addClientFactory(new MqttClientFactory("wss+mqtt")); + + wot = await servient.start(); + }); + + after(async () => { + await servient.shutdown(); + + await new Promise((resolve) => { + httpServer.close(() => { + broker.close(); + resolve(); + }); + }); + }); + + it("should consume Thing via ws+mqtt composite scheme", async () => { + const td = { + title: "MQTTTestThing", + securityDefinitions: { nosec_sc: { scheme: "nosec" } }, + security: ["nosec_sc"], + properties: { + status: { + type: "string", + forms: [ + { + href: "ws://localhost:9001/test/status", + subprotocol: "mqtt", + op: ["readproperty"], + }, + ], + }, + }, + }; + + const thing = await wot.consume(td as any); + + // Publish test message manually through broker + await new Promise((resolve, reject) => { + broker.publish( + { + cmd: "publish", + topic: "test/status", + payload: Buffer.from("online"), + qos: 0, + retain: false, + dup: false, + }, + (err?: Error) => { + if (err) reject(err); + else resolve(); + } + ); + }); + + const output = await thing.readProperty("status"); + + assert.ok(output); + }); +}); diff --git a/packages/core/src/consumed-thing.ts b/packages/core/src/consumed-thing.ts index a17c72a2c..3a010baa7 100644 --- a/packages/core/src/consumed-thing.ts +++ b/packages/core/src/consumed-thing.ts @@ -516,45 +516,65 @@ export default class ConsumedThing extends Thing implements IConsumedThing { if (options.formIndex >= 0 && options.formIndex < forms.length) { form = forms[options.formIndex]; + const scheme = Helpers.extractScheme(form.href); - if (this.#servient.hasClientFor(scheme)) { - debug(`ConsumedThing '${this.title}' got client for '${scheme}'`); - client = this.#servient.getClientFor(scheme); - - if (!this.#clients.get(scheme)) { - // new client - this.ensureClientSecurity(client, form); - this.#clients.set(scheme, client); - } + const subprotocol = form.subprotocol; + const composite = subprotocol ? `${scheme}+${subprotocol}` : scheme; + + // Determine which scheme to use + let resolvedScheme: string; + + if (this.#servient.hasClientFor(composite)) { + resolvedScheme = composite; + } else if (this.#servient.hasClientFor(scheme)) { + resolvedScheme = scheme; } else { throw new Error(`ConsumedThing '${this.title}' missing ClientFactory for '${scheme}'`); } + debug(`ConsumedThing '${this.title}' using client for '${resolvedScheme}'`); + + // Reuse client if already cached + if (this.#clients.has(resolvedScheme)) { + client = this.#clients.get(resolvedScheme)!; + } else { + client = this.#servient.getClientFor(resolvedScheme); + this.ensureClientSecurity(client, form); + this.#clients.set(resolvedScheme, client); + } } else { throw new Error(`ConsumedThing '${this.title}' missing formIndex '${options.formIndex}'`); } } else { - const schemes = forms.map((link) => Helpers.extractScheme(link.href)); - const cacheIdx = schemes.findIndex((scheme) => this.#clients.has(scheme)); + const schemes = forms.map((link) => { + return Helpers.extractScheme(link.href); + }); + + const resolvedSchemes = forms.map((link) => { + const scheme = Helpers.extractScheme(link.href); + const subprotocol = link.subprotocol; + const composite = subprotocol ? `${scheme}+${subprotocol}` : scheme; + return this.#servient.hasClientFor(composite) ? composite : scheme; + }); + const cacheIdx = resolvedSchemes.findIndex((scheme) => this.#clients.has(scheme)); if (cacheIdx !== -1) { // from cache - debug(`ConsumedThing '${this.title}' chose cached client for '${schemes[cacheIdx]}'`); - // if cacheIdx is valid, then clients *contains* schemes[cacheIdx] - client = this.#clients.get(schemes[cacheIdx])!; + debug(`ConsumedThing '${this.title}' chose cached client for '${resolvedSchemes[cacheIdx]}'`); + client = this.#clients.get(resolvedSchemes[cacheIdx])!; form = this.findForm(forms, op, affordance, schemes, cacheIdx); } else { // new client debug(`ConsumedThing '${this.title}' has no client in cache (${cacheIdx})`); - const srvIdx = schemes.findIndex((scheme) => this.#servient.hasClientFor(scheme)); + const srvIdx = resolvedSchemes.findIndex((scheme) => this.#servient.hasClientFor(scheme)); if (srvIdx === -1) throw new Error(`ConsumedThing '${this.title}' missing ClientFactory for '${schemes}'`); - client = this.#servient.getClientFor(schemes[srvIdx]); + client = this.#servient.getClientFor(resolvedSchemes[srvIdx]); debug(`ConsumedThing '${this.title}' got new client for '${schemes[srvIdx]}'`); - this.#clients.set(schemes[srvIdx], client); + this.#clients.set(resolvedSchemes[srvIdx], client); form = this.findForm(forms, op, affordance, schemes, srvIdx); this.ensureClientSecurity(client, form);