Skip to content

Commit 55e3b79

Browse files
Merge commit from fork
* Fix injection issues * Add exclusion /debug/vars key; consolidate func; apply to bulk and live loader too
1 parent ffa993c commit 55e3b79

9 files changed

Lines changed: 342 additions & 26 deletions

File tree

dgraph/cmd/alpha/http_test.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -827,24 +827,39 @@ func TestHealth(t *testing.T) {
827827
require.True(t, info[0].Uptime > int64(time.Duration(1)))
828828
}
829829

830-
// TestPprofCmdlineNotExposed ensures that /debug/pprof/cmdline is not reachable
831-
// without authentication. The endpoint exposes the full process command line,
832-
// which may include the admin token passed via --security "token=...".
833-
// The other pprof sub-endpoints should remain accessible.
834-
func TestPprofCmdlineNotExposed(t *testing.T) {
835-
// cmdline must be blocked — it leaks the admin token from process args.
830+
// TestCmdlineEndpointsNotExposed ensures that endpoints which expose the full
831+
// process command line are not reachable without authentication. Both
832+
// /debug/pprof/cmdline (net/http/pprof) and /debug/vars (expvar, which
833+
// publishes os.Args as "cmdline") can leak the admin token passed via
834+
// --security "token=...".
835+
func TestCmdlineEndpointsNotExposed(t *testing.T) {
836+
// /debug/pprof/cmdline must be blocked.
836837
resp, err := http.Get(fmt.Sprintf("%s/debug/pprof/cmdline", addr))
837838
require.NoError(t, err)
838839
defer resp.Body.Close()
839840
require.Equal(t, http.StatusNotFound, resp.StatusCode,
840841
"/debug/pprof/cmdline should return 404; got %d", resp.StatusCode)
841842

842-
// Sanity-check that other pprof endpoints are still reachable.
843-
resp2, err := http.Get(fmt.Sprintf("%s/debug/pprof/heap", addr))
843+
// /debug/vars must still be reachable but must NOT include "cmdline".
844+
resp2, err := http.Get(fmt.Sprintf("%s/debug/vars", addr))
844845
require.NoError(t, err)
845846
defer resp2.Body.Close()
846847
require.Equal(t, http.StatusOK, resp2.StatusCode,
847-
"/debug/pprof/heap should return 200; got %d", resp2.StatusCode)
848+
"/debug/vars should return 200; got %d", resp2.StatusCode)
849+
body, err := io.ReadAll(resp2.Body)
850+
require.NoError(t, err)
851+
var vars map[string]json.RawMessage
852+
require.NoError(t, json.Unmarshal(body, &vars))
853+
_, hasCmdline := vars["cmdline"]
854+
require.False(t, hasCmdline,
855+
"/debug/vars response must not contain the cmdline key")
856+
857+
// Sanity-check that other pprof endpoints are still reachable.
858+
resp3, err := http.Get(fmt.Sprintf("%s/debug/pprof/heap", addr))
859+
require.NoError(t, err)
860+
defer resp3.Body.Close()
861+
require.Equal(t, http.StatusOK, resp3.StatusCode,
862+
"/debug/pprof/heap should return 200; got %d", resp3.StatusCode)
848863
}
849864

850865
func setDrainingMode(t *testing.T, enable bool, accessJwt string) {

dgraph/cmd/alpha/run.go

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,16 +620,7 @@ func setupServer(closer *z.Closer, enableMcp bool) {
620620
}
621621
}
622622

623-
// Block /debug/pprof/cmdline — importing net/http/pprof registers it on
624-
// http.DefaultServeMux, but it exposes the full process command line which
625-
// may include the admin token from --security "token=...".
626-
serverHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
627-
if r.URL.Path == "/debug/pprof/cmdline" {
628-
http.NotFound(w, r)
629-
return
630-
}
631-
http.DefaultServeMux.ServeHTTP(w, r)
632-
})
623+
serverHandler := x.SanitizedDefaultServeMux()
633624
go x.StartListenHttpAndHttps(httpListener, tlsCfg, x.ServerCloser, serverHandler)
634625

