Skip to content

Commit 66d0704

Browse files
ConsuelitaConsuelita
authored andcommitted
TrainingData Bugs
1 parent 9147d39 commit 66d0704

3 files changed

Lines changed: 177 additions & 98 deletions

File tree

AdversaryLabSwift/Controllers/FileController.swift

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import ZIPFoundation
1414

1515
class FileController
1616
{
17-
let trainingDataFilename = "TrainingResults"
17+
let trainingDataFilename = "TrainingResults.json"
1818

1919
func saveModel(classifier: MLClassifier,
2020
classifierMetadata: MLModelMetadata,
@@ -220,53 +220,57 @@ class FileController
220220

221221
func unpack(adversaryURL: URL) -> URL?
222222
{
223+
guard let appDirectory = getAdversarySupportDirectory()
224+
else
225+
{
226+
print("Failed to save the file, could not find the application document directory.")
227+
return nil
228+
}
229+
223230
let modelGroupName = "temp"
231+
let temporaryDirURL = appDirectory.appendingPathComponent(modelGroupName, isDirectory: true)
224232

225-
if let appDirectory = prepareDirectory(groupName: modelGroupName)
233+
if FileManager.default.fileExists(atPath: temporaryDirURL.path)
234+
{
235+
try? FileManager.default.removeItem(at: temporaryDirURL)
236+
}
237+
238+
do
226239
{
227-
let temporaryDirURL = appDirectory.appendingPathComponent(modelGroupName, isDirectory: true)
240+
try FileManager.default.unzipItem(at: adversaryURL, to: temporaryDirURL, progress: nil, preferredEncoding: nil)
228241

229-
do
242+
let fileURLS = try FileManager.default.contentsOfDirectory(at: temporaryDirURL, includingPropertiesForKeys: nil, options: .skipsHiddenFiles)
243+
244+
print("\nUnzipped item at: \(adversaryURL.path)\nto: \(temporaryDirURL.path)")
245+
246+
if fileURLS.count == 1, fileURLS[0].hasDirectoryPath
230247
{
231-
try FileManager.default.unzipItem(at: adversaryURL, to: temporaryDirURL, progress: nil, preferredEncoding: nil)
248+
print("Unpacked model files to: \(fileURLS[0])")
249+
loadTrainingData(from: fileURLS[0])
232250

233-
let fileURLS = try FileManager.default.contentsOfDirectory(at: temporaryDirURL, includingPropertiesForKeys: nil, options: .skipsHiddenFiles)
234-
235-
print("\nUnzipped item at: \(adversaryURL.path)\nto: \(temporaryDirURL.path)")
236-
237-
if fileURLS.count == 1, fileURLS[0].hasDirectoryPath
238-
{
239-
print("Unpacked model files to: \(fileURLS[0])")
240-
loadTrainingData(from: fileURLS[0])
241-
242-
return fileURLS[0]
243-
}
244-
else
251+
return fileURLS[0]
252+
}
253+
else
254+
{
255+
for fileURL in fileURLS
245256
{
246-
for fileURL in fileURLS
257+
print("\nFound file: \(fileURL.path)")
258+
if fileURL.pathExtension == "mlmodel"
247259
{
248-
print("\nFound file: \(fileURL.path)")
249-
if fileURL.pathExtension == "mlmodel"
250-
{
251-
print("\nFound an mlm file in the chosen directory: \(fileURL)")
252-
}
260+
print("\nFound an mlm file in the chosen directory: \(fileURL)")
253261
}
254-
255-
print("Unpacked model files to: \(temporaryDirURL)")
256-
loadTrainingData(from: temporaryDirURL)
257-
258-
return temporaryDirURL
259262
}
260-
}
261-
catch let unzipError
262-
{
263-
print("\nError unzipping item at \(adversaryURL) to \(temporaryDirURL): \n\(unzipError)")
264263

265-
return nil
264+
print("Unpacked model files to: \(temporaryDirURL)")
265+
loadTrainingData(from: temporaryDirURL)
266+
267+
return temporaryDirURL
266268
}
267269
}
268-
else
270+
catch let unzipError
269271
{
272+
print("\nError unzipping item at \(adversaryURL) to \(temporaryDirURL): \n\(unzipError)")
273+
270274
return nil
271275
}
272276

@@ -284,17 +288,21 @@ class FileController
284288

285289
let groupURL = appDirectory.appendingPathComponent(groupName)
286290

287-
if !fileManager.fileExists(atPath: groupURL.path)
291+
guard fileManager.fileExists(atPath: groupURL.path)
292+
else
288293
{
289-
do
290-
{
291-
_ = try fileManager.createDirectory(at: groupURL, withIntermediateDirectories: true, attributes: nil)
292-
}
293-
catch let directoryError
294-
{
295-
print("\nError creating group directory: \(directoryError)")
296-
return nil
297-
}
294+
print("Group directory does not exist.")
295+
return nil
296+
}
297+
298+
do
299+
{
300+
_ = try fileManager.createDirectory(at: groupURL, withIntermediateDirectories: true, attributes: nil)
301+
}
302+
catch let directoryError
303+
{
304+
print("\nError creating group directory: \(directoryError)")
305+
return nil
298306
}
299307

300308
return appDirectory
@@ -315,6 +323,7 @@ class FileController
315323
{
316324
let trainingBytes = try Data(contentsOf: fileUrl)
317325
trainingData = try decoder.decode(TrainingData.self, from: trainingBytes)
326+
return
318327
}
319328
catch let decodeError
320329
{

AdversaryLabSwift/Helpers/Alerts.swift

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,27 @@
88

99
import Cocoa
1010

11-
// TODO: Call this when there is no appropriate data to be processed in the rdb file
12-
func showNoDataAlert()
11+
/// Call this when there is no appropriate data to be processed
12+
func showNoDataAlert(completion:@escaping (_ completion:Bool) -> Void)
1313
{
1414
let alert = NSAlert()
1515
alert.messageText = "Not enough packets to process"
1616
alert.informativeText = "There is not enough valid data in the selected database file."
17-
alert.runModal()
17+
let result = alert.runModal()
18+
19+
if result.rawValue == 0
20+
{
21+
if let selectedFileURL = showRethinkFileAlert()
22+
{
23+
FileController().loadSongFile(fileURL: selectedFileURL, completion: completion)
24+
}
25+
}
26+
else
27+
{
28+
print("Alert result: \(result)")
29+
completion(false)
30+
return
31+
}
1832
}
1933

2034
func showNoBlockedConnectionsAlert()

AdversaryLabSwift/ViewController.swift

Lines changed: 107 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -261,42 +261,30 @@ class ViewController: NSViewController, NSTabViewDelegate, ChartViewDelegate
261261
// Identify which tab we need to update
262262
guard let identifier = tabView.selectedTabViewItem?.identifier as? String,
263263
let currentTab = TabIds(rawValue: identifier)
264-
else { return }
264+
else
265+
{
266+
stopActivityIndicator()
267+
return
268+
269+
}
265270

266271
switch currentTab
267272
{
268273
case .TestMode:
269274
runTest()
270-
case .TrainingMode: // In Training mode we need a name so we can save the model files
271-
guard connectionGroupData.aConnectionData.connections.count > 6, connectionGroupData.bConnectionData.connections.count > 6
272-
else
273-
{
274-
showNoDataAlert()
275-
return
276-
}
277-
278-
if let name = showNameModelAlert()
279-
{
280-
print("Time to analyze some things.")
281-
configModel.modelName = name
282-
connectionInspector.analyzeConnections(configModel: configModel, resetTrainingData: true, resetTestingData: false)
283-
updateProgressIndicator()
284-
}
285-
else
286-
{
287-
sender.state = .off
288-
refreshDBUI()
289-
}
290-
291-
275+
case .TrainingMode:
276+
runTraining()
292277
case .DataMode:
293278
print("Data mode selected. Nothing to do here.")
279+
stopActivityIndicator()
280+
return
294281
}
295282
}
296283
else
297284
{
298285
print("Pause bot engage!! 🤖")
299286
updateProgressIndicator()
287+
stopActivityIndicator()
300288
}
301289
}
302290

@@ -348,10 +336,19 @@ class ViewController: NSViewController, NSTabViewDelegate, ChartViewDelegate
348336
configModel.modelName = modelName
349337

350338
// Unpack to a temporary directory
351-
modelDirectoryURL = FileController().unpack(adversaryURL: selectedURL)
352-
runTest()
339+
if let maybeModelDirectory = FileController().unpack(adversaryURL: selectedURL)
340+
{
341+
modelDirectoryURL = maybeModelDirectory
342+
runTest()
343+
}
344+
else
345+
{
346+
print("🚨 Failed to unpack the selected adversary file. 🚨")
347+
}
353348
}
349+
354350
loadDataButton.isEnabled = true
351+
stopActivityIndicator()
355352

356353
case .TrainingMode:
357354
DispatchQueue.main.async {
@@ -390,6 +387,14 @@ class ViewController: NSViewController, NSTabViewDelegate, ChartViewDelegate
390387
}
391388
}
392389

