Added code to Rocca for training/testing classifier (#114)

* allow Rocca to run without classifiers

Fixed bug that threw an error if no classifier files were specified in Rocca Params dialog

* add rocca switch to enable dev mode

currently only shows/hides extra buttons in the Params dialog, but will
extend to more options in the future

* Fix memory issue with RoccaContourDataBlocks not being released for
garbage collection

Set RoccaContourDataBlock objects to null and stop PamObserver Timer to
force release

* Fix problem tracing whistles in Rocca spectrogram pop-up

Whistle and raw data were being cleared before the user had time to trace out the whistle, causing PAMGuard to throw an exception.  Both were already being cleared when the pop-up window is closed, so no need to do it here.

* updated for training/testing classifiers
This commit is contained in:
m2oswald 2023-11-15 09:17:49 +00:00 committed by GitHub
parent d5f504dcd1
commit d7c4c278d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 303 additions and 62 deletions

View File

@ -32,6 +32,7 @@ import java.io.IOException;
import java.util.EnumMap; import java.util.EnumMap;
import javax.swing.JFileChooser; import javax.swing.JFileChooser;
import javax.swing.filechooser.FileNameExtensionFilter;
import PamUtils.PamCalendar; import PamUtils.PamCalendar;
@ -126,31 +127,34 @@ public class RoccaClassifyThis {
/** the field in the RoccaContourStats object which contains all the stats measures */ /** the field in the RoccaContourStats object which contains all the stats measures */
private EnumMap<RoccaContourStats.ParamIndx, Double> contourStats; private EnumMap<RoccaContourStats.ParamIndx, Double> contourStats;
private String dirIn; /**
* Constructor used when allowing user to select training dataset
/** the input filename */ * */
private String csvIn;
/** the input file */
private File statsFileIn;
/** the output filename */
private String csvOut;
/** the output file */
private File statsFileOut;
/** Constructor */
public RoccaClassifyThis(RoccaProcess roccaProcess) { public RoccaClassifyThis(RoccaProcess roccaProcess) {
File statsFileIn = getTheFile();
// initialize the BufferedReader if (statsFileIn!=null) {
BufferedReader inputFile = null; runTheClassifier(statsFileIn, roccaProcess);
}
}
/**
* Constructor when we pass in the training dataset
*/
public RoccaClassifyThis() {
}
/**
* Ask the user to select the file containing the testing dataset
*
* @return File the csv file containing the testing dataset
*/
public File getTheFile() {
// set the directory // set the directory
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Java\\EclipseWorkspace\\testing\\RoccaClassifyThis_testing"); // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Java\\EclipseWorkspace\\testing\\RoccaClassifyThis_testing");
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Atlantic Classifier\\manual 2-stage data"); // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Atlantic Classifier\\manual 2-stage data");
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Hawaii dataset problems"); // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Hawaii dataset problems");
this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier"); // this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier");
// Define the input and output filenames // Define the input and output filenames
// Hard-coded for now. To Do: query the user for the filename // Hard-coded for now. To Do: query the user for the filename
@ -158,35 +162,54 @@ public class RoccaClassifyThis {
// this.csvIn = new String("Manual_5sp_April 9 2013.csv"); // this.csvIn = new String("Manual_5sp_April 9 2013.csv");
// this.csvIn = new String("CombinedContourStats-fixed.csv"); // this.csvIn = new String("CombinedContourStats-fixed.csv");
// this.csvOut = new String("RoccaContourStatsReclassified.csv"); // this.csvOut = new String("RoccaContourStatsReclassified.csv");
this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv"); // this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv");
statsFileIn = new File(dirIn, csvIn); // statsFileIn = new File(dirIn, csvIn);
this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv"); // this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv");
statsFileOut = new File(dirIn, csvOut); // statsFileOut = new File(dirIn, csvOut);
// let the user select the arff file
JFileChooser fileChooser = new JFileChooser();
fileChooser.setDialogTitle("Select spreadsheet to recalculate...");
fileChooser.setFileHidingEnabled(true);
fileChooser.setApproveButtonText("Select");
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv");
fileChooser.addChoosableFileFilter(restrict);
int state = fileChooser.showOpenDialog(null);
File statsFileIn = null;
if (state == JFileChooser.APPROVE_OPTION) {
// load the file
statsFileIn = fileChooser.getSelectedFile();
return statsFileIn;
} else {
return null;
}
}
/**
* Run the classifier
* @param statsFileIn the File containing the testing dataset
* @param roccaProcess the RoccaProcess instance
*/
public void runTheClassifier(File statsFileIn, RoccaProcess roccaProcess) {
// JFileChooser fileChooser = new JFileChooser(); int index = statsFileIn.getAbsolutePath().lastIndexOf(".");
// fileChooser.setDialogTitle("Select spreadsheet to recalculate..."); String csvOut = statsFileIn.getAbsolutePath().substring(0,index) + "-classified.csv";
// fileChooser.setFileHidingEnabled(true); File statsFileOut = new File(csvOut);
// fileChooser.setApproveButtonText("Select");
// fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
//
// int state = fileChooser.showOpenDialog(this.dirIn);
// if (state == JFileChooser.APPROVE_OPTION) {
// load the classifier // load the classifier
System.out.println("Loading classifier..."); System.out.println("Loading classifier...");
roccaProcess.setClassifierLoaded roccaProcess.setClassifierLoaded
(roccaProcess.roccaClassifier.setUpClassifier()); (roccaProcess.roccaClassifier.setUpClassifier());
// initialize the BufferedReader
BufferedReader inputFile = null;
// open the input file // open the input file
try { try {
System.out.println("Opening input file "+statsFileIn); System.out.println("Opening input file "+statsFileIn);
@ -263,12 +286,45 @@ public class RoccaClassifyThis {
contourStats.put(RoccaContourStats.ParamIndx.FREQPOSSLOPEMEAN, Double.parseDouble(dataArray[34])); contourStats.put(RoccaContourStats.ParamIndx.FREQPOSSLOPEMEAN, Double.parseDouble(dataArray[34]));
contourStats.put(RoccaContourStats.ParamIndx.FREQNEGSLOPEMEAN, Double.parseDouble(dataArray[35])); contourStats.put(RoccaContourStats.ParamIndx.FREQNEGSLOPEMEAN, Double.parseDouble(dataArray[35]));
contourStats.put(RoccaContourStats.ParamIndx.FREQSLOPERATIO, Double.parseDouble(dataArray[36])); contourStats.put(RoccaContourStats.ParamIndx.FREQSLOPERATIO, Double.parseDouble(dataArray[36]));
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37]));
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38])); // Note that we have to modify the FREQBEGSWEEP value. Weka is trained with the FREQBEGSWEEP param
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39])); // as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER,
contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40])); // Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41])); // value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQBEGSWEEP value to be 0 indicating the
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42])); // first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third
// location in the array (which was 'up').
// Ideally we would map the values in the test data to the positions in the training array, but as a quick and
// dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the
/// index positions (0,1,2) is an offset of 1
// Note also that we don't have to do the same thing for FREQBEGUP and FREQBEGDWN since, by coincidence, the training
// values of 0 and 1 happen to match the index locations of 0 and 1
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37]));
double tempVal = Double.parseDouble(dataArray[37]);
tempVal++;
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, tempVal);
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38]));
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39]));
// Note that we have to modify the FREQENDSWEEP value. Weka is trained with the FREQENDSWEEP param
// as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER,
// Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained
// value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQENDSWEEP value to be 0 indicating the
// first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third
// location in the array (which was 'up').
// Ideally we would map the values in the test data to the positions in the training array, but as a quick and
// dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the
/// index positions (0,1,2) is an offset of 1
// Note also that we don't have to do the same thing for FREQENDUP and FREQENDDWN since, by coincidence, the training
// values of 0 and 1 happen to match the index locations of 0 and 1
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40]));
tempVal = Double.parseDouble(dataArray[40]);
tempVal++;
contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, tempVal);
contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41]));
contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42]));
// end of hack
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPDWN, Double.parseDouble(dataArray[43])); contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPDWN, Double.parseDouble(dataArray[43]));
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSDWNUP, Double.parseDouble(dataArray[44])); contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSDWNUP, Double.parseDouble(dataArray[44]));
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPFLAT, Double.parseDouble(dataArray[45])); contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPFLAT, Double.parseDouble(dataArray[45]));
@ -285,8 +341,8 @@ public class RoccaClassifyThis {
contourStats.put(RoccaContourStats.ParamIndx.INFLMEANDELTA, Double.parseDouble(dataArray[56])); contourStats.put(RoccaContourStats.ParamIndx.INFLMEANDELTA, Double.parseDouble(dataArray[56]));
contourStats.put(RoccaContourStats.ParamIndx.INFLSTDDEVDELTA, Double.parseDouble(dataArray[57])); contourStats.put(RoccaContourStats.ParamIndx.INFLSTDDEVDELTA, Double.parseDouble(dataArray[57]));
contourStats.put(RoccaContourStats.ParamIndx.INFLMEDIANDELTA, Double.parseDouble(dataArray[58])); contourStats.put(RoccaContourStats.ParamIndx.INFLMEDIANDELTA, Double.parseDouble(dataArray[58]));
contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59])); //contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59]));
contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60])); //contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60]));
// Run the classifier // Run the classifier
roccaProcess.roccaClassifier.classifyContour2(rcdb); roccaProcess.roccaClassifier.classifyContour2(rcdb);

