diff --git a/writers/batchwriter/batchwriter.go b/writers/batchwriter/batchwriter.go index ad85b9fa9f..66c8f742bc 100644 --- a/writers/batchwriter/batchwriter.go +++ b/writers/batchwriter/batchwriter.go @@ -281,7 +281,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag w.deleteStaleMessages = append(w.deleteStaleMessages, m) l := int64(len(w.deleteStaleMessages)) w.deleteStaleLock.Unlock() - if w.batchSize > 0 && l > w.batchSize { + if w.isLimitReached(l) { if err := w.flushDeleteStaleTables(ctx); err != nil { return err } @@ -301,7 +301,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag w.deleteRecordMessages = append(w.deleteRecordMessages, m) l := int64(len(w.deleteRecordMessages)) w.deleteRecordLock.Unlock() - if w.batchSize > 0 && l > w.batchSize { + if w.isLimitReached(l) { if err := w.flushDeleteRecordTables(ctx); err != nil { return err } @@ -328,7 +328,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag w.migrateTableMessages = append(w.migrateTableMessages, m) l := int64(len(w.migrateTableMessages)) w.migrateTableLock.Unlock() - if w.batchSize > 0 && l > w.batchSize { + if w.isLimitReached(l) { if err := w.flushMigrateTables(ctx); err != nil { return err } @@ -338,6 +338,12 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag return nil } +func (w *BatchWriter) isLimitReached(rowCount int64) bool { + limit := batch.CappedAt(0, w.batchSize) + limit.AddRows(rowCount) + return limit.ReachedLimit() +} + func (w *BatchWriter) startWorker(_ context.Context, msg *message.WriteInsert) error { w.workersLock.RLock() md := msg.Record.Schema().Metadata()