diff --git a/hooks/config.go b/hooks/config.go new file mode 100644 index 00000000000..4cec6a03bf6 --- /dev/null +++ b/hooks/config.go @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hooks + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/protos/pb" +) + +type ZeroHooks interface { + AssignUIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + AssignTimestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + AssignNsIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) + ApplyMutations(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) +} + +type ZeroHooksFns struct { + AssignUIDsFn func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + AssignTimestampsFn func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + AssignNsIDsFn func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) + CommitOrAbortFn func(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) + ApplyMutationsFn func(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) +} + +func (h ZeroHooksFns) AssignUIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + return h.AssignUIDsFn(ctx, num) +} + +func (h ZeroHooksFns) AssignTimestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + return h.AssignTimestampsFn(ctx, num) +} + +func (h ZeroHooksFns) AssignNsIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + return h.AssignNsIDsFn(ctx, num) +} + +func (h ZeroHooksFns) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { + return h.CommitOrAbortFn(ctx, tc) +} + +func (h ZeroHooksFns) ApplyMutations(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) { + return h.ApplyMutationsFn(ctx, m) +} + +// Config holds the configuration for embedded mode operation. +type Config struct { + // Hooks for bypassing Zero operations + ZeroHooks ZeroHooks + + // DataDir is the directory where data files are stored + DataDir string + + // CacheSizeMB is the size of the in-memory cache in megabytes + CacheSizeMB int64 +} + +var ( + // globalConfig holds the current embedded configuration + globalConfig atomic.Pointer[Config] + + defaultZeroHooks atomic.Value + + // enabled tracks whether embedded mode is active + enabled atomic.Bool + + // mu protects initialization + mu sync.Mutex +) + +// Enable activates embedded mode with the given configuration. +// This must be called before any Dgraph operations. +func Enable(cfg *Config) { + mu.Lock() + defer mu.Unlock() + + globalConfig.Store(cfg) + enabled.Store(true) +} + +// Disable deactivates embedded mode. +func Disable() { + mu.Lock() + defer mu.Unlock() + + enabled.Store(false) + globalConfig.Store(nil) +} + +// IsEnabled returns true if embedded mode is currently active. +func IsEnabled() bool { + return enabled.Load() +} + +// GetConfig returns the current embedded configuration, or nil if not enabled. +func GetConfig() *Config { + return globalConfig.Load() +} + +func SetDefaultZeroHooks(h ZeroHooks) { + defaultZeroHooks.Store(h) +} + +// GetHooks returns the active Zero hooks. +// If embedded mode is not enabled, it returns the default hooks implementation. +func GetHooks() ZeroHooks { + cfg := globalConfig.Load() + if cfg != nil && cfg.ZeroHooks != nil { + return cfg.ZeroHooks + } + if h := defaultZeroHooks.Load(); h != nil { + return h.(ZeroHooks) + } + panic("no ZeroHooks configured - ensure worker package is imported or hooks.Enable() is called") +} diff --git a/hooks/config_test.go b/hooks/config_test.go new file mode 100644 index 00000000000..0980fcb6739 --- /dev/null +++ b/hooks/config_test.go @@ -0,0 +1,238 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hooks + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/stretchr/testify/require" +) + +// mockZeroHooks is a test implementation of ZeroHooks +type mockZeroHooks struct { + assignUIDsCalled bool + assignTimestampsCalled bool + assignNsIDsCalled bool + commitOrAbortCalled bool + applyMutationsCalled bool +} + +func (m *mockZeroHooks) AssignUIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + m.assignUIDsCalled = true + return &pb.AssignedIds{StartId: 1, EndId: 10}, nil +} + +func (m *mockZeroHooks) AssignTimestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + m.assignTimestampsCalled = true + return &pb.AssignedIds{StartId: 100}, nil +} + +func (m *mockZeroHooks) AssignNsIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + m.assignNsIDsCalled = true + return &pb.AssignedIds{StartId: 1}, nil +} + +func (m *mockZeroHooks) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { + m.commitOrAbortCalled = true + return &api.TxnContext{CommitTs: 200}, nil +} + +func (m *mockZeroHooks) ApplyMutations(ctx context.Context, mut *pb.Mutations) (*api.TxnContext, error) { + m.applyMutationsCalled = true + return &api.TxnContext{}, nil +} + +func resetGlobalState() { + mu.Lock() + defer mu.Unlock() + enabled.Store(false) + globalConfig.Store(nil) + // Store a true nil by using a new atomic.Value (zero value has nil) + defaultZeroHooks = atomic.Value{} +} + +func TestEnableDisable(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + require.False(t, IsEnabled(), "should start disabled") + require.Nil(t, GetConfig(), "config should be nil when disabled") + + cfg := &Config{ + DataDir: "/tmp/test", + CacheSizeMB: 128, + } + Enable(cfg) + + require.True(t, IsEnabled(), "should be enabled after Enable()") + require.NotNil(t, GetConfig(), "config should not be nil after Enable()") + require.Equal(t, "/tmp/test", GetConfig().DataDir) + require.Equal(t, int64(128), GetConfig().CacheSizeMB) + + Disable() + + require.False(t, IsEnabled(), "should be disabled after Disable()") + require.Nil(t, GetConfig(), "config should be nil after Disable()") +} + +func TestGetHooksWithCustomConfig(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + mock := &mockZeroHooks{} + cfg := &Config{ + ZeroHooks: mock, + } + Enable(cfg) + + hooks := GetHooks() + require.NotNil(t, hooks) + require.Equal(t, mock, hooks, "should return custom hooks from config") +} + +func TestGetHooksWithDefaultHooks(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + mock := &mockZeroHooks{} + SetDefaultZeroHooks(mock) + + hooks := GetHooks() + require.NotNil(t, hooks) + require.Equal(t, mock, hooks, "should return default hooks when no config") +} + +func TestGetHooksPanicsWhenNoHooksConfigured(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + require.Panics(t, func() { + GetHooks() + }, "should panic when no hooks configured") +} + +func TestGetHooksCustomOverridesDefault(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + defaultMock := &mockZeroHooks{} + SetDefaultZeroHooks(defaultMock) + + customMock := &mockZeroHooks{} + cfg := &Config{ + ZeroHooks: customMock, + } + Enable(cfg) + + hooks := GetHooks() + require.Equal(t, customMock, hooks, "custom hooks should override default") +} + +func TestZeroHooksFnsWrapper(t *testing.T) { + var assignUIDsCalled, timestampsCalled, nsIDsCalled, commitCalled, mutationsCalled bool + + wrapper := ZeroHooksFns{ + AssignUIDsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + assignUIDsCalled = true + return &pb.AssignedIds{StartId: 1}, nil + }, + AssignTimestampsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + timestampsCalled = true + return &pb.AssignedIds{StartId: 100}, nil + }, + AssignNsIDsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + nsIDsCalled = true + return &pb.AssignedIds{StartId: 1}, nil + }, + CommitOrAbortFn: func(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { + commitCalled = true + return &api.TxnContext{CommitTs: 200}, nil + }, + ApplyMutationsFn: func(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) { + mutationsCalled = true + return &api.TxnContext{}, nil + }, + } + + ctx := context.Background() + + _, _ = wrapper.AssignUIDs(ctx, &pb.Num{}) + require.True(t, assignUIDsCalled, "AssignUIDs should delegate to function") + + _, _ = wrapper.AssignTimestamps(ctx, &pb.Num{}) + require.True(t, timestampsCalled, "AssignTimestamps should delegate to function") + + _, _ = wrapper.AssignNsIDs(ctx, &pb.Num{}) + require.True(t, nsIDsCalled, "AssignNsIDs should delegate to function") + + _, _ = wrapper.CommitOrAbort(ctx, &api.TxnContext{}) + require.True(t, commitCalled, "CommitOrAbort should delegate to function") + + _, _ = wrapper.ApplyMutations(ctx, &pb.Mutations{}) + require.True(t, mutationsCalled, "ApplyMutations should delegate to function") +} + +func TestConcurrentEnableDisable(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + var wg sync.WaitGroup + iterations := 100 + + // Concurrent enables + for i := 0; i < iterations; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + cfg := &Config{ + DataDir: "/tmp/test", + CacheSizeMB: int64(i), + } + Enable(cfg) + }(i) + } + + // Concurrent disables + for i := 0; i < iterations; i++ { + wg.Add(1) + go func() { + defer wg.Done() + Disable() + }() + } + + // Concurrent reads + for i := 0; i < iterations; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = IsEnabled() + _ = GetConfig() + }() + } + + wg.Wait() + // Test passes if no race conditions detected +} + +func TestSetDefaultZeroHooks(t *testing.T) { + resetGlobalState() + defer resetGlobalState() + + mock1 := &mockZeroHooks{} + mock2 := &mockZeroHooks{} + + SetDefaultZeroHooks(mock1) + require.Equal(t, mock1, GetHooks()) + + SetDefaultZeroHooks(mock2) + require.Equal(t, mock2, GetHooks()) +} diff --git a/hooks/init.go b/hooks/init.go new file mode 100644 index 00000000000..8fb5d2727ae --- /dev/null +++ b/hooks/init.go @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hooks + +// This package provides configuration and hooks for running Dgraph in embedded mode. +// The actual initialization functions that would create import cycles are intentionally +// left to be called directly by the host application using the +// individual packages (edgraph, worker, posting, schema, x). +// +// Usage: +// 1. Call hooks.Enable() with your ZeroHooks configuration +// 2. Initialize packages directly: edgraph.Init(), worker.State.InitStorage(), etc. +// 3. The hooks in this package will be called automatically by worker functions +// 4. Call hooks.Disable() when shutting down diff --git a/worker/default_zero_hooks.go b/worker/default_zero_hooks.go new file mode 100644 index 00000000000..482b2c27177 --- /dev/null +++ b/worker/default_zero_hooks.go @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "context" + + "google.golang.org/grpc/metadata" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/conn" + "github.com/dgraph-io/dgraph/v25/hooks" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" +) + +type defaultZeroHooks struct{} + +func init() { + hooks.SetDefaultZeroHooks(defaultZeroHooks{}) +} + +func (defaultZeroHooks) AssignUIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + if num.Type == 0 { + num.Type = pb.Num_UID + } + + // Pass on the incoming metadata to the zero. Namespace from the metadata is required by zero. + if md, ok := metadata.FromIncomingContext(ctx); ok { + ctx = metadata.NewOutgoingContext(ctx, md) + } + + pl := groups().Leader(0) + if pl == nil { + return nil, conn.ErrNoConnection + } + + c := pb.NewZeroClient(pl.Get()) + return c.AssignIds(ctx, num) +} + +func (defaultZeroHooks) AssignTimestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + pl := groups().connToZeroLeader() + if pl == nil { + return nil, conn.ErrNoConnection + } + + c := pb.NewZeroClient(pl.Get()) + return c.Timestamps(ctx, num) +} + +func (defaultZeroHooks) AssignNsIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { + num.Type = pb.Num_NS_ID + + pl := groups().Leader(0) + if pl == nil { + return nil, conn.ErrNoConnection + } + + c := pb.NewZeroClient(pl.Get()) + return c.AssignIds(ctx, num) +} + +func (defaultZeroHooks) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { + pl := groups().Leader(0) + if pl == nil { + return nil, conn.ErrNoConnection + } + + // Do de-duplication before sending the request to zero. + tc.Keys = x.Unique(tc.Keys) + tc.Preds = x.Unique(tc.Preds) + + zc := pb.NewZeroClient(pl.Get()) + return zc.CommitOrAbort(ctx, tc) +} + +func (defaultZeroHooks) ApplyMutations(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) { + if groups().ServesGroup(m.GroupId) { + txnCtx := &api.TxnContext{} + return txnCtx, (&grpcWorker{}).proposeAndWait(ctx, txnCtx, m) + } + + pl := groups().Leader(m.GroupId) + if pl == nil { + return nil, conn.ErrNoConnection + } + + var tc *api.TxnContext + c := pb.NewWorkerClient(pl.Get()) + + ch := make(chan error, 1) + go func() { + var err error + tc, err = c.Mutate(ctx, m) + ch <- err + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-ch: + return tc, err + } +} diff --git a/worker/mutation.go b/worker/mutation.go index fdac2a41c1b..57740edaeb7 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -20,14 +20,13 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/y" "github.com/dgraph-io/dgo/v250" "github.com/dgraph-io/dgo/v250/protos/api" - "github.com/dgraph-io/dgraph/v25/conn" + "github.com/dgraph-io/dgraph/v25/hooks" "github.com/dgraph-io/dgraph/v25/posting" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" @@ -598,49 +597,31 @@ func ValidateAndConvert(edge *pb.DirectedEdge, su *pb.SchemaUpdate) error { // AssignNsIdsOverNetwork sends a request to assign Namespace IDs to the current zero leader. func AssignNsIdsOverNetwork(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { - pl := groups().Leader(0) - if pl == nil { - return nil, conn.ErrNoConnection - } - - con := pl.Get() - c := pb.NewZeroClient(con) - num.Type = pb.Num_NS_ID - return c.AssignIds(ctx, num) + h := hooks.GetHooks() + return h.AssignNsIDs(ctx, num) } // AssignUidsOverNetwork sends a request to assign UIDs from the current zero leader. func AssignUidsOverNetwork(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { - // Pass on the incoming metadata to the zero. Namespace from the metadata is required by zero. - if md, ok := metadata.FromIncomingContext(ctx); ok { - ctx = metadata.NewOutgoingContext(ctx, md) - } - pl := groups().Leader(0) - if pl == nil { - return nil, conn.ErrNoConnection - } - - con := pl.Get() - c := pb.NewZeroClient(con) + h := hooks.GetHooks() num.Type = pb.Num_UID - return c.AssignIds(ctx, num) + return h.AssignUIDs(ctx, num) } // Timestamps sends a request to assign startTs for a new transaction to the current zero leader. func Timestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { - pl := groups().connToZeroLeader() - if pl == nil { - return nil, conn.ErrNoConnection - } - - con := pl.Get() - c := pb.NewZeroClient(con) - return c.Timestamps(ctx, num) + h := hooks.GetHooks() + return h.AssignTimestamps(ctx, num) } func fillTxnContext(tctx *api.TxnContext, startTs uint64, isErrored bool) { if txn := posting.Oracle().GetTxn(startTs); txn != nil { - txn.FillContext(tctx, groups().groupId(), isErrored) + // In embedded mode, use group 1 as the default + gid := uint32(1) + if !hooks.IsEnabled() { + gid = groups().groupId() + } + txn.FillContext(tctx, gid, isErrored) } // We do not need to fill linread mechanism anymore, because transaction // start ts is sufficient to wait for, to achieve lin reads. @@ -650,38 +631,9 @@ func fillTxnContext(tctx *api.TxnContext, startTs uint64, isErrored bool) { // the leader of the group gid for proposing. func proposeOrSend(ctx context.Context, gid uint32, m *pb.Mutations, chr chan res) { res := res{} - if groups().ServesGroup(gid) { - res.ctx = &api.TxnContext{} - res.err = (&grpcWorker{}).proposeAndWait(ctx, res.ctx, m) - chr <- res - return - } - - pl := groups().Leader(gid) - if pl == nil { - res.err = conn.ErrNoConnection - chr <- res - return - } - - var tc *api.TxnContext - c := pb.NewWorkerClient(pl.Get()) - - ch := make(chan error, 1) - go func() { - var err error - tc, err = c.Mutate(ctx, m) - ch <- err - }() - - select { - case <-ctx.Done(): - res.err = ctx.Err() - res.ctx = nil - case err := <-ch: - res.err = err - res.ctx = tc - } + + h := hooks.GetHooks() + res.ctx, res.err = h.ApplyMutations(ctx, m) chr <- res } @@ -895,18 +847,8 @@ func CommitOverNetwork(ctx context.Context, tc *api.TxnContext) (uint64, error) clientDiscard = true } - pl := groups().Leader(0) - if pl == nil { - return 0, conn.ErrNoConnection - } - - // Do de-duplication before sending the request to zero. - tc.Keys = x.Unique(tc.Keys) - tc.Preds = x.Unique(tc.Preds) - - zc := pb.NewZeroClient(pl.Get()) - tctx, err := zc.CommitOrAbort(ctx, tc) - + h := hooks.GetHooks() + tctx, err := h.CommitOrAbort(ctx, tc) if err != nil { span.AddEvent("Error in CommitOrAbort", trace.WithAttributes( attribute.String("error", err.Error()))) @@ -918,7 +860,6 @@ func CommitOverNetwork(ctx context.Context, tc *api.TxnContext) (uint64, error) if tctx.Aborted || tctx.CommitTs == 0 { if !clientDiscard { - // The server aborted the txn (not the client) ostats.Record(ctx, x.TxnAborts.M(1)) } return 0, dgo.ErrAborted