View File

@ -169,6 +169,7 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
JButton classifier2Button; JButton classifier2Button;
JButton recalcButton; JButton recalcButton;
JButton reclassifyButton; JButton reclassifyButton;
JButton trainThenTestButton;
JButton clearClassifier; JButton clearClassifier;
JComboBox<DefaultComboBoxModel<Vector<String>>> stage1Classes; JComboBox<DefaultComboBoxModel<Vector<String>>> stage1Classes;
DefaultComboBoxModel<Vector<String>> stage1ClassModel; DefaultComboBoxModel<Vector<String>> stage1ClassModel;
@ -513,6 +514,10 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
reclassifyButton.addActionListener(this); reclassifyButton.addActionListener(this);
reclassifyButton.setToolTipText("Load the whistle data from the contour stats output file, and run it through the current Classifier"); reclassifyButton.setToolTipText("Load the whistle data from the contour stats output file, and run it through the current Classifier");
reclassifyButton.setVisible(true); reclassifyButton.setVisible(true);
trainThenTestButton = new JButton("Train then Test");
trainThenTestButton.addActionListener(this);
trainThenTestButton.setToolTipText("Train a classifier on a set of training data, then test it with a set of testing data");
trainThenTestButton.setVisible(true);
// ******** THIS LINES CONTROLS THE VISIBILITY ******** // ******** THIS LINES CONTROLS THE VISIBILITY ********
if (RoccaDev.isEnabled()) { if (RoccaDev.isEnabled()) {
@ -528,13 +533,15 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
extraPanelLayout.createParallelGroup(GroupLayout.Alignment.LEADING) extraPanelLayout.createParallelGroup(GroupLayout.Alignment.LEADING)
.addGroup(extraPanelLayout.createSequentialGroup() .addGroup(extraPanelLayout.createSequentialGroup()
.addComponent(recalcButton) .addComponent(recalcButton)
.addComponent(reclassifyButton)) .addComponent(reclassifyButton)
.addComponent(trainThenTestButton))
); );
extraPanelLayout.setVerticalGroup( extraPanelLayout.setVerticalGroup(
extraPanelLayout.createSequentialGroup() extraPanelLayout.createSequentialGroup()
.addGroup(extraPanelLayout.createParallelGroup(GroupLayout.Alignment.BASELINE) .addGroup(extraPanelLayout.createParallelGroup(GroupLayout.Alignment.BASELINE)
.addComponent(recalcButton) .addComponent(recalcButton)
.addComponent(reclassifyButton)) .addComponent(reclassifyButton)
.addComponent(trainThenTestButton))
); );
classifierPanel.add(extraButtonsSubPanel); classifierPanel.add(extraButtonsSubPanel);
@ -892,7 +899,9 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
} else if (e.getSource() == recalcButton) { } else if (e.getSource() == recalcButton) {
RoccaFixParams recalc = new RoccaFixParams(roccaControl.roccaProcess); RoccaFixParams recalc = new RoccaFixParams(roccaControl.roccaProcess);
} else if (e.getSource() == reclassifyButton) { } else if (e.getSource() == reclassifyButton) {
RoccaClassifyThisEvent reclassify = new RoccaClassifyThisEvent(roccaControl.roccaProcess); RoccaClassifyThis reclassify = new RoccaClassifyThis(roccaControl.roccaProcess);
} else if (e.getSource() == trainThenTestButton) {
RoccaTrainThenTest trainThenTest = new RoccaTrainThenTest(roccaControl.roccaProcess);
} else if (e.getSource() == fftButton) { } else if (e.getSource() == fftButton) {
roccaParameters.setUseFFT(true); roccaParameters.setUseFFT(true);
this.enableTheCorrectSource(); this.enableTheCorrectSource();

View File

@ -145,6 +145,7 @@ public class RoccaRFModel implements java.io.Serializable {
} catch (Exception ex) { } catch (Exception ex) {
System.err.println("1st Classification failed: " + ex.getMessage()); System.err.println("1st Classification failed: " + ex.getMessage());
ex.printStackTrace();
rcdb.setClassifiedAs("Err"); rcdb.setClassifiedAs("Err");
} }
} }

