Skip to content

Commit 9ae5a15

Browse files
Add ConvertOrderBy method for MongoDB-style sorting
Add `ORDER BY` support with `ConvertOrderBy` method. - Converts MongoDB-style sort objects to PostgreSQL `ORDER BY` clauses. - Supports both regular columns and JSONB fields with dual sorting. - Includes integration tests and fuzz tests.
1 parent d710938 commit 9ae5a15

7 files changed

Lines changed: 488 additions & 0 deletions

File tree

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,39 @@ values := []any{"aztec", "nuke", "", 2, 10}
9797
(given "customdata" is configured with `filter.WithNestedJSONB("customdata", "password", "playerCount")`)
9898

9999

100+
## Order By Support
101+
102+
In addition to filtering, this package also supports converting MongoDB-style sort objects into PostgreSQL ORDER BY clauses using the `ConvertOrderBy` method:
103+
104+
```go
105+
// Convert a sort object to an ORDER BY clause
106+
sortInput := []byte(`{"playerCount": -1, "name": 1}`)
107+
orderBy, err := converter.ConvertOrderBy(sortInput)
108+
if err != nil {
109+
// handle error
110+
}
111+
fmt.Println(orderBy) // "playerCount" DESC, "name" ASC
112+
113+
db.Query("SELECT * FROM games ORDER BY " + orderBy)
114+
```
115+
116+
### Sort Direction Values:
117+
- `1`: Ascending (ASC)
118+
- `-1`: Descending (DESC)
119+
120+
### JSONB Field Sorting:
121+
For JSONB fields, the package generates sophisticated ORDER BY clauses that handle both numeric and text sorting:
122+
123+
```go
124+
// With WithNestedJSONB("metadata", "created_at"):
125+
sortInput := []byte(`{"score": -1}`)
126+
orderBy, err := converter.ConvertOrderBy(sortInput)
127+
// Generates: (CASE WHEN jsonb_typeof(metadata->'score') = 'number' THEN (metadata->>'score')::numeric END) DESC NULLS LAST, metadata->>'score' DESC NULLS LAST
128+
```
129+
130+
This ensures proper sorting whether the JSONB field contains numeric or text values.
131+
132+
100133
## Difference with MongoDB
101134

102135
- The MongoDB query filters don't have the option to compare fields with each other. This package adds the `$field` operator to compare fields with each other.

filter/converter.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,77 @@ func (c *Converter) isNestedColumn(column string) bool {
404404
}
405405
return true
406406
}
407+
408+
// ConvertOrderBy converts a JSON object with field names and sort directions
409+
// into a PostgreSQL ORDER BY clause. The JSON object should have keys with values
410+
// of 1 (ASC) or -1 (DESC).
411+
//
412+
// For JSONB fields, it generates clauses that handle both numeric and text sorting.
413+
//
414+
// Example: {"playerCount": -1, "name": 1} -> "playerCount DESC, name ASC"
415+
func (c *Converter) ConvertOrderBy(query []byte) (string, error) {
416+
keyValues, err := objectInOrder(query)
417+
if err != nil {
418+
return "", err
419+
}
420+
421+
parts := make([]string, 0, len(keyValues))
422+
423+
for _, kv := range keyValues {
424+
key, value := kv.Key, kv.Value
425+
426+
if !isValidPostgresIdentifier(key) {
427+
return "", fmt.Errorf("invalid column name: %s", key)
428+
}
429+
if !c.isColumnAllowed(key) {
430+
return "", ColumnNotAllowedError{Column: key}
431+
}
432+
433+
// Convert value to number for direction
434+
var direction string
435+
switch v := value.(type) {
436+
case json.Number:
437+
if num, err := v.Int64(); err == nil {
438+
switch num {
439+
case 1:
440+
direction = "ASC"
441+
case -1:
442+
direction = "DESC"
443+
default:
444+
return "", InvalidOrderDirectionError{Field: key, Value: value}
445+
}
446+
} else {
447+
return "", InvalidOrderDirectionError{Field: key, Value: value}
448+
}
449+
case float64:
450+
switch v {
451+
case 1:
452+
direction = "ASC"
453+
case -1:
454+
direction = "DESC"
455+
default:
456+
return "", InvalidOrderDirectionError{Field: key, Value: value}
457+
}
458+
default:
459+
return "", InvalidOrderDirectionError{Field: key, Value: value}
460+
}
461+
462+
var fieldClause string
463+
if c.isNestedColumn(key) {
464+
// For JSONB fields, handle both numeric and text sorting.
465+
// We need to use the raw JSONB reference for jsonb_typeof, but columnName() for the actual sorting
466+
fieldClause = fmt.Sprintf("(CASE WHEN jsonb_typeof(%s->'%s') = 'number' THEN (%s)::numeric END) %s NULLS LAST, %s %s NULLS LAST", c.nestedColumn, key, c.columnName(key), direction, c.columnName(key), direction)
467+
} else {
468+
// Regular field.
469+
fieldClause = fmt.Sprintf(`%s %s`, c.columnName(key), direction)
470+
}
471+
472+
parts = append(parts, fieldClause)
473+
}
474+
475+
if len(parts) == 0 {
476+
return "", nil
477+
}
478+
479+
return strings.Join(parts, ", "), nil
480+
}

