Skip to content

Commit 6225da9

Browse files
dunglasCopilot
andauthored
refactor: improve ExtensionWorkers API (#1952)
* refactor: improve ExtensionWorkers API * Update workerextension.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update workerextension.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update caddy/app.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * review * fix tests * docs * errors * improved error handling * fix race * add missing return * use %q in Errorf --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 407ef09 commit 6225da9

11 files changed

Lines changed: 219 additions & 203 deletions

caddy/app.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"path/filepath"
88
"strconv"
99
"strings"
10+
"sync"
1011
"time"
1112

1213
"github.com/caddyserver/caddy/v2"
@@ -18,6 +19,22 @@ import (
1819
"github.com/dunglas/frankenphp/internal/fastabs"
1920
)
2021

22+
var (
23+
options []frankenphp.Option
24+
optionsMU sync.RWMutex
25+
)
26+
27+
// EXPERIMENTAL: RegisterWorkers provides a way for extensions to register frankenphp.Workers
28+
func RegisterWorkers(name, fileName string, num int, wo ...frankenphp.WorkerOption) frankenphp.Workers {
29+
w, opt := frankenphp.WithExtensionWorkers(name, fileName, num, wo...)
30+
31+
optionsMU.Lock()
32+
options = append(options, opt)
33+
optionsMU.Unlock()
34+
35+
return w
36+
}
37+
2138
// FrankenPHPApp represents the global "frankenphp" directive in the Caddyfile
2239
// it's responsible for starting up the global PHP instance and all threads
2340
//
@@ -118,6 +135,11 @@ func (f *FrankenPHPApp) Start() error {
118135
frankenphp.WithPhpIni(f.PhpIni),
119136
frankenphp.WithMaxWaitTime(f.MaxWaitTime),
120137
}
138+
139+
optionsMU.RLock()
140+
opts = append(opts, options...)
141+
optionsMU.RUnlock()
142+
121143
for _, w := range append(f.Workers) {
122144
workerOpts := []frankenphp.WorkerOption{
123145
frankenphp.WithWorkerEnv(w.Env),
@@ -151,6 +173,10 @@ func (f *FrankenPHPApp) Stop() error {
151173
f.NumThreads = 0
152174
f.MaxWaitTime = 0
153175

176+
optionsMU.Lock()
177+
options = nil
178+
optionsMU.Unlock()
179+
154180
return nil
155181
}
156182

caddy/caddy_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -965,15 +965,15 @@ func TestMaxWaitTime(t *testing.T) {
965965
for range 10 {
966966
go func() {
967967
statusCode := getStatusCode("http://localhost:"+testPort+"/sleep.php?sleep=10", t)
968-
if statusCode == http.StatusGatewayTimeout {
968+
if statusCode == http.StatusServiceUnavailable {
969969
success.Store(true)
970970
}
971971
wg.Done()
972972
}()
973973
}
974974
wg.Wait()
975975

976-
require.True(t, success.Load(), "At least one request should have failed with a 504 Gateway Timeout status")
976+
require.True(t, success.Load(), "At least one request should have failed with a 503 Service Unavailable status")
977977
}
978978

979979
func TestMaxWaitTimeWorker(t *testing.T) {
@@ -1012,23 +1012,26 @@ func TestMaxWaitTimeWorker(t *testing.T) {
10121012
for range 10 {
10131013
go func() {
10141014
statusCode := getStatusCode("http://localhost:"+testPort+"/sleep.php?sleep=10&iteration=1", t)
1015-
if statusCode == http.StatusGatewayTimeout {
1015+
if statusCode == http.StatusServiceUnavailable {
10161016
success.Store(true)
10171017
}
10181018
wg.Done()
10191019
}()
10201020
}
10211021
wg.Wait()
1022-
require.True(t, success.Load(), "At least one request should have failed with a 504 Gateway Timeout status")
1022+
require.True(t, success.Load(), "At least one request should have failed with a 503 Service Unavailable status")
10231023

10241024
// Fetch metrics
10251025
resp, err := http.Get("http://localhost:2999/metrics")
10261026
require.NoError(t, err, "failed to fetch metrics")
1027-
defer resp.Body.Close()
1027+
t.Cleanup(func() {
1028+
require.NoError(t, resp.Body.Close())
1029+
})
10281030

10291031
// Read and parse metrics
10301032
metrics := new(bytes.Buffer)
10311033
_, err = metrics.ReadFrom(resp.Body)
1034+
require.NoError(t, err)
10321035

10331036
expectedMetrics := `
10341037
# TYPE frankenphp_worker_queue_depth gauge

caddy/module.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package caddy
22

33
import (
44
"encoding/json"
5+
"errors"
56
"fmt"
67
"log/slog"
78
"net/http"
@@ -192,8 +193,11 @@ func (f *FrankenPHPModule) ServeHTTP(w http.ResponseWriter, r *http.Request, _ c
192193
frankenphp.WithOriginalRequest(&origReq),
193194
frankenphp.WithWorkerName(workerName),
194195
)
196+
if err != nil {
197+
return caddyhttp.Error(http.StatusInternalServerError, err)
198+
}
195199

196-
if err = frankenphp.ServeHTTP(w, fr); err != nil {
200+
if err = frankenphp.ServeHTTP(w, fr); err != nil && !errors.As(err, &frankenphp.ErrRejected{}) {
197201
return caddyhttp.Error(http.StatusInternalServerError, err)
198202
}
199203

context.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package frankenphp
22

33
import (
44
"context"
5+
"errors"
6+
"fmt"
57
"log/slog"
68
"net/http"
79
"os"
@@ -117,23 +119,25 @@ func (fc *frankenPHPContext) closeContext() {
117119
}
118120

119121
// validate checks if the request should be outright rejected
120-
func (fc *frankenPHPContext) validate() bool {
122+
func (fc *frankenPHPContext) validate() error {
121123
if strings.Contains(fc.request.URL.Path, "\x00") {
122-
fc.rejectBadRequest("Invalid request path")
124+
fc.reject(ErrInvalidRequestPath)
123125

124-
return false
126+
return ErrInvalidRequestPath
125127
}
126128

127129
contentLengthStr := fc.request.Header.Get("Content-Length")
128130
if contentLengthStr != "" {
129131
if contentLength, err := strconv.Atoi(contentLengthStr); err != nil || contentLength < 0 {
130-
fc.rejectBadRequest("invalid Content-Length header: " + contentLengthStr)
132+
e := fmt.Errorf("%w: %q", ErrInvalidContentLengthHeader, contentLengthStr)
133+
134+
fc.reject(e)
131135

132-
return false
136+
return e
133137
}
134138
}
135139

136-
return true
140+
return nil
137141
}
138142

139143
func (fc *frankenPHPContext) clientHasClosed() bool {
@@ -149,16 +153,22 @@ func (fc *frankenPHPContext) clientHasClosed() bool {
149153
}
150154
}
151155

152-
// reject sends a response with the given status code and message
153-
func (fc *frankenPHPContext) reject(statusCode int, message string) {
156+
// reject sends a response with the given status code and error
157+
func (fc *frankenPHPContext) reject(err error) {
154158
if fc.isDone {
155159
return
156160
}
157161

162+
re := &ErrRejected{}
163+
if !errors.As(err, re) {
164+
// Should never happen
165+
panic("only instance of ErrRejected can be passed to reject")
166+
}
167+
158168
rw := fc.responseWriter
159169
if rw != nil {
160-
rw.WriteHeader(statusCode)
161-
_, _ = rw.Write([]byte(message))
170+
rw.WriteHeader(re.status)
171+
_, _ = rw.Write([]byte(err.Error()))
162172

163173
if f, ok := rw.(http.Flusher); ok {
164174
f.Flush()
@@ -167,7 +177,3 @@ func (fc *frankenPHPContext) reject(statusCode int, message string) {
167177

168178
fc.closeContext()
169179
}
170-
171-
func (fc *frankenPHPContext) rejectBadRequest(message string) {
172-
fc.reject(http.StatusBadRequest, message)
173-
}

frankenphp.go

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ var (
5252
ErrScriptExecution = errors.New("error during PHP script execution")
5353
ErrNotRunning = errors.New("FrankenPHP is not running. For proper configuration visit: https://frankenphp.dev/docs/config/#caddyfile-config")
5454

55+
ErrInvalidRequestPath = ErrRejected{"invalid request path", http.StatusBadRequest}
56+
ErrInvalidContentLengthHeader = ErrRejected{"invalid Content-Length header", http.StatusBadRequest}
57+
ErrMaxWaitTimeExceeded = ErrRejected{"maximum request handling time exceeded", http.StatusServiceUnavailable}
58+
5559
isRunning bool
5660
onServerShutdown []func()
5761

@@ -63,34 +67,43 @@ var (
6367
maxWaitTime time.Duration
6468
)
6569

70+
type ErrRejected struct {
71+
message string
72+
status int
73+
}
74+
75+
func (e ErrRejected) Error() string {
76+
return e.message
77+
}
78+
6679
type syslogLevel int
6780

6881
const (
69-
emerg syslogLevel = iota // system is unusable
70-
alert // action must be taken immediately
71-
crit // critical conditions
72-
err // error conditions
73-
warning // warning conditions
74-
notice // normal but significant condition
75-
info // informational
76-
debug // debug-level messages
82+
syslogLevelEmerg syslogLevel = iota // system is unusable
83+
syslogLevelAlert // action must be taken immediately
84+
syslogLevelCrit // critical conditions
85+
syslogLevelErr // error conditions
86+
syslogLevelWarn // warning conditions
87+
syslogLevelNotice // normal but significant condition
88+
syslogLevelInfo // informational
89+
syslogLevelDebug // debug-level messages
7790
)
7891

7992
func (l syslogLevel) String() string {
8093
switch l {
81-
case emerg:
94+
case syslogLevelEmerg:
8295
return "emerg"
83-
case alert:
96+
case syslogLevelAlert:
8497
return "alert"
85-
case crit:
98+
case syslogLevelCrit:
8699
return "crit"
87-
case err:
100+
case syslogLevelErr:
88101
return "err"
89-
case warning:
102+
case syslogLevelWarn:
90103
return "warning"
91-
case notice:
104+
case syslogLevelNotice:
92105
return "notice"
93-
case debug:
106+
case syslogLevelDebug:
94107
return "debug"
95108
default:
96109
return "info"
@@ -210,11 +223,6 @@ func Init(options ...Option) error {
210223

211224
registerExtensions()
212225

213-
// add registered external workers
214-
for _, ew := range extensionWorkers {
215-
options = append(options, WithWorkers(ew.name, ew.fileName, ew.num, ew.options...))
216-
}
217-
218226
opt := &opt{}
219227
for _, o := range options {
220228
if err := o(opt); err != nil {
@@ -336,20 +344,17 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error
336344

337345
fc.responseWriter = responseWriter
338346

339-
if !fc.validate() {
340-
return nil
347+
if err := fc.validate(); err != nil {
348+
return err
341349
}
342350

343351
// Detect if a worker is available to handle this request
344352
if fc.worker != nil {
345-
fc.worker.handleRequest(fc)
346-
347-
return nil
353+
return fc.worker.handleRequest(fc)
348354
}
349355

350356
// If no worker was available, send the request to non-worker threads
351-
handleRequestWithRegularPHPThreads(fc)
352-
return nil
357+
return handleRequestWithRegularPHPThreads(fc)
353358
}
354359

355360
//export go_ub_write
@@ -566,19 +571,19 @@ func go_log(message *C.char, level C.int) {
566571
m := C.GoString(message)
567572

568573
var le syslogLevel
569-
if level < C.int(emerg) || level > C.int(debug) {
570-
le = info
574+
if level < C.int(syslogLevelEmerg) || level > C.int(syslogLevelDebug) {
575+
le = syslogLevelInfo
571576
} else {
572577
le = syslogLevel(level)
573578
}
574579

575580
switch le {
576-
case emerg, alert, crit, err:
581+
case syslogLevelEmerg, syslogLevelAlert, syslogLevelCrit, syslogLevelErr:
577582
logger.LogAttrs(context.Background(), slog.LevelError, m, slog.String("syslog_level", syslogLevel(level).String()))
578583

579-
case warning:
584+
case syslogLevelWarn:
580585
logger.LogAttrs(context.Background(), slog.LevelWarn, m, slog.String("syslog_level", syslogLevel(level).String()))
581-
case debug:
586+
case syslogLevelDebug:
582587
logger.LogAttrs(context.Background(), slog.LevelDebug, m, slog.String("syslog_level", syslogLevel(level).String()))
583588

584589
default:

frankenphp_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,17 @@ func runTest(t *testing.T, test func(func(http.ResponseWriter, *http.Request), *
7878
}
7979

8080
err := frankenphp.Init(initOpts...)
81-
require.Nil(t, err)
81+
require.NoError(t, err)
8282
defer frankenphp.Shutdown()
8383

8484
handler := func(w http.ResponseWriter, r *http.Request) {
8585
req, err := frankenphp.NewRequestWithContext(r, frankenphp.WithRequestDocumentRoot(testDataDir, false))
8686
assert.NoError(t, err)
8787

8888
err = frankenphp.ServeHTTP(w, req)
89-
assert.NoError(t, err)
89+
if err != nil && !errors.As(err, &frankenphp.ErrRejected{}) {
90+
assert.Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err))
91+
}
9092
}
9193

9294
var ts *httptest.Server
@@ -109,6 +111,7 @@ func runTest(t *testing.T, test func(func(http.ResponseWriter, *http.Request), *
109111

110112
func testRequest(req *http.Request, handler func(http.ResponseWriter, *http.Request), t *testing.T) (string, *http.Response) {
111113
t.Helper()
114+
112115
w := httptest.NewRecorder()
113116
handler(w, req)
114117
resp := w.Result()
@@ -988,7 +991,7 @@ func FuzzRequest(f *testing.F) {
988991
// The response status must be 400 if the request path contains null bytes
989992
if strings.Contains(req.URL.Path, "\x00") {
990993
assert.Equal(t, 400, resp.StatusCode)
991-
assert.Contains(t, body, "Invalid request path")
994+
assert.Contains(t, body, "invalid request path")
992995
return
993996
}
994997

0 commit comments

Comments
 (0)