View File

@ -24,10 +24,14 @@
package rocca; package rocca;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader; import java.io.FileReader;
import java.util.Date; import java.util.Date;
import java.util.Enumeration; import java.util.Enumeration;
import javax.swing.JFileChooser;
import javax.swing.filechooser.FileNameExtensionFilter;
import weka.classifiers.trees.RandomForest; import weka.classifiers.trees.RandomForest;
import weka.core.Instances; import weka.core.Instances;
import weka.core.SerializationHelper; import weka.core.SerializationHelper;
@ -42,13 +46,64 @@ import weka.core.SerializationHelper;
*/ */
public class RoccaTrainClassifier { public class RoccaTrainClassifier {
/**
* Standalone implementation
*
* @param args
*/
public static void main(String[] args) { public static void main(String[] args) {
RoccaTrainClassifier rtc = new RoccaTrainClassifier();
File arffFile = rtc.getArff();
if (arffFile!=null) {
String modelName = rtc.trainClassifier(arffFile);
}
}
/**
* Let user choose arff file training dataset
*
* @return File the arff file containing the training dataset
*/
public File getArff() {
// String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408";
// let the user select the arff file
JFileChooser fileChooser = new JFileChooser();
fileChooser.setDialogTitle("Select arff file containing training data");
fileChooser.setFileHidingEnabled(true);
fileChooser.setApproveButtonText("Select");
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .arff files", "arff");
fileChooser.addChoosableFileFilter(restrict);
File arffFile;
int state = fileChooser.showOpenDialog(null);
if (state == JFileChooser.APPROVE_OPTION) {
// load the file
arffFile = fileChooser.getSelectedFile();
return arffFile;
} else {
return null;
}
}
/**
* Actual code to train the classifier
*
*/
public String trainClassifier(File arffFile) {
RandomForest model = new RandomForest (); RandomForest model = new RandomForest ();
Instances trainData = null; Instances trainData = null;
String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408";
// load the ARFF file containing the training set // load the ARFF file containing the training set
System.out.println("Loading data..."); System.out.println("Loading data..." + arffFile.getAbsolutePath());
try { try {
trainData = new Instances trainData = new Instances
(new BufferedReader (new BufferedReader
@ -56,10 +111,13 @@ public class RoccaTrainClassifier {
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\allwhists 12 vars 8sp update 1-28-10.arff"))); // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\allwhists 12 vars 8sp update 1-28-10.arff")));
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\ETP_orcawale_whists2 modified-subset110perspecies-no_harm_ratios.arff"))); // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\ETP_orcawale_whists2 modified-subset110perspecies-no_harm_ratios.arff")));
// ("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\Atl_TrainDF_Event_160829.arff"))); // ("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\Atl_TrainDF_Event_160829.arff")));
(arffFile + ".arff"))); // (arffFile + ".arff")));
(arffFile)));
trainData.setClassIndex(trainData.numAttributes()-1); trainData.setClassIndex(trainData.numAttributes()-1);
} catch (Exception ex) { } catch (Exception ex) {
System.out.println("Error Loading..."); System.out.println("Error Loading...");
ex.printStackTrace();
return null;
} }
// set the classifier parameters // set the classifier parameters
@ -78,6 +136,8 @@ public class RoccaTrainClassifier {
model.setOptions(options); model.setOptions(options);
} catch (Exception ex) { } catch (Exception ex) {
System.out.println("Error setting options..."); System.out.println("Error setting options...");
ex.printStackTrace();
return null;
} }
// train the classifier // train the classifier
@ -90,23 +150,29 @@ public class RoccaTrainClassifier {
new Date()); new Date());
} catch (Exception ex) { } catch (Exception ex) {
System.out.println("Error training classifier..."); System.out.println("Error training classifier...");
ex.printStackTrace();
return null;
} }
// save the classifier // save the classifier
String[] curOptions = model.getOptions(); // String[] curOptions = model.getOptions();
Enumeration test = model.listOptions(); // Enumeration test = model.listOptions();
System.out.println("Saving Classifier...");
Instances header = new Instances(trainData,0); Instances header = new Instances(trainData,0);
int index = arffFile.getAbsolutePath().lastIndexOf(".");
String modelName = arffFile.getAbsolutePath().substring(0,index) + ".model";
System.out.println("Saving Classifier..." + modelName);
try { try {
SerializationHelper.writeAll SerializationHelper.writeAll
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\RF_8sp_54att_110whistle-subset.model", // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\RF_8sp_54att_110whistle-subset.model",
(arffFile + ".model", // (arffFile + ".model",
(modelName,
new Object[]{model,header}); new Object[]{model,header});
System.out.println("Finished!");
return modelName;
} catch (Exception ex) { } catch (Exception ex) {
System.out.println("Error saving classifier..."); System.out.println("Error saving classifier...");
ex.printStackTrace();
} }
return null;
System.out.println("Finished!");
} }
} }

