mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-21 22:52:22 +00:00
Added code to Rocca for training/testing classifier (#114)
* allow Rocca to run without classifiers Fixed bug that threw an error if no classifier files were specified in Rocca Params dialog * add rocca switch to enable dev mode currently only shows/hides extra buttons in the Params dialog, but will extend to more options in the future * Fix memory issue with RoccaContourDataBlocks not being released for garbage collection Set RoccaContourDataBlock objects to null and stop PamObserver Timer to force release * Fix problem tracing whistles in Rocca spectrogram pop-up Whistle and raw data were being cleared before the user had time to trace out the whistle, causing PAMGuard to throw an exception. Both were already being cleared when the pop-up window is closed, so no need to do it here. * updated for training/testing classifiers
This commit is contained in:
parent
d5f504dcd1
commit
d7c4c278d8
@ -32,6 +32,7 @@ import java.io.IOException;
|
||||
import java.util.EnumMap;
|
||||
|
||||
import javax.swing.JFileChooser;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
|
||||
import PamUtils.PamCalendar;
|
||||
|
||||
@ -126,31 +127,34 @@ public class RoccaClassifyThis {
|
||||
/** the field in the RoccaContourStats object which contains all the stats measures */
|
||||
private EnumMap<RoccaContourStats.ParamIndx, Double> contourStats;
|
||||
|
||||
private String dirIn;
|
||||
|
||||
/** the input filename */
|
||||
private String csvIn;
|
||||
|
||||
/** the input file */
|
||||
private File statsFileIn;
|
||||
|
||||
/** the output filename */
|
||||
private String csvOut;
|
||||
|
||||
/** the output file */
|
||||
private File statsFileOut;
|
||||
|
||||
/** Constructor */
|
||||
/**
|
||||
* Constructor used when allowing user to select training dataset
|
||||
* */
|
||||
public RoccaClassifyThis(RoccaProcess roccaProcess) {
|
||||
|
||||
// initialize the BufferedReader
|
||||
BufferedReader inputFile = null;
|
||||
|
||||
File statsFileIn = getTheFile();
|
||||
if (statsFileIn!=null) {
|
||||
runTheClassifier(statsFileIn, roccaProcess);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor when we pass in the training dataset
|
||||
*/
|
||||
public RoccaClassifyThis() {
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Ask the user to select the file containing the testing dataset
|
||||
*
|
||||
* @return File the csv file containing the testing dataset
|
||||
*/
|
||||
public File getTheFile() {
|
||||
// set the directory
|
||||
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Java\\EclipseWorkspace\\testing\\RoccaClassifyThis_testing");
|
||||
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Atlantic Classifier\\manual 2-stage data");
|
||||
// this.dirIn = new String("C:\\Users\\Mike\\Documents\\Work\\Tom\\Hawaii dataset problems");
|
||||
this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier");
|
||||
// this.dirIn = new String("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier");
|
||||
|
||||
// Define the input and output filenames
|
||||
// Hard-coded for now. To Do: query the user for the filename
|
||||
@ -158,35 +162,54 @@ public class RoccaClassifyThis {
|
||||
// this.csvIn = new String("Manual_5sp_April 9 2013.csv");
|
||||
// this.csvIn = new String("CombinedContourStats-fixed.csv");
|
||||
// this.csvOut = new String("RoccaContourStatsReclassified.csv");
|
||||
this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv");
|
||||
statsFileIn = new File(dirIn, csvIn);
|
||||
this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv");
|
||||
statsFileOut = new File(dirIn, csvOut);
|
||||
// this.csvIn = new String("Atl_TestDFNoTrain_Call_W_160831.csv");
|
||||
// statsFileIn = new File(dirIn, csvIn);
|
||||
// this.csvOut = new String("Atl_TestDFNoTrain_Call_W_160829-classified.csv");
|
||||
// statsFileOut = new File(dirIn, csvOut);
|
||||
|
||||
// let the user select the arff file
|
||||
JFileChooser fileChooser = new JFileChooser();
|
||||
fileChooser.setDialogTitle("Select spreadsheet to recalculate...");
|
||||
fileChooser.setFileHidingEnabled(true);
|
||||
fileChooser.setApproveButtonText("Select");
|
||||
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
|
||||
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv");
|
||||
fileChooser.addChoosableFileFilter(restrict);
|
||||
|
||||
int state = fileChooser.showOpenDialog(null);
|
||||
File statsFileIn = null;
|
||||
if (state == JFileChooser.APPROVE_OPTION) {
|
||||
|
||||
// load the file
|
||||
statsFileIn = fileChooser.getSelectedFile();
|
||||
return statsFileIn;
|
||||
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Run the classifier
|
||||
* @param statsFileIn the File containing the testing dataset
|
||||
* @param roccaProcess the RoccaProcess instance
|
||||
*/
|
||||
public void runTheClassifier(File statsFileIn, RoccaProcess roccaProcess) {
|
||||
|
||||
// JFileChooser fileChooser = new JFileChooser();
|
||||
// fileChooser.setDialogTitle("Select spreadsheet to recalculate...");
|
||||
// fileChooser.setFileHidingEnabled(true);
|
||||
// fileChooser.setApproveButtonText("Select");
|
||||
// fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
|
||||
//
|
||||
// int state = fileChooser.showOpenDialog(this.dirIn);
|
||||
// if (state == JFileChooser.APPROVE_OPTION) {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
int index = statsFileIn.getAbsolutePath().lastIndexOf(".");
|
||||
String csvOut = statsFileIn.getAbsolutePath().substring(0,index) + "-classified.csv";
|
||||
File statsFileOut = new File(csvOut);
|
||||
|
||||
|
||||
// load the classifier
|
||||
System.out.println("Loading classifier...");
|
||||
roccaProcess.setClassifierLoaded
|
||||
(roccaProcess.roccaClassifier.setUpClassifier());
|
||||
|
||||
// initialize the BufferedReader
|
||||
BufferedReader inputFile = null;
|
||||
|
||||
// open the input file
|
||||
try {
|
||||
System.out.println("Opening input file "+statsFileIn);
|
||||
@ -263,12 +286,45 @@ public class RoccaClassifyThis {
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQPOSSLOPEMEAN, Double.parseDouble(dataArray[34]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQNEGSLOPEMEAN, Double.parseDouble(dataArray[35]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQSLOPERATIO, Double.parseDouble(dataArray[36]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42]));
|
||||
|
||||
// Note that we have to modify the FREQBEGSWEEP value. Weka is trained with the FREQBEGSWEEP param
|
||||
// as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER,
|
||||
// Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained
|
||||
// value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQBEGSWEEP value to be 0 indicating the
|
||||
// first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third
|
||||
// location in the array (which was 'up').
|
||||
// Ideally we would map the values in the test data to the positions in the training array, but as a quick and
|
||||
// dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the
|
||||
/// index positions (0,1,2) is an offset of 1
|
||||
// Note also that we don't have to do the same thing for FREQBEGUP and FREQBEGDWN since, by coincidence, the training
|
||||
// values of 0 and 1 happen to match the index locations of 0 and 1
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, Double.parseDouble(dataArray[37]));
|
||||
double tempVal = Double.parseDouble(dataArray[37]);
|
||||
tempVal++;
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGSWEEP, tempVal);
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGUP, Double.parseDouble(dataArray[38]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQBEGDWN, Double.parseDouble(dataArray[39]));
|
||||
|
||||
// Note that we have to modify the FREQENDSWEEP value. Weka is trained with the FREQENDSWEEP param
|
||||
// as -1=down, 0=flat and 1=up, and that would be how the test data comes through as well. HOWEVER,
|
||||
// Weka assumes that for nominal parameters, the value is the index location (0,1 or 2) and NOT the actual trained
|
||||
// value (-1,0 or 1). So if the whistle has a down sweep, Weka needs the FREQENDSWEEP value to be 0 indicating the
|
||||
// first location in the array (which was 'down'). If it was up, the value would need to be 2 indicating the third
|
||||
// location in the array (which was 'up').
|
||||
// Ideally we would map the values in the test data to the positions in the training array, but as a quick and
|
||||
// dirty hack we'll simply add 1 to the value since the difference between the nominal values (-1,0,1) and the
|
||||
/// index positions (0,1,2) is an offset of 1
|
||||
// Note also that we don't have to do the same thing for FREQENDUP and FREQENDDWN since, by coincidence, the training
|
||||
// values of 0 and 1 happen to match the index locations of 0 and 1
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, Double.parseDouble(dataArray[40]));
|
||||
tempVal = Double.parseDouble(dataArray[40]);
|
||||
tempVal++;
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQENDSWEEP, tempVal);
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQENDUP, Double.parseDouble(dataArray[41]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.FREQENDDWN, Double.parseDouble(dataArray[42]));
|
||||
// end of hack
|
||||
|
||||
|
||||
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPDWN, Double.parseDouble(dataArray[43]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSDWNUP, Double.parseDouble(dataArray[44]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.NUMSWEEPSUPFLAT, Double.parseDouble(dataArray[45]));
|
||||
@ -285,8 +341,8 @@ public class RoccaClassifyThis {
|
||||
contourStats.put(RoccaContourStats.ParamIndx.INFLMEANDELTA, Double.parseDouble(dataArray[56]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.INFLSTDDEVDELTA, Double.parseDouble(dataArray[57]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.INFLMEDIANDELTA, Double.parseDouble(dataArray[58]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59]));
|
||||
contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.INFLDUR, Double.parseDouble(dataArray[59]));
|
||||
//contourStats.put(RoccaContourStats.ParamIndx.STEPDUR, Double.parseDouble(dataArray[60]));
|
||||
|
||||
// Run the classifier
|
||||
roccaProcess.roccaClassifier.classifyContour2(rcdb);
|
||||
|
@ -169,6 +169,7 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
|
||||
JButton classifier2Button;
|
||||
JButton recalcButton;
|
||||
JButton reclassifyButton;
|
||||
JButton trainThenTestButton;
|
||||
JButton clearClassifier;
|
||||
JComboBox<DefaultComboBoxModel<Vector<String>>> stage1Classes;
|
||||
DefaultComboBoxModel<Vector<String>> stage1ClassModel;
|
||||
@ -513,6 +514,10 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
|
||||
reclassifyButton.addActionListener(this);
|
||||
reclassifyButton.setToolTipText("Load the whistle data from the contour stats output file, and run it through the current Classifier");
|
||||
reclassifyButton.setVisible(true);
|
||||
trainThenTestButton = new JButton("Train then Test");
|
||||
trainThenTestButton.addActionListener(this);
|
||||
trainThenTestButton.setToolTipText("Train a classifier on a set of training data, then test it with a set of testing data");
|
||||
trainThenTestButton.setVisible(true);
|
||||
|
||||
// ******** THIS LINES CONTROLS THE VISIBILITY ********
|
||||
if (RoccaDev.isEnabled()) {
|
||||
@ -528,13 +533,15 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
|
||||
extraPanelLayout.createParallelGroup(GroupLayout.Alignment.LEADING)
|
||||
.addGroup(extraPanelLayout.createSequentialGroup()
|
||||
.addComponent(recalcButton)
|
||||
.addComponent(reclassifyButton))
|
||||
.addComponent(reclassifyButton)
|
||||
.addComponent(trainThenTestButton))
|
||||
);
|
||||
extraPanelLayout.setVerticalGroup(
|
||||
extraPanelLayout.createSequentialGroup()
|
||||
.addGroup(extraPanelLayout.createParallelGroup(GroupLayout.Alignment.BASELINE)
|
||||
.addComponent(recalcButton)
|
||||
.addComponent(reclassifyButton))
|
||||
.addComponent(reclassifyButton)
|
||||
.addComponent(trainThenTestButton))
|
||||
);
|
||||
classifierPanel.add(extraButtonsSubPanel);
|
||||
|
||||
@ -892,7 +899,9 @@ public class RoccaParametersDialog extends PamDialog implements ActionListener,
|
||||
} else if (e.getSource() == recalcButton) {
|
||||
RoccaFixParams recalc = new RoccaFixParams(roccaControl.roccaProcess);
|
||||
} else if (e.getSource() == reclassifyButton) {
|
||||
RoccaClassifyThisEvent reclassify = new RoccaClassifyThisEvent(roccaControl.roccaProcess);
|
||||
RoccaClassifyThis reclassify = new RoccaClassifyThis(roccaControl.roccaProcess);
|
||||
} else if (e.getSource() == trainThenTestButton) {
|
||||
RoccaTrainThenTest trainThenTest = new RoccaTrainThenTest(roccaControl.roccaProcess);
|
||||
} else if (e.getSource() == fftButton) {
|
||||
roccaParameters.setUseFFT(true);
|
||||
this.enableTheCorrectSource();
|
||||
|
@ -145,6 +145,7 @@ public class RoccaRFModel implements java.io.Serializable {
|
||||
|
||||
} catch (Exception ex) {
|
||||
System.err.println("1st Classification failed: " + ex.getMessage());
|
||||
ex.printStackTrace();
|
||||
rcdb.setClassifiedAs("Err");
|
||||
}
|
||||
}
|
||||
|
@ -24,10 +24,14 @@
|
||||
package rocca;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.util.Date;
|
||||
import java.util.Enumeration;
|
||||
|
||||
import javax.swing.JFileChooser;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
|
||||
import weka.classifiers.trees.RandomForest;
|
||||
import weka.core.Instances;
|
||||
import weka.core.SerializationHelper;
|
||||
@ -42,13 +46,64 @@ import weka.core.SerializationHelper;
|
||||
*/
|
||||
public class RoccaTrainClassifier {
|
||||
|
||||
|
||||
/**
|
||||
* Standalone implementation
|
||||
*
|
||||
* @param args
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
|
||||
RoccaTrainClassifier rtc = new RoccaTrainClassifier();
|
||||
File arffFile = rtc.getArff();
|
||||
if (arffFile!=null) {
|
||||
String modelName = rtc.trainClassifier(arffFile);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Let user choose arff file training dataset
|
||||
*
|
||||
* @return File the arff file containing the training dataset
|
||||
*/
|
||||
public File getArff() {
|
||||
// String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408";
|
||||
|
||||
// let the user select the arff file
|
||||
JFileChooser fileChooser = new JFileChooser();
|
||||
fileChooser.setDialogTitle("Select arff file containing training data");
|
||||
fileChooser.setFileHidingEnabled(true);
|
||||
fileChooser.setApproveButtonText("Select");
|
||||
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
|
||||
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .arff files", "arff");
|
||||
fileChooser.addChoosableFileFilter(restrict);
|
||||
File arffFile;
|
||||
|
||||
int state = fileChooser.showOpenDialog(null);
|
||||
if (state == JFileChooser.APPROVE_OPTION) {
|
||||
|
||||
// load the file
|
||||
arffFile = fileChooser.getSelectedFile();
|
||||
return arffFile;
|
||||
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Actual code to train the classifier
|
||||
*
|
||||
*/
|
||||
public String trainClassifier(File arffFile) {
|
||||
|
||||
RandomForest model = new RandomForest ();
|
||||
Instances trainData = null;
|
||||
String arffFile = "C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\TP_TrainEvtDF_170408";
|
||||
|
||||
// load the ARFF file containing the training set
|
||||
System.out.println("Loading data...");
|
||||
System.out.println("Loading data..." + arffFile.getAbsolutePath());
|
||||
try {
|
||||
trainData = new Instances
|
||||
(new BufferedReader
|
||||
@ -56,10 +111,13 @@ public class RoccaTrainClassifier {
|
||||
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\allwhists 12 vars 8sp update 1-28-10.arff")));
|
||||
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\ETP_orcawale_whists2 modified-subset110perspecies-no_harm_ratios.arff")));
|
||||
// ("C:\\Users\\SCANS\\Documents\\Work\\Biowaves\\ONR classifier\\Atl_TrainDF_Event_160829.arff")));
|
||||
(arffFile + ".arff")));
|
||||
// (arffFile + ".arff")));
|
||||
(arffFile)));
|
||||
trainData.setClassIndex(trainData.numAttributes()-1);
|
||||
} catch (Exception ex) {
|
||||
System.out.println("Error Loading...");
|
||||
ex.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
|
||||
// set the classifier parameters
|
||||
@ -78,6 +136,8 @@ public class RoccaTrainClassifier {
|
||||
model.setOptions(options);
|
||||
} catch (Exception ex) {
|
||||
System.out.println("Error setting options...");
|
||||
ex.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
|
||||
// train the classifier
|
||||
@ -90,23 +150,29 @@ public class RoccaTrainClassifier {
|
||||
new Date());
|
||||
} catch (Exception ex) {
|
||||
System.out.println("Error training classifier...");
|
||||
ex.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
|
||||
// save the classifier
|
||||
String[] curOptions = model.getOptions();
|
||||
Enumeration test = model.listOptions();
|
||||
System.out.println("Saving Classifier...");
|
||||
// String[] curOptions = model.getOptions();
|
||||
// Enumeration test = model.listOptions();
|
||||
Instances header = new Instances(trainData,0);
|
||||
int index = arffFile.getAbsolutePath().lastIndexOf(".");
|
||||
String modelName = arffFile.getAbsolutePath().substring(0,index) + ".model";
|
||||
System.out.println("Saving Classifier..." + modelName);
|
||||
try {
|
||||
SerializationHelper.writeAll
|
||||
// ("C:\\Users\\Mike\\Documents\\Work\\Java\\WEKA\\weka vs R\\RF_8sp_54att_110whistle-subset.model",
|
||||
(arffFile + ".model",
|
||||
// (arffFile + ".model",
|
||||
(modelName,
|
||||
new Object[]{model,header});
|
||||
System.out.println("Finished!");
|
||||
return modelName;
|
||||
} catch (Exception ex) {
|
||||
System.out.println("Error saving classifier...");
|
||||
ex.printStackTrace();
|
||||
}
|
||||
|
||||
System.out.println("Finished!");
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
109
src/rocca/RoccaTrainThenTest.java
Normal file
109
src/rocca/RoccaTrainThenTest.java
Normal file
@ -0,0 +1,109 @@
|
||||
/*
|
||||
* PAMGUARD - Passive Acoustic Monitoring GUARDianship.
|
||||
* To assist in the Detection Classification and Localisation
|
||||
* of marine mammals (cetaceans).
|
||||
*
|
||||
* Copyright (C) 2006
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU General Public License
|
||||
* as published by the Free Software Foundation; either version 3
|
||||
* of the License, or (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
package rocca;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
|
||||
import javax.swing.JFileChooser;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
|
||||
public class RoccaTrainThenTest {
|
||||
|
||||
RoccaTrainClassifier roccaTrainClassifier;
|
||||
|
||||
RoccaClassifyThis roccaClassifyThis;
|
||||
|
||||
|
||||
/**
|
||||
* Main Constructor
|
||||
* @param roccaProcess
|
||||
*/
|
||||
public RoccaTrainThenTest(RoccaProcess roccaProcess) {
|
||||
|
||||
|
||||
// let the user select the csv file containing the training and testing dataset(s)
|
||||
JFileChooser fileChooser = new JFileChooser();
|
||||
fileChooser.setDialogTitle("Select csv file with the training/testing pairs");
|
||||
fileChooser.setFileHidingEnabled(true);
|
||||
fileChooser.setApproveButtonText("Select");
|
||||
fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
|
||||
FileNameExtensionFilter restrict = new FileNameExtensionFilter("Only .csv files", "csv");
|
||||
fileChooser.addChoosableFileFilter(restrict);
|
||||
|
||||
int state = fileChooser.showOpenDialog(null);
|
||||
if (state == JFileChooser.APPROVE_OPTION) {
|
||||
|
||||
// load the file
|
||||
try {
|
||||
File csvDataPairs = fileChooser.getSelectedFile();
|
||||
BufferedReader br = new BufferedReader(new FileReader(csvDataPairs));
|
||||
String curPath = csvDataPairs.getParent();
|
||||
|
||||
// main loop
|
||||
// read through the csv file one line at a time. The first column should contain the training dataset filename,
|
||||
// and the second column the testing dataset filename. Paths should be relative to the path containing
|
||||
// the csv file
|
||||
String line = "";
|
||||
String splitBy = ",";
|
||||
while ((line=br.readLine())!=null) {
|
||||
|
||||
String[] filenames = line.split(splitBy);
|
||||
|
||||
// train the classifier
|
||||
File arffFile = new File(curPath + File.separator + filenames[0]);
|
||||
roccaTrainClassifier = new RoccaTrainClassifier();
|
||||
String modelName = roccaTrainClassifier.trainClassifier(arffFile);
|
||||
if (modelName == null) {
|
||||
System.out.println("ERROR: could not create classifier model from "+arffFile);
|
||||
continue;
|
||||
}
|
||||
|
||||
// set the classifier as the current one in RoccaParameters
|
||||
roccaProcess.roccaControl.roccaParameters.setRoccaClassifierModelFilename(new File(modelName));
|
||||
|
||||
|
||||
// test the classifier with the testing dataset
|
||||
File testFile = new File(curPath + File.separator + filenames[1]);
|
||||
roccaClassifyThis = new RoccaClassifyThis();
|
||||
roccaClassifyThis.runTheClassifier(testFile, roccaProcess);
|
||||
|
||||
}
|
||||
|
||||
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
return;
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user