filter/converter_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package filter_test
22

33
import (
44
"database/sql"
5+
"encoding/json"
56
"fmt"
67
"reflect"
78
"testing"
@@ -641,3 +642,134 @@ func TestConverter_AccessControl(t *testing.T) {
641642
t.Run("nested but disallow password, disallow",
642643
f(`{"password": "hacks"}`, no("password"), filter.WithNestedJSONB("meta", "created_at"), filter.WithDisallowColumns("password")))
643644
}
645+
646+
func TestConverter_ConvertOrderBy(t *testing.T) {
647+
tests := []struct {
648+
name string
649+
options []filter.Option
650+
input string
651+
expected string
652+
err error
653+
}{
654+
{
655+
"single field ascending",
656+
[]filter.Option{filter.WithAllowAllColumns()},
657+
`{"playerCount": 1}`,
658+
`"playerCount" ASC`,
659+
nil,
660+
},
661+
{
662+
"single field descending",
663+
[]filter.Option{filter.WithAllowAllColumns()},
664+
`{"playerCount": -1}`,
665+
`"playerCount" DESC`,
666+
nil,
667+
},
668+
{
669+
"multiple fields",
670+
[]filter.Option{filter.WithAllowAllColumns()},
671+
`{"playerCount": -1, "name": 1}`,
672+
`"playerCount" DESC, "name" ASC`,
673+
nil,
674+
},
675+
{
676+
"nested JSONB single field ascending",
677+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
678+
`{"map": 1}`,
679+
`(CASE WHEN jsonb_typeof(customdata->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST`,
680+
nil,
681+
},
682+
{
683+
"nested JSONB single field descending",
684+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
685+
`{"map": -1}`,
686+
`(CASE WHEN jsonb_typeof(customdata->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
687+
nil,
688+
},
689+
{
690+
"nested JSONB multiple fields",
691+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
692+
`{"map": 1, "bar": -1}`,
693+
`(CASE WHEN jsonb_typeof(customdata->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST, (CASE WHEN jsonb_typeof(customdata->'bar') = 'number' THEN ("customdata"->>'bar')::numeric END) DESC NULLS LAST, "customdata"->>'bar' DESC NULLS LAST`,
694+
nil,
695+
},
696+
{
697+
"mixed nested and regular fields",
698+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
699+
`{"created_at": 1, "map": -1}`,
700+
`"created_at" ASC, (CASE WHEN jsonb_typeof(customdata->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
701+
nil,
702+
},
703+
{
704+
"field name with spaces",
705+
[]filter.Option{filter.WithAllowAllColumns()},
706+
`{"my_field": 1}`,
707+
`"my_field" ASC`,
708+
nil,
709+
},
710+
{
711+
"empty object",
712+
[]filter.Option{filter.WithAllowAllColumns()},
713+
`{}`,
714+
``,
715+
nil,
716+
},
717+
{
718+
"invalid field name for SQL injection",
719+
[]filter.Option{filter.WithAllowAllColumns()},
720+
`{"my field": 1}`,
721+
``,
722+
fmt.Errorf("invalid column name: my field"),
723+
},
724+
{
725+
"invalid direction value",
726+
[]filter.Option{filter.WithAllowAllColumns()},
727+
`{"playerCount": 2}`,
728+
``,
729+
filter.InvalidOrderDirectionError{Field: "playerCount", Value: json.Number("2")},
730+
},
731+
{
732+
"invalid direction string",
733+
[]filter.Option{filter.WithAllowAllColumns()},
734+
`{"playerCount": "asc"}`,
735+
``,
736+
filter.InvalidOrderDirectionError{Field: "playerCount", Value: "asc"},
737+
},
738+
{
739+
"disallowed column",
740+
[]filter.Option{filter.WithAllowColumns("name")},
741+
`{"playerCount": 1}`,
742+
``,
743+
filter.ColumnNotAllowedError{Column: "playerCount"},
744+
},
745+
}
746+
747+
for _, tt := range tests {
748+
t.Run(tt.name, func(t *testing.T) {
749+
converter, err := filter.NewConverter(tt.options...)
750+
if err != nil {
751+
t.Fatalf("Failed to create converter: %v", err)
752+
}
753+
754+
result, err := converter.ConvertOrderBy([]byte(tt.input))
755+
756+
if tt.err != nil {
757+
if err == nil {
758+
t.Fatalf("Expected error %v, got nil", tt.err)
759+
}
760+
if err.Error() != tt.err.Error() {
761+
t.Errorf("Expected error %v, got %v", tt.err, err)
762+
}
763+
return
764+
}
765+
766+
if err != nil {
767+
t.Fatalf("Unexpected error: %v", err)
768+
}
769+
770+
if result != tt.expected {
771+
t.Errorf("Expected %q, got %q", tt.expected, result)
772+
}
773+
})
774+
}
775+
}

filter/errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,12 @@ type ColumnNotAllowedError struct {
1212
func (e ColumnNotAllowedError) Error() string {
1313
return fmt.Sprintf("column not allowed: %s", e.Column)
1414
}
15+
16+
type InvalidOrderDirectionError struct {
17+
Field string
18+
Value any
19+
}
20+
21+
func (e InvalidOrderDirectionError) Error() string {
22+
return fmt.Sprintf("invalid order direction for field %s: %v (must be 1 or -1)", e.Field, e.Value)
23+
}

filter/util.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
package filter
22

3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
)
8+
39
func isNumeric(v any) bool {
410
// json.Unmarshal returns float64 for all numbers
511
// so we only need to check for float64.
@@ -71,3 +77,47 @@ func isValidPostgresIdentifier(s string) bool {
7177

7278
return true
7379
}
80+
81+
func objectInOrder(b []byte) ([]struct{ Key string; Value any }, error) {
82+
dec := json.NewDecoder(bytes.NewReader(b))
83+
dec.UseNumber()
84+
85+
// expect {
86+
tok, err := dec.Token()
87+
if err != nil {
88+
return nil, err
89+
}
90+
if d, ok := tok.(json.Delim); !ok || d != '{' {
91+
return nil, fmt.Errorf("expected object, got %v", tok)
92+
}
93+
94+
var result []struct{ Key string; Value any }
95+
96+
for dec.More() {
97+
// key
98+
tok, err := dec.Token()
99+
if err != nil {
100+
return nil, err
101+
}
102+
key, ok := tok.(string)
103+
if !ok {
104+
return nil, fmt.Errorf("expected string key, got %v", tok)
105+
}
106+
107+
// value
108+
var v any
109+
if err := dec.Decode(&v); err != nil {
110+
return nil, err
111+
}
112+
113+
result = append(result, struct{ Key string; Value any }{Key: key, Value: v})
114+
}
115+
116+
// consume }
117+
_, err = dec.Token()
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
return result, nil
123+
}

0 commit comments

Comments
 (0)