390+
func stopActivityIndicator()
391+
{
392+
DispatchQueue.main.async
393+
{
394+
self.activityIndicator.stopAnimation(nil)
395+
}
396+
}
397+
393398
// MARK: - Charts
394399

395400
// func chartValueSelected(_ chartView: ChartViewBase, entry: ChartDataEntry, highlight: Highlight)
@@ -657,44 +662,95 @@ class ViewController: NSViewController, NSTabViewDelegate, ChartViewDelegate
657662
{
658663
configModel.processingEnabled = true
659664

660-
guard connectionGroupData.bConnectionData.connections.count > 1
661-
else
665+
if connectionGroupData.bConnectionData.connections.count < 1
662666
{
663-
showNoBlockedConnectionsAlert()
664-
return
667+
// Prompt the user to select a data file to load
668+
showNoDataAlert
669+
{ (dataLoaded) in
670+
671+
if dataLoaded
672+
{
673+
self.runTest()
674+
}
675+
else
676+
{
677+
self.stopActivityIndicator()
678+
return
679+
}
680+
}
665681
}
666-
667-
// Make sure that we have gotten an Adversary file and unpacked it to a temporary directory
668-
// TODO: Delete this directory on program exit
669-
if modelDirectoryURL == nil
682+
else
670683
{
671-
// Get the user to select the correct .adversary file
672-
if let selectedURL = showSelectAdversaryFileAlert()
684+
// Make sure that we have gotten an Adversary file and unpacked it to a temporary directory
685+
// TODO: Delete this directory on program exit
686+
if modelDirectoryURL == nil
673687
{
674-
// Model Group Name should be the same as the directory
675-
modelName = selectedURL.deletingPathExtension().lastPathComponent
676-
677-
// Unpack to a temporary directory
678-
modelDirectoryURL = FileController().unpack(adversaryURL: selectedURL)
679-
runTest()
688+
// Get the user to select the correct .adversary file
689+
if let selectedURL = showSelectAdversaryFileAlert()
690+
{
691+
// Model Group Name should be the same as the directory
692+
modelName = selectedURL.deletingPathExtension().lastPathComponent
693+
694+
// Unpack to a temporary directory
695+
modelDirectoryURL = FileController().unpack(adversaryURL: selectedURL)
696+
runTest()
697+
}
698+
else
699+
{
700+
processPacketsButton.state = .off
701+
stopActivityIndicator()
702+
return
703+
}
680704
}
681-
else
705+
706+
if !modelDirectoryURL!.hasDirectoryPath
682707
{
683-
processPacketsButton.state = .off
708+
// Unpack to a temporary directory
709+
modelDirectoryURL = FileController().unpack(adversaryURL: modelDirectoryURL!)
684710
}
685711

686-
return
712+
configModel.modelName = modelDirectoryURL!.deletingPathExtension().lastPathComponent
713+
connectionInspector.analyzeConnections(configModel: configModel, resetTrainingData: false, resetTestingData: true)
714+
updateProgressIndicator()
715+
stopActivityIndicator()
687716
}
688-
689-
if !modelDirectoryURL!.hasDirectoryPath
717+
}
718+
719+
func runTraining()
720+
{
721+
// In Training mode we need a name so we can save the model files
722+
if connectionGroupData.aConnectionData.connections.count < 6, connectionGroupData.bConnectionData.connections.count < 6
690723
{
691-
// Unpack to a temporary directory
692-
modelDirectoryURL = FileController().unpack(adversaryURL: modelDirectoryURL!)
724+
showNoDataAlert
725+
{
726+
(dataLoaded) in
727+
728+
if dataLoaded
729+
{
730+
self.runTraining()
731+
}
732+
else
733+
{
734+
self.stopActivityIndicator()
735+
return
736+
}
737+
}
693738
}
694739

695-
configModel.modelName = modelDirectoryURL!.deletingPathExtension().lastPathComponent
696-
connectionInspector.analyzeConnections(configModel: configModel, resetTrainingData: false, resetTestingData: true)
697-
updateProgressIndicator()
740+
if let name = showNameModelAlert()
741+
{
742+
print("Time to analyze some things.")
743+
configModel.modelName = name
744+
connectionInspector.analyzeConnections(configModel: configModel, resetTrainingData: true, resetTestingData: false)
745+
updateProgressIndicator()
746+
stopActivityIndicator()
747+
return
748+
}
749+
else
750+
{
751+
stopActivityIndicator()
752+
return
753+
}
698754
}
699755

700756
// MARK: - Alerts

0 commit comments

Comments
 (0)