View File

@ -0,0 +1,109 @@
/*
* PAMGUARD - Passive Acoustic Monitoring GUARDianship.
* To assist in the Detection Classification and Localisation
* of marine mammals (cetaceans).
*
* Copyright (C) 2006
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package rocca;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import javax.swing.JFileChooser;
import javax.swing.filechooser.FileNameExtensionFilter;
public class RoccaTrainThenTest {
RoccaTrainClassifier roccaTrainClassifier;
RoccaClassifyThis roccaClassifyThis;
/**
* Main Constructor
* @param roccaProcess
*/
public RoccaTrainThenTest(RoccaProcess roccaProcess) {
// let the user select the csv file containing the training and testing dataset(s)
JFileChooser fileChooser = new JFileChooser();
fileChooser.setDialogTitle("Select csv file with the training/testing pairs");
fileChooser.setFileHidingEnabled(true);
fileChooser.setApproveButtonText("Select");
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv");
fileChooser.addChoosableFileFilter(restrict);
int state = fileChooser.showOpenDialog(null);
if (state == JFileChooser.APPROVE_OPTION) {
// load the file
try {
File csvDataPairs = fileChooser.getSelectedFile();
BufferedReader br = new BufferedReader(new FileReader(csvDataPairs));
String curPath = csvDataPairs.getParent();
// main loop
// read through the csv file one line at a time. The first column should contain the training dataset filename,
// and the second column the testing dataset filename. Paths should be relative to the path containing
// the csv file
String line = "";
String splitBy = ",";
while ((line=br.readLine())!=null) {
String[] filenames = line.split(splitBy);
// train the classifier
File arffFile = new File(curPath + File.separator + filenames[0]);
roccaTrainClassifier = new RoccaTrainClassifier();
String modelName = roccaTrainClassifier.trainClassifier(arffFile);
if (modelName == null) {
System.out.println("ERROR: could not create classifier model from "+arffFile);
continue;
}
// set the classifier as the current one in RoccaParameters
roccaProcess.roccaControl.roccaParameters.setRoccaClassifierModelFilename(new File(modelName));
// test the classifier with the testing dataset
File testFile = new File(curPath + File.separator + filenames[1]);
roccaClassifyThis = new RoccaClassifyThis();
roccaClassifyThis.runTheClassifier(testFile, roccaProcess);
}
} catch (FileNotFoundException e) {
e.printStackTrace();
return;
} catch (IOException e) {
e.printStackTrace();
return;
}
} else {
return;
}
}
}