|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package hooks |
| 7 | + |
| 8 | +import ( |
| 9 | + "context" |
| 10 | + "sync" |
| 11 | + "sync/atomic" |
| 12 | + "testing" |
| 13 | + |
| 14 | + "github.com/dgraph-io/dgo/v250/protos/api" |
| 15 | + "github.com/dgraph-io/dgraph/v25/protos/pb" |
| 16 | + "github.com/stretchr/testify/require" |
| 17 | +) |
| 18 | + |
| 19 | +// mockZeroHooks is a test implementation of ZeroHooks |
| 20 | +type mockZeroHooks struct { |
| 21 | + assignUIDsCalled bool |
| 22 | + assignTimestampsCalled bool |
| 23 | + assignNsIDsCalled bool |
| 24 | + commitOrAbortCalled bool |
| 25 | + applyMutationsCalled bool |
| 26 | +} |
| 27 | + |
| 28 | +func (m *mockZeroHooks) AssignUIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 29 | + m.assignUIDsCalled = true |
| 30 | + return &pb.AssignedIds{StartId: 1, EndId: 10}, nil |
| 31 | +} |
| 32 | + |
| 33 | +func (m *mockZeroHooks) AssignTimestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 34 | + m.assignTimestampsCalled = true |
| 35 | + return &pb.AssignedIds{StartId: 100}, nil |
| 36 | +} |
| 37 | + |
| 38 | +func (m *mockZeroHooks) AssignNsIDs(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 39 | + m.assignNsIDsCalled = true |
| 40 | + return &pb.AssignedIds{StartId: 1}, nil |
| 41 | +} |
| 42 | + |
| 43 | +func (m *mockZeroHooks) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { |
| 44 | + m.commitOrAbortCalled = true |
| 45 | + return &api.TxnContext{CommitTs: 200}, nil |
| 46 | +} |
| 47 | + |
| 48 | +func (m *mockZeroHooks) ApplyMutations(ctx context.Context, mut *pb.Mutations) (*api.TxnContext, error) { |
| 49 | + m.applyMutationsCalled = true |
| 50 | + return &api.TxnContext{}, nil |
| 51 | +} |
| 52 | + |
| 53 | +func resetGlobalState() { |
| 54 | + mu.Lock() |
| 55 | + defer mu.Unlock() |
| 56 | + enabled.Store(false) |
| 57 | + globalConfig.Store(nil) |
| 58 | + // Store a true nil by using a new atomic.Value (zero value has nil) |
| 59 | + defaultZeroHooks = atomic.Value{} |
| 60 | +} |
| 61 | + |
| 62 | +func TestEnableDisable(t *testing.T) { |
| 63 | + resetGlobalState() |
| 64 | + defer resetGlobalState() |
| 65 | + |
| 66 | + require.False(t, IsEnabled(), "should start disabled") |
| 67 | + require.Nil(t, GetConfig(), "config should be nil when disabled") |
| 68 | + |
| 69 | + cfg := &Config{ |
| 70 | + DataDir: "/tmp/test", |
| 71 | + CacheSizeMB: 128, |
| 72 | + } |
| 73 | + Enable(cfg) |
| 74 | + |
| 75 | + require.True(t, IsEnabled(), "should be enabled after Enable()") |
| 76 | + require.NotNil(t, GetConfig(), "config should not be nil after Enable()") |
| 77 | + require.Equal(t, "/tmp/test", GetConfig().DataDir) |
| 78 | + require.Equal(t, int64(128), GetConfig().CacheSizeMB) |
| 79 | + |
| 80 | + Disable() |
| 81 | + |
| 82 | + require.False(t, IsEnabled(), "should be disabled after Disable()") |
| 83 | + require.Nil(t, GetConfig(), "config should be nil after Disable()") |
| 84 | +} |
| 85 | + |
| 86 | +func TestGetHooksWithCustomConfig(t *testing.T) { |
| 87 | + resetGlobalState() |
| 88 | + defer resetGlobalState() |
| 89 | + |
| 90 | + mock := &mockZeroHooks{} |
| 91 | + cfg := &Config{ |
| 92 | + ZeroHooks: mock, |
| 93 | + } |
| 94 | + Enable(cfg) |
| 95 | + |
| 96 | + hooks := GetHooks() |
| 97 | + require.NotNil(t, hooks) |
| 98 | + require.Equal(t, mock, hooks, "should return custom hooks from config") |
| 99 | +} |
| 100 | + |
| 101 | +func TestGetHooksWithDefaultHooks(t *testing.T) { |
| 102 | + resetGlobalState() |
| 103 | + defer resetGlobalState() |
| 104 | + |
| 105 | + mock := &mockZeroHooks{} |
| 106 | + SetDefaultZeroHooks(mock) |
| 107 | + |
| 108 | + hooks := GetHooks() |
| 109 | + require.NotNil(t, hooks) |
| 110 | + require.Equal(t, mock, hooks, "should return default hooks when no config") |
| 111 | +} |
| 112 | + |
| 113 | +func TestGetHooksPanicsWhenNoHooksConfigured(t *testing.T) { |
| 114 | + resetGlobalState() |
| 115 | + defer resetGlobalState() |
| 116 | + |
| 117 | + require.Panics(t, func() { |
| 118 | + GetHooks() |
| 119 | + }, "should panic when no hooks configured") |
| 120 | +} |
| 121 | + |
| 122 | +func TestGetHooksCustomOverridesDefault(t *testing.T) { |
| 123 | + resetGlobalState() |
| 124 | + defer resetGlobalState() |
| 125 | + |
| 126 | + defaultMock := &mockZeroHooks{} |
| 127 | + SetDefaultZeroHooks(defaultMock) |
| 128 | + |
| 129 | + customMock := &mockZeroHooks{} |
| 130 | + cfg := &Config{ |
| 131 | + ZeroHooks: customMock, |
| 132 | + } |
| 133 | + Enable(cfg) |
| 134 | + |
| 135 | + hooks := GetHooks() |
| 136 | + require.Equal(t, customMock, hooks, "custom hooks should override default") |
| 137 | +} |
| 138 | + |
| 139 | +func TestZeroHooksFnsWrapper(t *testing.T) { |
| 140 | + var assignUIDsCalled, timestampsCalled, nsIDsCalled, commitCalled, mutationsCalled bool |
| 141 | + |
| 142 | + wrapper := ZeroHooksFns{ |
| 143 | + AssignUIDsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 144 | + assignUIDsCalled = true |
| 145 | + return &pb.AssignedIds{StartId: 1}, nil |
| 146 | + }, |
| 147 | + AssignTimestampsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 148 | + timestampsCalled = true |
| 149 | + return &pb.AssignedIds{StartId: 100}, nil |
| 150 | + }, |
| 151 | + AssignNsIDsFn: func(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { |
| 152 | + nsIDsCalled = true |
| 153 | + return &pb.AssignedIds{StartId: 1}, nil |
| 154 | + }, |
| 155 | + CommitOrAbortFn: func(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) { |
| 156 | + commitCalled = true |
| 157 | + return &api.TxnContext{CommitTs: 200}, nil |
| 158 | + }, |
| 159 | + ApplyMutationsFn: func(ctx context.Context, m *pb.Mutations) (*api.TxnContext, error) { |
| 160 | + mutationsCalled = true |
| 161 | + return &api.TxnContext{}, nil |
| 162 | + }, |
| 163 | + } |
| 164 | + |
| 165 | + ctx := context.Background() |
| 166 | + |
| 167 | + _, _ = wrapper.AssignUIDs(ctx, &pb.Num{}) |
| 168 | + require.True(t, assignUIDsCalled, "AssignUIDs should delegate to function") |
| 169 | + |
| 170 | + _, _ = wrapper.AssignTimestamps(ctx, &pb.Num{}) |
| 171 | + require.True(t, timestampsCalled, "AssignTimestamps should delegate to function") |
| 172 | + |
| 173 | + _, _ = wrapper.AssignNsIDs(ctx, &pb.Num{}) |
| 174 | + require.True(t, nsIDsCalled, "AssignNsIDs should delegate to function") |
| 175 | + |
| 176 | + _, _ = wrapper.CommitOrAbort(ctx, &api.TxnContext{}) |
| 177 | + require.True(t, commitCalled, "CommitOrAbort should delegate to function") |
| 178 | + |
| 179 | + _, _ = wrapper.ApplyMutations(ctx, &pb.Mutations{}) |
| 180 | + require.True(t, mutationsCalled, "ApplyMutations should delegate to function") |
| 181 | +} |
| 182 | + |
| 183 | +func TestConcurrentEnableDisable(t *testing.T) { |
| 184 | + resetGlobalState() |
| 185 | + defer resetGlobalState() |
| 186 | + |
| 187 | + var wg sync.WaitGroup |
| 188 | + iterations := 100 |
| 189 | + |
| 190 | + // Concurrent enables |
| 191 | + for i := 0; i < iterations; i++ { |
| 192 | + wg.Add(1) |
| 193 | + go func(i int) { |
| 194 | + defer wg.Done() |
| 195 | + cfg := &Config{ |
| 196 | + DataDir: "/tmp/test", |
| 197 | + CacheSizeMB: int64(i), |
| 198 | + } |
| 199 | + Enable(cfg) |
| 200 | + }(i) |
| 201 | + } |
| 202 | + |
| 203 | + // Concurrent disables |
| 204 | + for i := 0; i < iterations; i++ { |
| 205 | + wg.Add(1) |
| 206 | + go func() { |
| 207 | + defer wg.Done() |
| 208 | + Disable() |
| 209 | + }() |
| 210 | + } |
| 211 | + |
| 212 | + // Concurrent reads |
| 213 | + for i := 0; i < iterations; i++ { |
| 214 | + wg.Add(1) |
| 215 | + go func() { |
| 216 | + defer wg.Done() |
| 217 | + _ = IsEnabled() |
| 218 | + _ = GetConfig() |
| 219 | + }() |
| 220 | + } |
| 221 | + |
| 222 | + wg.Wait() |
| 223 | + // Test passes if no race conditions detected |
| 224 | +} |
| 225 | + |
| 226 | +func TestSetDefaultZeroHooks(t *testing.T) { |
| 227 | + resetGlobalState() |
| 228 | + defer resetGlobalState() |
| 229 | + |
| 230 | + mock1 := &mockZeroHooks{} |
| 231 | + mock2 := &mockZeroHooks{} |
| 232 | + |
| 233 | + SetDefaultZeroHooks(mock1) |
| 234 | + require.Equal(t, mock1, GetHooks()) |
| 235 | + |
| 236 | + SetDefaultZeroHooks(mock2) |
| 237 | + require.Equal(t, mock2, GetHooks()) |
| 238 | +} |
0 commit comments