diff --git a/src/rocca/RoccaClassifyThis.java b/src/rocca/RoccaClassifyThis.java index 1a2d56b2..87987247 100644 --- a/src/rocca/RoccaClassifyThis.java +++ b/src/rocca/RoccaClassifyThis.java @@ -32,6 +32,7 @@ import java.io.IOException; import java.util.EnumMap; import javax.swing.JFileChooser; +import javax.swing.filechooser.FileNameExtensionFilter; import PamUtils.PamCalendar; @@ -126,31 +127,34 @@ public class RoccaClassifyThis { /** the field in the RoccaContourStats object which contains all the stats measures */ private EnumMap contourStats; - private String dirIn; - - /** the input filename */ - private String csvIn; - - /** the input file */ - private File statsFileIn; - - /** the output filename */ - private String csvOut; - - /** the output file */ - private File statsFileOut; - - /** Constructor */ + /** + * Constructor used when allowing user to select training dataset + * */ public RoccaClassifyThis(RoccaProcess roccaProcess) { - - // initialize the BufferedReader - BufferedReader inputFile = null; - + File statsFileIn = getTheFile(); + if (statsFileIn!=null) { + runTheClassifier(statsFileIn, roccaProcess); + } + } + + /** + * Constructor when we pass in the training dataset + */ + public RoccaClassifyThis() { + } + + + /** + * Ask the user to select the file containing the testing dataset + * + * @return File the csv file containing the testing dataset + */ + public File getTheFile() { // set the directory // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Java\\EclipseWorkspace\\testing\\RoccaClassifyThis_testing"); // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Atlantic Classifier\\manual 2-stage data"); // this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Hawaii dataset problems"); - this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier"); +// this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier"); // Define the input and output filenames // Hard-coded for now. To Do: query the user for the filename @@ -158,35 +162,54 @@ public class RoccaClassifyThis { // this.csvIn = new String("Manual_5sp_April 9 2013.csv"); // this.csvIn = new String("CombinedContourStats-fixed.csv"); // this.csvOut = new String("RoccaContourStatsReclassified.csv"); - this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv"); - statsFileIn = new File(dirIn, csvIn); - this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv"); - statsFileOut = new File(dirIn, csvOut); +// this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv"); +// statsFileIn = new File(dirIn, csvIn); +// this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv"); +// statsFileOut = new File(dirIn, csvOut); + // let the user select the arff file + JFileChooser fileChooser = new JFileChooser(); + fileChooser.setDialogTitle("Select spreadsheet to recalculate..."); + fileChooser.setFileHidingEnabled(true); + fileChooser.setApproveButtonText("Select"); + fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); + FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv"); + fileChooser.addChoosableFileFilter(restrict); + + int state = fileChooser.showOpenDialog(null); + File statsFileIn = null; + if (state == JFileChooser.APPROVE_OPTION) { + + // load the file + statsFileIn = fileChooser.getSelectedFile(); + return statsFileIn; + + } else { + return null; + } + } + + + /** + * Run the classifier + * @param statsFileIn the File containing the testing dataset + * @param roccaProcess the RoccaProcess instance + */ + public void runTheClassifier(File statsFileIn, RoccaProcess roccaProcess) { -// JFileChooser fileChooser = new JFileChooser(); -// fileChooser.setDialogTitle("Select spreadsheet to recalculate..."); -// fileChooser.setFileHidingEnabled(true); -// fileChooser.setApproveButtonText("Select"); -// fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); -// -// int state = fileChooser.showOpenDialog(this.dirIn); -// if (state == JFileChooser.APPROVE_OPTION) { - - - - - - - - - + int index = statsFileIn.getAbsolutePath().lastIndexOf("."); + String csvOut = statsFileIn.getAbsolutePath().substring(0,index) + "-classified.csv"; + File statsFileOut = new File(csvOut); + // load the classifier System.out.println("Loading classifier..."); roccaProcess.setClassifierLoaded (roccaProcess.roccaClassifier.setUpClassifier()); + // initialize the BufferedReader + BufferedReader inputFile = null; + // open the input file try { System.out.println("Opening input file "+statsFileIn); @@ -263,12 +286,45 @@ public class RoccaClassifyThis { contourStats.put(RoccaContourStats.ParamIndx.FREQPOSSLOPEMEAN, Double.parseDouble(dataArray[34])); contourStats.put(RoccaContourStats.ParamIndx.FREQNEGSLOPEMEAN, Double.parseDouble(dataArray[35])); contourStats.put(RoccaContourStats.ParamIndx.FREQSLOPERATIO, Double.parseDouble(dataArray[36])); - contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37])); - //contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38])); - //contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39])); - contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40])); - //contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41])); - //contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42])); + + // Note that we have to modify the FREQBEGSWEEP value. Weka is trained with the FREQBEGSWEEP param + // as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER, + // Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained + // value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQBEGSWEEP value to be 0 indicating the + // first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third + // location in the array (which was 'up'). + // Ideally we would map the values in the test data to the positions in the training array, but as a quick and + // dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the + /// index positions (0,1,2) is an offset of 1 + // Note also that we don't have to do the same thing for FREQBEGUP and FREQBEGDWN since, by coincidence, the training + // values of 0 and 1 happen to match the index locations of 0 and 1 + //contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37])); + double tempVal = Double.parseDouble(dataArray[37]); + tempVal++; + contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, tempVal); + contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38])); + contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39])); + + // Note that we have to modify the FREQENDSWEEP value. Weka is trained with the FREQENDSWEEP param + // as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER, + // Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained + // value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQENDSWEEP value to be 0 indicating the + // first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third + // location in the array (which was 'up'). + // Ideally we would map the values in the test data to the positions in the training array, but as a quick and + // dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the + /// index positions (0,1,2) is an offset of 1 + // Note also that we don't have to do the same thing for FREQENDUP and FREQENDDWN since, by coincidence, the training + // values of 0 and 1 happen to match the index locations of 0 and 1 + //contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40])); + tempVal = Double.parseDouble(dataArray[40]); + tempVal++; + contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, tempVal); + contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41])); + contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42])); + // end of hack + + contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPDWN, Double.parseDouble(dataArray[43])); contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSDWNUP, Double.parseDouble(dataArray[44])); contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPFLAT, Double.parseDouble(dataArray[45])); @@ -285,8 +341,8 @@ public class RoccaClassifyThis { contourStats.put(RoccaContourStats.ParamIndx.INFLMEANDELTA, Double.parseDouble(dataArray[56])); contourStats.put(RoccaContourStats.ParamIndx.INFLSTDDEVDELTA, Double.parseDouble(dataArray[57])); contourStats.put(RoccaContourStats.ParamIndx.INFLMEDIANDELTA, Double.parseDouble(dataArray[58])); - contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59])); - contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60])); + //contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59])); + //contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60])); // Run the classifier roccaProcess.roccaClassifier.classifyContour2(rcdb); diff --git a/src/rocca/RoccaParametersDialog.java b/src/rocca/RoccaParametersDialog.java index 2fd1f136..fb74da29 100644 --- a/src/rocca/RoccaParametersDialog.java +++ b/src/rocca/RoccaParametersDialog.java @@ -169,6 +169,7 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener, JButton classifier2Button; JButton recalcButton; JButton reclassifyButton; + JButton trainThenTestButton; JButton clearClassifier; JComboBox>> stage1Classes; DefaultComboBoxModel> stage1ClassModel; @@ -513,6 +514,10 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener, reclassifyButton.addActionListener(this); reclassifyButton.setToolTipText("Load the whistle data from the contour stats output file, and run it through the current Classifier"); reclassifyButton.setVisible(true); + trainThenTestButton = new JButton("Train then Test"); + trainThenTestButton.addActionListener(this); + trainThenTestButton.setToolTipText("Train a classifier on a set of training data, then test it with a set of testing data"); + trainThenTestButton.setVisible(true); // ******** THIS LINES CONTROLS THE VISIBILITY ******** if (RoccaDev.isEnabled()) { @@ -528,13 +533,15 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener, extraPanelLayout.createParallelGroup(GroupLayout.Alignment.LEADING) .addGroup(extraPanelLayout.createSequentialGroup() .addComponent(recalcButton) - .addComponent(reclassifyButton)) + .addComponent(reclassifyButton) + .addComponent(trainThenTestButton)) ); extraPanelLayout.setVerticalGroup( extraPanelLayout.createSequentialGroup() .addGroup(extraPanelLayout.createParallelGroup(GroupLayout.Alignment.BASELINE) .addComponent(recalcButton) - .addComponent(reclassifyButton)) + .addComponent(reclassifyButton) + .addComponent(trainThenTestButton)) ); classifierPanel.add(extraButtonsSubPanel); @@ -892,7 +899,9 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener, } else if (e.getSource() == recalcButton) { RoccaFixParams recalc = new RoccaFixParams(roccaControl.roccaProcess); } else if (e.getSource() == reclassifyButton) { - RoccaClassifyThisEvent reclassify = new RoccaClassifyThisEvent(roccaControl.roccaProcess); + RoccaClassifyThis reclassify = new RoccaClassifyThis(roccaControl.roccaProcess); + } else if (e.getSource() == trainThenTestButton) { + RoccaTrainThenTest trainThenTest = new RoccaTrainThenTest(roccaControl.roccaProcess); } else if (e.getSource() == fftButton) { roccaParameters.setUseFFT(true); this.enableTheCorrectSource(); diff --git a/src/rocca/RoccaRFModel.java b/src/rocca/RoccaRFModel.java index bcf74f4d..e7c9b208 100644 --- a/src/rocca/RoccaRFModel.java +++ b/src/rocca/RoccaRFModel.java @@ -145,6 +145,7 @@ public class RoccaRFModel implements java.io.Serializable { } catch (Exception ex) { System.err.println("1st Classification failed: " + ex.getMessage()); + ex.printStackTrace(); rcdb.setClassifiedAs("Err"); } } diff --git a/src/rocca/RoccaTrainClassifier.java b/src/rocca/RoccaTrainClassifier.java index 09b2c6c7..70acc07d 100644 --- a/src/rocca/RoccaTrainClassifier.java +++ b/src/rocca/RoccaTrainClassifier.java @@ -24,10 +24,14 @@ package rocca; import java.io.BufferedReader; +import java.io.File; import java.io.FileReader; import java.util.Date; import java.util.Enumeration; +import javax.swing.JFileChooser; +import javax.swing.filechooser.FileNameExtensionFilter; + import weka.classifiers.trees.RandomForest; import weka.core.Instances; import weka.core.SerializationHelper; @@ -42,13 +46,64 @@ import weka.core.SerializationHelper; */ public class RoccaTrainClassifier { + + /** + * Standalone implementation + * + * @param args + */ public static void main(String[] args) { + + RoccaTrainClassifier rtc = new RoccaTrainClassifier(); + File arffFile = rtc.getArff(); + if (arffFile!=null) { + String modelName = rtc.trainClassifier(arffFile); + } + } + + + /** + * Let user choose arff file training dataset + * + * @return File the arff file containing the training dataset + */ + public File getArff() { +// String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408"; + + // let the user select the arff file + JFileChooser fileChooser = new JFileChooser(); + fileChooser.setDialogTitle("Select arff file containing training data"); + fileChooser.setFileHidingEnabled(true); + fileChooser.setApproveButtonText("Select"); + fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); + FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .arff files", "arff"); + fileChooser.addChoosableFileFilter(restrict); + File arffFile; + + int state = fileChooser.showOpenDialog(null); + if (state == JFileChooser.APPROVE_OPTION) { + + // load the file + arffFile = fileChooser.getSelectedFile(); + return arffFile; + + } else { + return null; + } + } + + + /** + * Actual code to train the classifier + * + */ + public String trainClassifier(File arffFile) { + RandomForest model = new RandomForest (); Instances trainData = null; - String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408"; // load the ARFF file containing the training set - System.out.println("Loading data..."); + System.out.println("Loading data..." + arffFile.getAbsolutePath()); try { trainData = new Instances (new BufferedReader @@ -56,10 +111,13 @@ public class RoccaTrainClassifier { // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\allwhists 12 vars 8sp update 1-28-10.arff"))); // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\ETP_orcawale_whists2 modified-subset110perspecies-no_harm_ratios.arff"))); // ("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\Atl_TrainDF_Event_160829.arff"))); - (arffFile + ".arff"))); +// (arffFile + ".arff"))); + (arffFile))); trainData.setClassIndex(trainData.numAttributes()-1); } catch (Exception ex) { System.out.println("Error Loading..."); + ex.printStackTrace(); + return null; } // set the classifier parameters @@ -78,6 +136,8 @@ public class RoccaTrainClassifier { model.setOptions(options); } catch (Exception ex) { System.out.println("Error setting options..."); + ex.printStackTrace(); + return null; } // train the classifier @@ -90,23 +150,29 @@ public class RoccaTrainClassifier { new Date()); } catch (Exception ex) { System.out.println("Error training classifier..."); + ex.printStackTrace(); + return null; } // save the classifier - String[] curOptions = model.getOptions(); - Enumeration test = model.listOptions(); - System.out.println("Saving Classifier..."); +// String[] curOptions = model.getOptions(); +// Enumeration test = model.listOptions(); Instances header = new Instances(trainData,0); + int index = arffFile.getAbsolutePath().lastIndexOf("."); + String modelName = arffFile.getAbsolutePath().substring(0,index) + ".model"; + System.out.println("Saving Classifier..." + modelName); try { SerializationHelper.writeAll // ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\RF_8sp_54att_110whistle-subset.model", - (arffFile + ".model", +// (arffFile + ".model", + (modelName, new Object[]{model,header}); + System.out.println("Finished!"); + return modelName; } catch (Exception ex) { System.out.println("Error saving classifier..."); + ex.printStackTrace(); } - - System.out.println("Finished!"); + return null; } - } diff --git a/src/rocca/RoccaTrainThenTest.java b/src/rocca/RoccaTrainThenTest.java new file mode 100644 index 00000000..7af2486f --- /dev/null +++ b/src/rocca/RoccaTrainThenTest.java @@ -0,0 +1,109 @@ +/* + * PAMGUARD - Passive Acoustic Monitoring GUARDianship. + * To assist in the Detection Classification and Localisation + * of marine mammals (cetaceans). + * + * Copyright (C) 2006 + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 3 + * of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + */ + +package rocca; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; + +import javax.swing.JFileChooser; +import javax.swing.filechooser.FileNameExtensionFilter; + +public class RoccaTrainThenTest { + + RoccaTrainClassifier roccaTrainClassifier; + + RoccaClassifyThis roccaClassifyThis; + + + /** + * Main Constructor + * @param roccaProcess + */ + public RoccaTrainThenTest(RoccaProcess roccaProcess) { + + + // let the user select the csv file containing the training and testing dataset(s) + JFileChooser fileChooser = new JFileChooser(); + fileChooser.setDialogTitle("Select csv file with the training/testing pairs"); + fileChooser.setFileHidingEnabled(true); + fileChooser.setApproveButtonText("Select"); + fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); + FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv"); + fileChooser.addChoosableFileFilter(restrict); + + int state = fileChooser.showOpenDialog(null); + if (state == JFileChooser.APPROVE_OPTION) { + + // load the file + try { + File csvDataPairs = fileChooser.getSelectedFile(); + BufferedReader br = new BufferedReader(new FileReader(csvDataPairs)); + String curPath = csvDataPairs.getParent(); + + // main loop + // read through the csv file one line at a time. The first column should contain the training dataset filename, + // and the second column the testing dataset filename. Paths should be relative to the path containing + // the csv file + String line = ""; + String splitBy = ","; + while ((line=br.readLine())!=null) { + + String[] filenames = line.split(splitBy); + + // train the classifier + File arffFile = new File(curPath + File.separator + filenames[0]); + roccaTrainClassifier = new RoccaTrainClassifier(); + String modelName = roccaTrainClassifier.trainClassifier(arffFile); + if (modelName == null) { + System.out.println("ERROR: could not create classifier model from "+arffFile); + continue; + } + + // set the classifier as the current one in RoccaParameters + roccaProcess.roccaControl.roccaParameters.setRoccaClassifierModelFilename(new File(modelName)); + + + // test the classifier with the testing dataset + File testFile = new File(curPath + File.separator + filenames[1]); + roccaClassifyThis = new RoccaClassifyThis(); + roccaClassifyThis.runTheClassifier(testFile, roccaProcess); + + } + + + } catch (FileNotFoundException e) { + e.printStackTrace(); + return; + } catch (IOException e) { + e.printStackTrace(); + return; + } + } else { + return; + } + } + +}