635626
go func() {

dgraph/cmd/alpha/upsert_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,76 @@ func TestLargeStringIndex(t *testing.T) {
29992999
`{"predicate":"name_term","type":"string","index":true,"tokenizer":["term"]}`)
30003000
}
30013001

3002+
// TestDQLInjectionViaCondField is a security regression test for LEAD-001.
3003+
// It verifies that a crafted cond field containing an injected DQL query block
3004+
// is rejected by input validation before reaching the query builder.
3005+
func TestDQLInjectionViaCondField(t *testing.T) {
3006+
require.NoError(t, dropAll())
3007+
require.NoError(t, alterSchema(`
3008+
name: string @index(exact) .
3009+
email: string @index(exact) .
3010+
secret: string .
3011+
`))
3012+
3013+
// Seed data that should NOT be readable via a mutation request.
3014+
seed := `{
3015+
set {
3016+
_:u1 <dgraph.type> "User" .
3017+
_:u1 <name> "Alice" .
3018+
_:u1 <email> "alice@example.com" .
3019+
_:u1 <secret> "SSN-111-22-3333" .
3020+
3021+
_:u2 <dgraph.type> "User" .
3022+
_:u2 <name> "Bob" .
3023+
_:u2 <email> "bob@example.com" .
3024+
_:u2 <secret> "API_KEY_secret_abc123" .
3025+
}
3026+
}`
3027+
_, err := mutationWithTs(mutationInp{body: seed, typ: "application/rdf", commitNow: true})
3028+
require.NoError(t, err)
3029+
3030+
// Craft the injection payload. The cond value closes the @if() clause and
3031+
// appends an entirely new named query block "leak" that would exfiltrate all
3032+
// data if the injection were not blocked.
3033+
injectionPayload := `{
3034+
"query": "{ q(func: uid(0x1)) { uid } }",
3035+
"mutations": [{
3036+
"set": [{"uid": "0x1", "dgraph.type": "Dummy"}],
3037+
"cond": "@if(eq(name, \"nonexistent\"))\n leak(func: has(dgraph.type)) { uid dgraph.type name email secret }"
3038+
}]
3039+
}`
3040+
3041+
// The injection payload must be rejected by cond validation.
3042+
_, err = mutationWithTs(mutationInp{
3043+
body: injectionPayload,
3044+
typ: "application/json",
3045+
commitNow: true,
3046+
})
3047+
require.Error(t, err)
3048+
require.Contains(t, err.Error(), "invalid cond value")
3049+
3050+
// Verify that no mutation was applied — the request was rejected entirely.
3051+
q := `{ q(func: has(dgraph.type)) { uid name } }`
3052+
res, _, err := queryWithTs(queryInp{body: q, typ: "application/dql"})
3053+
require.NoError(t, err)
3054+
require.NotContains(t, res, "Dummy")
3055+
3056+
// Verify that a legitimate conditional upsert still works.
3057+
legitimateUpsert := `{
3058+
"query": "{ q(func: eq(name, \"Alice\")) { v as uid } }",
3059+
"mutations": [{
3060+
"set": [{"uid": "uid(v)", "email": "alice-updated@example.com"}],
3061+
"cond": "@if(eq(len(v), 1))"
3062+
}]
3063+
}`
3064+
_, err = mutationWithTs(mutationInp{
3065+
body: legitimateUpsert,
3066+
typ: "application/json",
3067+
commitNow: true,
3068+
})
3069+
require.NoError(t, err)
3070+
}
3071+
30023072
func TestStringWithQuote(t *testing.T) {
30033073
require.NoError(t, dropAll())
30043074
require.NoError(t, alterSchemaWithRetry(`name: string @unique @index(exact) .`))

dgraph/cmd/bulk/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ func RunBulkLoader(opt BulkOptions) {
280280
maxOpenFilesWarning()
281281

282282
go func() {
283-
log.Fatal(http.ListenAndServe(opt.HttpAddr, nil))
283+
log.Fatal(http.ListenAndServe(opt.HttpAddr, x.SanitizedDefaultServeMux()))
284284
}()
285285
http.HandleFunc("/jemalloc", x.JemallocHandler)
286286

dgraph/cmd/debug/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ func run() {
943943
go func() {
944944
for i := 8080; i < 9080; i++ {
945945
fmt.Printf("Listening for /debug HTTP requests at port: %d\n", i)
946-
if err := http.ListenAndServe(fmt.Sprintf("localhost:%d", i), nil); err != nil {
946+
if err := http.ListenAndServe(fmt.Sprintf("localhost:%d", i), x.SanitizedDefaultServeMux()); err != nil {
947947
fmt.Println("Port busy. Trying another one...")
948948
continue
949949
}

dgraph/cmd/live/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ func run() error {
705705
z.SetTmpDir(opt.tmpDir)
706706

707707
go func() {
708-
if err := http.ListenAndServe(opt.httpAddr, nil); err != nil {
708+
if err := http.ListenAndServe(opt.httpAddr, x.SanitizedDefaultServeMux()); err != nil {
709709
glog.Errorf("Error while starting HTTP server: %+v", err)
710710
}
711711
}()

edgraph/server.go

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"math"
1515
"net"
16+
"regexp"
1617
"sort"
1718
"strconv"
1819
"strings"
@@ -709,11 +710,102 @@ func validateMutation(ctx context.Context, edges []*pb.DirectedEdge) error {
709710
return nil
710711
}
711712

713+
// validateCondValue checks that a cond string is a well-formed @if(...) or @filter(...)
714+
// clause with balanced parentheses and no trailing content. This prevents DQL injection
715+
// via crafted cond values that close the parenthesized expression and append additional
716+
// query blocks.
717+
func validateCondValue(cond string) error {
718+
cond = strings.TrimSpace(cond)
719+
if cond == "" {
720+
return nil
721+
}
722+
723+
lower := strings.ToLower(cond)
724+
if !strings.HasPrefix(lower, "@if(") && !strings.HasPrefix(lower, "@filter(") {
725+
return errors.Errorf("invalid cond value: must start with @if( or @filter(")
726+
}
727+
728+
openIdx := strings.Index(cond, "(")
729+
if openIdx == -1 {
730+
return errors.Errorf("invalid cond value: missing opening parenthesis")
731+
}
732+
733+
depth := 0
734+
inString := false
735+
escaped := false
736+
closingIdx := -1
737+
738+
for i := openIdx; i < len(cond); i++ {
739+
ch := cond[i]
740+
if escaped {
741+
escaped = false
742+
continue
743+
}
744+
if ch == '\\' {
745+
escaped = true
746+
continue
747+
}
748+
if ch == '"' {
749+
inString = !inString
750+
continue
751+
}
752+
if inString {
753+
continue
754+
}
755+
if ch == '(' {
756+
depth++
757+
} else if ch == ')' {
758+
depth--
759+
if depth == 0 {
760+
closingIdx = i
761+
break
762+
}
763+
}
764+
}
765+
766+
if closingIdx == -1 {
767+
return errors.Errorf("invalid cond value: unbalanced parentheses")
768+
}
769+
770+
trailing := strings.TrimSpace(cond[closingIdx+1:])
771+
if trailing != "" {
772+
return errors.Errorf("invalid cond value: unexpected content after condition")
773+
}
774+
775+
return nil
776+
}
777+
778+
// valVarRegexp matches a valid val(variableName) reference used in upsert mutations.
779+
var valVarRegexp = regexp.MustCompile(`^val\([a-zA-Z_][a-zA-Z0-9_.]*\)$`)
780+
781+
// validateValObjectId checks that an ObjectId starting with "val(" is a well-formed
782+
// val(variableName) reference and contains no injected DQL syntax.
783+
func validateValObjectId(objectId string) error {
784+
if !valVarRegexp.MatchString(objectId) {
785+
return errors.Errorf("invalid val() reference in ObjectId: %q", objectId)
786+
}
787+
return nil
788+
}
789+
790+
// langTagRegexp matches a valid BCP 47 language tag (letters, digits, hyphens).
791+
var langTagRegexp = regexp.MustCompile(`^[a-zA-Z]+(-[a-zA-Z0-9]+)*$`)
792+
793+
// validateLangTag checks that a language tag contains only safe characters.
794+
func validateLangTag(lang string) error {
795+
if lang == "" {
796+
return nil
797+
}
798+
if !langTagRegexp.MatchString(lang) {
799+
return errors.Errorf("invalid language tag: %q", lang)
800+
}
801+
return nil
802+
}
803+
712804
// buildUpsertQuery modifies the query to evaluate the
713805
// @if condition defined in Conditional Upsert.
714-
func buildUpsertQuery(qc *queryContext) string {
806+
func buildUpsertQuery(qc *queryContext) (string, error) {
715807
if qc.req.Query == "" || len(qc.gmuList) == 0 {
716-
return qc.req.Query
808+
return qc.req.Query, nil
717809
}
718810

719811
qc.condVars = make([]string, len(qc.req.Mutations))
@@ -724,6 +816,10 @@ func buildUpsertQuery(qc *queryContext) string {
724816
for i, gmu := range qc.gmuList {
725817
isCondUpsert := strings.TrimSpace(gmu.Cond) != ""
726818
if isCondUpsert {
819+
if err := validateCondValue(gmu.Cond); err != nil {
820+
return "", err
821+
}
822+
727823
qc.condVars[i] = fmt.Sprintf("__dgraph_upsertcheck_%v__", strconv.Itoa(i))
728824
qc.uidRes[qc.condVars[i]] = nil
729825
// @if in upsert is same as @filter in the query
@@ -753,7 +849,7 @@ func buildUpsertQuery(qc *queryContext) string {
753849
}
754850

755851
x.Check2(upsertQB.WriteString(`}`))
756-
return upsertQB.String()
852+
return upsertQB.String(), nil
757853
}
758854

759855
// updateMutations updates the mutation and replaces uid(var) and val(var) with
@@ -1581,7 +1677,11 @@ func parseRequest(ctx context.Context, qc *queryContext) error {
15811677

15821678
qc.uidRes = make(map[string][]string)
15831679
qc.valRes = make(map[string]*types.ShardedMap)
1584-
upsertQuery = buildUpsertQuery(qc)
1680+
var err error
1681+
upsertQuery, err = buildUpsertQuery(qc)
1682+
if err != nil {
1683+
return err
1684+
}
15851685
needVars = findMutationVars(qc)
15861686
if upsertQuery == "" {
15871687
if len(needVars) > 0 {
@@ -1777,6 +1877,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17771877
// during the automatic serialization of a structure into JSON.
17781878
predicateName := fmt.Sprintf("<%v>", pred.Predicate)
17791879
if pred.Lang != "" {
1880+
if err := validateLangTag(pred.Lang); err != nil {
1881+
return err
1882+
}
17801883
predicateName = fmt.Sprintf("%v@%v", predicateName, pred.Lang)
17811884
}
17821885

@@ -1814,6 +1917,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
18141917
}
18151918
qc.uniqueVars[uniqueVarMapKey] = uniquePredMeta{queryVar: queryVar}
18161919
} else {
1920+
if err := validateValObjectId(pred.ObjectId); err != nil {
1921+
return err
1922+
}
18171923
valQueryVar := fmt.Sprintf("__dgraph_uniquecheck_val_%v__", uniqueVarMapKey)
18181924
query := fmt.Sprintf(`%v as var(func: eq(%v,%v)){
18191925
uid

0 commit comments

Comments
 (0)