@@ -208,12 +208,35 @@ func main() {
208208
209209 migrationContext .SetConnectionCharset (* charset )
210210
211- if migrationContext .AlterStatement == "" {
211+ if migrationContext .AlterStatement == "" && ! migrationContext . Revert {
212212 log .Fatal ("--alter must be provided and statement must not be empty" )
213213 }
214214 parser := sql .NewParserFromAlterStatement (migrationContext .AlterStatement )
215215 migrationContext .AlterStatementOptions = parser .GetAlterStatementOptions ()
216216
217+ if migrationContext .Revert {
218+ if migrationContext .Resume {
219+ log .Fatal ("--revert cannot be used with --resume" )
220+ }
221+ if migrationContext .OldTableName == "" {
222+ migrationContext .Log .Fatalf ("--revert must be called with --old-table" )
223+ }
224+
225+ // options irrelevant to revert mode
226+ if migrationContext .AlterStatement != "" {
227+ log .Warning ("--alter was provided with --revert, it will be ignored" )
228+ }
229+ if migrationContext .AttemptInstantDDL {
230+ log .Warning ("--attempt-instant-ddl was provided with --revert, it will be ignored" )
231+ }
232+ if migrationContext .IncludeTriggers {
233+ log .Warning ("--include-triggers was provided with --revert, it will be ignored" )
234+ }
235+ if migrationContext .DiscardForeignKeys {
236+ log .Warning ("--discard-foreign-keys was provided with --revert, it will be ignored" )
237+ }
238+ }
239+
217240 if migrationContext .DatabaseName == "" {
218241 if parser .HasExplicitSchema () {
219242 migrationContext .DatabaseName = parser .GetExplicitSchema ()
@@ -293,10 +316,6 @@ func main() {
293316 migrationContext .Log .Fatalf ("--checkpoint-seconds should be >=10" )
294317 }
295318
296- if migrationContext .Revert && migrationContext .OldTableName == "" {
297- migrationContext .Log .Fatalf ("--revert must be called with --old-table" )
298- }
299-
300319 switch * cutOver {
301320 case "atomic" , "default" , "" :
302321 migrationContext .CutOverType = base .CutOverAtomic
@@ -353,15 +372,16 @@ func main() {
353372 acceptSignals (migrationContext )
354373
355374 migrator := logic .NewMigrator (migrationContext , AppVersion )
375+ var err error
356376 if migrationContext .Revert {
357- if err := migrator .Revert (); err != nil {
358- migrationContext .Log .Fatale (err )
359- }
377+ err = migrator .Revert ()
360378 } else {
361- if err := migrator .Migrate (); err != nil {
362- migrator .ExecOnFailureHook ()
363- migrationContext .Log .Fatale (err )
364- }
379+ err = migrator .Migrate ()
380+ }
381+
382+ if err != nil {
383+ migrator .ExecOnFailureHook ()
384+ migrationContext .Log .Fatale (err )
365385 }
366386 fmt .Fprintln (os .Stdout , "# Done" )
367387}
0 commit comments