diff --git a/README.md b/README.md index 9cb9b3c..f2c8a3d 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ See [below](#other-servers) for examples of use with other servers. | **[`writeToDisk`](#writetodisk)** | `boolean\|Function` | `false` | Instructs the module to write files to the configured location on disk as specified in your Rspack configuration. | | **[`outputFileSystem`](#outputfilesystem)** | `Object` | [`memfs`](https://github.com/streamich/memfs) | Set the default file system which will be used by Rspack as primary destination of generated files. | | **[`modifyResponseData`](#modifyresponsedata)** | `Function` | `undefined` | Allows to set up a callback to change the response data. | +| **[`forwardError`](#forwarderror)** | `boolean` | `false` | Enable or disable forwarding errors to the next middleware. | The middleware accepts an `options` Object. The following is a property reference for the Object. @@ -313,6 +314,33 @@ devMiddleware(compiler, { }); ``` +### forwardError + +Type: `Boolean` +Default: `false` + +When enabled, handled errors are forwarded to the next middleware instead of being rendered as the built-in HTML error page. +This allows Connect/Express/Router and the Koa/Hono wrappers to hand errors to your application's own error middleware. +Hapi still does not support this option because its request lifecycle does not expose equivalent `next(err)` forwarding semantics. + +```js +const express = require("express"); +const { devMiddleware } = require("@rspack/dev-middleware"); +const { rspack } = require("@rspack/core"); + +const compiler = rspack({ + /* Rspack configuration */ +}); + +const app = express(); + +app.use(devMiddleware(compiler, { forwardError: true })); + +app.use((error, req, res, next) => { + res.status(500).send("Something broke!"); +}); +``` + ## API `@rspack/dev-middleware` also provides convenience methods that can be use to diff --git a/src/index.js b/src/index.js index 09bb80c..ece0d30 100644 --- a/src/index.js +++ b/src/index.js @@ -127,6 +127,7 @@ const noop = () => {}; * @property {boolean=} lastModified options to generate last modified header * @property {(boolean | number | string | { maxAge?: number, immutable?: boolean })=} cacheControl options to generate cache headers * @property {boolean=} cacheImmutable enable immutable cache headers for immutable assets (defaults to true when omitted) + * @property {boolean=} forwardError forward errors to the next middleware */ /** @@ -192,6 +193,42 @@ const noop = () => {}; * @typedef {T & { [P in K]: NonNullable }} WithoutUndefined */ +/** + * @param {import("fs").ReadStream} stream readable stream + * @param {(error?: Error) => void} callback callback + * @returns {void} + */ +function waitUntilStreamReady(stream, callback) { + let isResolved = false; + + /** + * @param {Error=} error error + * @returns {void} + */ + const onEvent = (error) => { + if (isResolved) { + return; + } + + isResolved = true; + + stream.removeListener("error", onEvent); + stream.removeListener("readable", onEvent); + stream.removeListener("end", onEvent); + + if (error) { + stream.destroy(); + } + + callback(error); + }; + + stream.once("error", onEvent); + stream.once("readable", onEvent); + // Empty stream + stream.once("end", onEvent); +} + /** * @template {IncomingMessage} [RequestInternal=IncomingMessage] * @template {ServerResponse} [ResponseInternal=ServerResponse] @@ -435,10 +472,17 @@ function koaWrapper(compiler, options) { * @param {import("fs").ReadStream} stream readable stream */ res.stream = (stream) => { - ctx.body = stream; + waitUntilStreamReady(stream, (error) => { + if (error) { + reject(error); + return; + } - isFinished = true; - resolve(); + ctx.body = stream; + + isFinished = true; + resolve(); + }); }; /** * @param {string | Buffer} data data @@ -476,6 +520,10 @@ function koaWrapper(compiler, options) { }, ); } catch (err) { + if (options?.forwardError) { + throw err; + } + ctx.status = /** @type {Error & { statusCode: number }} */ (err).statusCode || /** @type {Error & { status: number }} */ (err).status || @@ -602,10 +650,17 @@ function honoWrapper(compiler, options) { * @param {import("fs").ReadStream} stream readable stream */ res.stream = (stream) => { - body = stream; + waitUntilStreamReady(stream, (error) => { + if (error) { + reject(error); + return; + } - isFinished = true; - resolve(); + body = stream; + + isFinished = true; + resolve(); + }); }; /** @@ -651,6 +706,10 @@ function honoWrapper(compiler, options) { }, ); } catch (err) { + if (options?.forwardError) { + throw err; + } + context.status(500); return context.json({ message: /** @type {Error} */ (err).message }); diff --git a/src/middleware.js b/src/middleware.js index 973fef1..44d3375 100644 --- a/src/middleware.js +++ b/src/middleware.js @@ -193,8 +193,6 @@ function wrapper(context) { } const acceptedMethods = context.options.methods || ["GET", "HEAD"]; - // TODO do we need an option here? - const forwardError = false; initState(res); @@ -212,13 +210,23 @@ function wrapper(context) { * @returns {Promise} */ async function sendError(message, status, options) { - if (forwardError) { + if (context.options.forwardError) { + if (!getHeadersSent(res)) { + const headers = getResponseHeaders(res); + + for (let i = 0; i < headers.length; i++) { + removeResponseHeader(res, headers[i]); + } + } + const error = /** @type {Error & { statusCode: number }} */ (new Error(message)); error.statusCode = status; await goNext(error); + + return; } const escapeHtml = getEscapeHtml(); diff --git a/test/middleware.test.js b/test/middleware.test.js index b65611c..abd28aa 100644 --- a/test/middleware.test.js +++ b/test/middleware.test.js @@ -120,6 +120,11 @@ async function frameworkFactory( case "koa": { // eslint-disable-next-line new-cap const app = new framework(); + + if (typeof options.setupApp === "function") { + options.setupApp(app); + } + const koaMiddleware = devMiddleware.koaWrapper( compiler, devMiddlewareOptions, @@ -145,8 +150,10 @@ async function frameworkFactory( case "hono": { // eslint-disable-next-line new-cap const app = new framework(); - const server = await startServer(name, app); - const req = request(server); + + if (typeof options.setupApp === "function") { + options.setupApp(app); + } const instance = devMiddleware.honoWrapper( compiler, devMiddlewareOptions, @@ -164,12 +171,19 @@ async function frameworkFactory( } } + const server = await startServer(name, app); + const req = request(server); + return [server, req, instance.devMiddleware]; } default: { const isRouter = name === "router"; const app = framework(); + if (typeof options.setupApp === "function") { + options.setupApp(app); + } + const instance = devMiddleware(compiler, devMiddlewareOptions); const middlewares = typeof options.setupMiddlewares === "function" @@ -3604,6 +3618,270 @@ describe.each([ }); }); + describe("should call the app error middleware for handled errors when forwardError is enabled", () => { + let compiler; + + const outputPath = path.resolve( + __dirname, + "./outputs/basic-test-errors-headers-sent", + ); + + let nextWasCalled = false; + let forwardedError; + + beforeAll(async () => { + compiler = getCompiler({ + ...webpackConfig, + output: { + filename: "bundle.js", + path: outputPath, + }, + }); + + [server, req, instance] = await frameworkFactory( + name, + framework, + compiler, + { + etag: "weak", + lastModified: true, + forwardError: true, + }, + { + setupApp: (app) => { + if (name === "hono") { + app.onError((error, c) => { + forwardedError = error; + nextWasCalled = true; + c.status(500); + + return c.text("error"); + }); + } + }, + setupMiddlewares: (middlewares) => { + if (name === "hapi") { + // There's no such thing as "the next route handler" in hapi. One request is matched to one or no route handlers. + } else if (name === "koa") { + middlewares.unshift(async (ctx, next) => { + try { + await next(); + } catch (error) { + forwardedError = error; + nextWasCalled = true; + ctx.status = 500; + ctx.body = "error"; + } + }); + } else { + middlewares.push((error, _req, res, _next) => { + forwardedError = error; + nextWasCalled = true; + res.statusCode = 500; + res.end("error"); + }); + } + + return middlewares; + }, + }, + ); + + instance.context.outputFileSystem.mkdirSync(outputPath, { + recursive: true, + }); + instance.context.outputFileSystem.writeFileSync( + path.resolve(outputPath, "index.html"), + "HTML", + ); + instance.context.outputFileSystem.writeFileSync( + path.resolve(outputPath, "image.svg"), + "svg image", + ); + instance.context.outputFileSystem.writeFileSync( + path.resolve(outputPath, "file.text"), + "text", + ); + + const originalMethod = + instance.context.outputFileSystem.createReadStream; + + instance.context.outputFileSystem.createReadStream = + function createReadStream(...args) { + if (args[0].endsWith("image.svg")) { + const brokenStream = new this.ReadStream(...args); + + brokenStream._read = function _read() { + const error = new Error("test"); + error.code = "ENAMETOOLONG"; + this.emit("error", error); + this.end(); + this.destroy(); + }; + + return brokenStream; + } + + return originalMethod(...args); + }; + }); + + beforeEach(() => { + nextWasCalled = false; + forwardedError = undefined; + }); + + afterAll(async () => { + await close(server, instance); + }); + + it("should work with piping stream", async () => { + const response = await req.get("/file.text"); + + expect(response.statusCode).toBe(200); + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + }); + + it('should return the "500" code for requests above root', async () => { + const response = await req.get("/public/..%2f../middleware.test.js"); + + expect(response.statusCode).toBe(500); + + if (name === "hapi") { + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + } else { + if (name !== "hono") { + expect(response.text).toBe("error"); + } + expect(nextWasCalled).toBe(true); + expect(forwardedError).toBeInstanceOf(Error); + expect(forwardedError.statusCode).toBe(403); + } + }); + + it('should return the "500" code for the "GET" request to the bundle file with etag and wrong "if-match" header', async () => { + const firstResponse = await req.get("/file.text"); + + expect(firstResponse.statusCode).toBe(200); + expect(firstResponse.headers.etag).toBeDefined(); + expect(firstResponse.headers.etag.startsWith("W/")).toBe(true); + + const response = await req.get("/file.text").set("if-match", "test"); + + expect(response.statusCode).toBe(500); + + if (name === "hapi") { + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + } else { + if (name !== "hono") { + expect(response.text).toBe("error"); + } + expect(nextWasCalled).toBe(true); + expect(forwardedError).toBeInstanceOf(Error); + expect(forwardedError.statusCode).toBe(412); + } + }); + + it('should return the "500" code for the "GET" request with the invalid range header', async () => { + const response = await req + .get("/file.text") + .set("Range", "bytes=9999999-"); + + expect(response.statusCode).toBe(500); + + if (name === "hapi") { + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + } else { + if (name !== "hono") { + expect(response.text).toBe("error"); + } + expect(nextWasCalled).toBe(true); + expect(forwardedError).toBeInstanceOf(Error); + expect(forwardedError.statusCode).toBe(416); + } + }); + + it('should return the "500" code for the "GET" request to the "image.svg" file when it throws a reading error', async () => { + const response = await req.get("/image.svg"); + + expect(response.statusCode).toBe(500); + + if (name === "hapi") { + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + } else { + if (name !== "hono") { + expect(response.text).toBe("error"); + } + expect(nextWasCalled).toBe(true); + expect(forwardedError).toBeInstanceOf(Error); + expect(forwardedError.statusCode).toBe(404); + } + }); + + it('should return the "200" code for the "HEAD" request to the bundle file', async () => { + const response = await req.head("/file.text"); + + expect(response.statusCode).toBe(200); + expect(response.text).toBeUndefined(); + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + }); + + it('should return the "304" code for the "GET" request to the bundle file with etag and "if-none-match"', async () => { + const firstResponse = await req.get("/file.text"); + + expect(firstResponse.statusCode).toBe(200); + expect(firstResponse.headers.etag).toBeDefined(); + expect(firstResponse.headers.etag.startsWith("W/")).toBe(true); + + const secondResponse = await req + .get("/file.text") + .set("if-none-match", firstResponse.headers.etag); + + expect(secondResponse.statusCode).toBe(304); + expect(secondResponse.headers.etag).toBeDefined(); + expect(secondResponse.headers.etag.startsWith("W/")).toBe(true); + + const thirdResponse = await req + .get("/file.text") + .set("if-none-match", firstResponse.headers.etag); + + expect(thirdResponse.statusCode).toBe(304); + expect(thirdResponse.headers.etag).toBeDefined(); + expect(thirdResponse.headers.etag.startsWith("W/")).toBe(true); + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + }); + + it('should return the "304" code for the "GET" request to the bundle file with lastModified and "if-modified-since" header', async () => { + const firstResponse = await req.get("/file.text"); + + expect(firstResponse.statusCode).toBe(200); + expect(firstResponse.headers["last-modified"]).toBeDefined(); + + const secondResponse = await req + .get("/file.text") + .set("if-modified-since", firstResponse.headers["last-modified"]); + + expect(secondResponse.statusCode).toBe(304); + expect(secondResponse.headers["last-modified"]).toBeDefined(); + + const thirdResponse = await req + .get("/file.text") + .set("if-modified-since", secondResponse.headers["last-modified"]); + + expect(thirdResponse.statusCode).toBe(304); + expect(thirdResponse.headers["last-modified"]).toBeDefined(); + expect(nextWasCalled).toBe(false); + expect(forwardedError).toBeUndefined(); + }); + }); + describe("should fallthrough for not found files", () => { let compiler; diff --git a/types/index.d.ts b/types/index.d.ts index b76d003..a744c3b 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -1,155 +1,3 @@ -/** @typedef {import("@rspack/core").Compiler} Compiler */ -/** @typedef {import("@rspack/core").MultiCompiler} MultiCompiler */ -/** @typedef {import("@rspack/core").Configuration} Configuration */ -/** @typedef {import("@rspack/core").Stats} Stats */ -/** @typedef {import("@rspack/core").MultiStats} MultiStats */ -/** @typedef {import("fs").ReadStream} ReadStream */ -/** - * @typedef {object} ExtendedServerResponse - * @property {{ rspack?: { devMiddleware?: Context }, webpack?: { devMiddleware?: Context } }=} locals locals - */ -/** @typedef {import("http").IncomingMessage} IncomingMessage */ -/** @typedef {import("http").ServerResponse & ExtendedServerResponse} ServerResponse */ -/** @typedef {any} EXPECTED_ANY */ -/** @typedef {Function} EXPECTED_FUNCTION */ -/** - * @callback NextFunction - * @param {EXPECTED_ANY=} err error - * @returns {void} - */ -/** - * @typedef {NonNullable} WatchOptions - */ -/** - * @typedef {boolean | Configuration["devServer"] | undefined} DevServerOption - */ -/** - * @typedef {Compiler["watching"]} Watching - */ -/** - * @typedef {ReturnType} MultiWatching - */ -/** - * @typedef {import("@rspack/core").OutputFileSystem & { createReadStream?: import("fs").createReadStream, statSync: import("fs").statSync, readFileSync: import("fs").readFileSync }} OutputFileSystem - */ -/** @typedef {ReturnType} Logger */ -/** @typedef {{ close(callback: (err?: Error | null | undefined) => void): void }} ClosableWatching */ -/** - * @callback Callback - * @param {(Stats | MultiStats)=} stats - */ -/** - * @typedef {object} ResponseData - * @property {Buffer | ReadStream} data data - * @property {number} byteLength byte length - */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @callback ModifyResponseData - * @param {RequestInternal} req req - * @param {ResponseInternal} res res - * @param {Buffer | ReadStream} data data - * @param {number} byteLength byte length - * @returns {ResponseData} - */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @typedef {object} Context - * @property {boolean} state state - * @property {Stats | MultiStats | undefined} stats stats - * @property {Callback[]} callbacks callbacks - * @property {Options} options options - * @property {Compiler | MultiCompiler} compiler compiler - * @property {Watching | MultiWatching | undefined} watching watching - * @property {Logger} logger logger - * @property {OutputFileSystem} outputFileSystem output file system - */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @typedef {WithoutUndefined, "watching">} FilledContext - */ -/** @typedef {Record | { key: string, value: number | string }[]} NormalizedHeaders */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @typedef {NormalizedHeaders | ((req: RequestInternal, res: ResponseInternal, context: Context) => void | undefined | NormalizedHeaders) | undefined} Headers - */ -/** - * @template {IncomingMessage} [RequestInternal = IncomingMessage] - * @template {ServerResponse} [ResponseInternal = ServerResponse] - * @typedef {object} Options - * @property {{ [key: string]: string }=} mimeTypes mime types - * @property {(string | undefined)=} mimeTypeDefault mime type default - * @property {(boolean | ((targetPath: string) => boolean))=} writeToDisk write to disk - * @property {string[]=} methods methods - * @property {Headers=} headers headers - * @property {NonNullable["publicPath"]=} publicPath public path - * @property {Configuration["stats"]=} stats stats - * @property {boolean=} serverSideRender is server side render - * @property {OutputFileSystem=} outputFileSystem output file system - * @property {(boolean | string)=} index index - * @property {ModifyResponseData=} modifyResponseData modify response data - * @property {"weak" | "strong"=} etag options to generate etag header - * @property {boolean=} lastModified options to generate last modified header - * @property {(boolean | number | string | { maxAge?: number, immutable?: boolean })=} cacheControl options to generate cache headers - * @property {boolean=} cacheImmutable enable immutable cache headers for immutable assets (defaults to true when omitted) - */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @callback Middleware - * @param {RequestInternal} req - * @param {ResponseInternal} res - * @param {NextFunction} next - * @returns {Promise} - */ -/** @typedef {import("./utils/getFilenameFromUrl.js").Extra} Extra */ -/** - * @callback GetFilenameFromUrl - * @param {string} url - * @param {Extra=} extra - * @returns {string | undefined} - */ -/** - * @callback WaitUntilValid - * @param {Callback} callback - */ -/** - * @callback Invalidate - * @param {Callback} callback - */ -/** - * @callback Close - * @param {(err: Error | null | undefined) => void} callback - */ -/** - * @template {IncomingMessage} RequestInternal - * @template {ServerResponse} ResponseInternal - * @typedef {object} AdditionalMethods - * @property {GetFilenameFromUrl} getFilenameFromUrl get filename from url - * @property {WaitUntilValid} waitUntilValid wait until valid - * @property {Invalidate} invalidate invalidate - * @property {Close} close close - * @property {Context} context context - */ -/** - * @template {IncomingMessage} [RequestInternal=IncomingMessage] - * @template {ServerResponse} [ResponseInternal=ServerResponse] - * @typedef {Middleware & AdditionalMethods} API - */ -/** - * @template T - * @template {keyof T} K - * @typedef {Omit & Partial} WithOptional - */ -/** - * @template T - * @template {keyof T} K - * @typedef {T & { [P in K]: NonNullable }} WithoutUndefined - */ /** * @template {IncomingMessage} [RequestInternal=IncomingMessage] * @template {ServerResponse} [ResponseInternal=ServerResponse] @@ -367,6 +215,10 @@ export type Options< * enable immutable cache headers for immutable assets (defaults to true when omitted) */ cacheImmutable?: boolean | undefined; + /** + * forward errors to the next middleware + */ + forwardError?: boolean | undefined; }; export type Middleware< RequestInternal extends IncomingMessage = import("http").IncomingMessage,