mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-22 07:02:29 +00:00
Implementing a data selector for the deep learning module
This commit is contained in:
parent
38dc7e6076
commit
15d72c8fd5
@ -76,7 +76,6 @@ public class ExportProcessDialog {
|
||||
|
||||
|
||||
public void createExportGroup() {
|
||||
|
||||
//clear current tasks.
|
||||
dlOfflineGroup.clearTasks();
|
||||
|
||||
@ -88,7 +87,6 @@ public class ExportProcessDialog {
|
||||
dlOfflineGroup.addTask(new ExportTask(dataBlocks.get(i), exportManager));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
////---Swing stuff----/// should not be here but this is how PG works.
|
||||
|
||||
@ -388,7 +386,7 @@ public class ExportProcessDialog {
|
||||
|
||||
|
||||
/**
|
||||
* Export taks group.
|
||||
* Export task
|
||||
*/
|
||||
class ExportTaskGroup extends OfflineTaskGroup {
|
||||
|
||||
|
@ -622,8 +622,9 @@ public class DLControl extends PamControlledUnit implements PamSettings {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the classifier chooser.
|
||||
* @return the classifier chooser.Take it
|
||||
* Get the classifier chooser. The classifier chooser chooses which classifier use
|
||||
* based on a selected file or URL.
|
||||
* @return the classifier chooser.
|
||||
*/
|
||||
public DLClassifierChooser getDlClassifierChooser() {
|
||||
return dlClassifierChooser;
|
||||
|
35
src/rawDeepLearningClassifier/dataSelector/DLDataFilter.java
Normal file
35
src/rawDeepLearningClassifier/dataSelector/DLDataFilter.java
Normal file
@ -0,0 +1,35 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import PamguardMVC.PamDataUnit;
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
|
||||
/**
|
||||
* Score a data unit with a deep learning annotation. Note that this could be an
|
||||
* external data unit with an annotation e.g. a click, or deep learning
|
||||
* detections generated by the deep learning module.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*/
|
||||
public interface DLDataFilter {
|
||||
|
||||
/**
|
||||
* Score a data unit with deep learning annotations
|
||||
* @param dataUnit - the data unit to score
|
||||
* @return 0 to reject and >0 to accept.
|
||||
*/
|
||||
public int scoreDLData(PamDataUnit dataUnit);
|
||||
|
||||
/**
|
||||
* Get parameters for the data filter.
|
||||
* @return parameters for the data selectors.
|
||||
*/
|
||||
public DataSelectParams getParams();
|
||||
|
||||
/**
|
||||
* Set the parameters for the data filter.
|
||||
* @param params - the parameters to set.
|
||||
*/
|
||||
public void setParams(DataSelectParams params);
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import PamguardMVC.PamDataBlock;
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
import PamguardMVC.dataSelector.DataSelector;
|
||||
import PamguardMVC.dataSelector.DataSelectorCreator;
|
||||
import clickDetector.dataSelector.ClickDataSelector;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
|
||||
/**
|
||||
* Creates a data selector for the deep learning module.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*/
|
||||
public class DLDataSelectCreator extends DataSelectorCreator {
|
||||
|
||||
private DLControl dlcontrol;
|
||||
|
||||
public DLDataSelectCreator(DLControl dlcontrol, PamDataBlock pamDataBlock) {
|
||||
super(pamDataBlock);
|
||||
this.dlcontrol = dlcontrol;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSelectParams createNewParams(String name) {
|
||||
return new DLDataSelectorParams();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSelector createDataSelector(String selectorName, boolean allowScores, String selectorType) {
|
||||
return new DLDataSelector(dlcontrol, this.getPamDataBlock(), selectorName, allowScores, selectorType);
|
||||
}
|
||||
|
||||
}
|
120
src/rawDeepLearningClassifier/dataSelector/DLDataSelector.java
Normal file
120
src/rawDeepLearningClassifier/dataSelector/DLDataSelector.java
Normal file
@ -0,0 +1,120 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import PamView.dialog.PamDialogPanel;
|
||||
import PamguardMVC.PamDataBlock;
|
||||
import PamguardMVC.PamDataUnit;
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
import PamguardMVC.dataSelector.DataSelector;
|
||||
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
|
||||
/**
|
||||
* Data selector of DL data units. Note that data selectors are really data
|
||||
* unit specific and not annotation specific.
|
||||
* <p>
|
||||
* The data selector can have different types of data selectors which can
|
||||
* depend on the classifer used and user choice.
|
||||
* <p>
|
||||
* Note that this is slightly different from DLPredicitoDecision
|
||||
* as it deals with data units that may have a more than one prediction.
|
||||
* i.e.
|
||||
*/
|
||||
public class DLDataSelector extends DataSelector {
|
||||
|
||||
|
||||
/**
|
||||
* Swing panel for the data selector.
|
||||
*/
|
||||
private DLSelectPanel dlSelectPanel;
|
||||
|
||||
|
||||
/**
|
||||
* FX panel for the data selector.
|
||||
*/
|
||||
private DLSelectPaneFX dlSelectPaneFX;
|
||||
|
||||
/**
|
||||
* Data filter for filtering.
|
||||
*/
|
||||
private List<DLDataFilter> dataFilters = new ArrayList<DLDataFilter> ();
|
||||
|
||||
|
||||
private DLDataSelectorParams dlDataSelectParams;
|
||||
|
||||
/**
|
||||
* Constructor for the data selector.
|
||||
* @param dlcontrol - reference to the DLControl.
|
||||
* @param pamDataBlock - the data block.
|
||||
* @param selectorName - the selector name.
|
||||
* @param allowScores - allow all the scores.
|
||||
* @param selectorType - the selector type.
|
||||
*/
|
||||
public DLDataSelector(DLControl dlcontrol, PamDataBlock pamDataBlock, String selectorName, boolean allowScores, String selectorType) {
|
||||
super(pamDataBlock, selectorName, allowScores);
|
||||
/****New data filters go here****/
|
||||
dataFilters.add(new DLPredictionFilter(dlcontrol));
|
||||
}
|
||||
|
||||
@Override
|
||||
public PamDialogPanel getDialogPanel() {
|
||||
if (dlSelectPanel == null) {
|
||||
dlSelectPanel = new DLSelectPanel(this);
|
||||
}
|
||||
return dlSelectPanel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DynamicSettingsPane<Boolean> getDialogPaneFX() {
|
||||
if (dlSelectPaneFX == null) {
|
||||
dlSelectPaneFX = new DLSelectPaneFX(this);
|
||||
}
|
||||
return dlSelectPaneFX;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParams(DataSelectParams dataSelectParams) {
|
||||
|
||||
if (dataSelectParams instanceof DLDataSelectorParams) {
|
||||
dlDataSelectParams = (DLDataSelectorParams) dataSelectParams;
|
||||
|
||||
checkDataFilterParamsSize(dlDataSelectParams);
|
||||
|
||||
//set paramters for all data filters.
|
||||
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
|
||||
dataFilters.get(i).setParams((DLDataSelectorParams) dataSelectParams);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that the data filters are the corret size.
|
||||
* @param dataSelectParams - the apramters to set.
|
||||
*/
|
||||
private void checkDataFilterParamsSize(DLDataSelectorParams dataSelectParams) {
|
||||
if (dataSelectParams.dataSelectorParams==null || dataSelectParams.dataSelectorParams.length!=dataFilters.size()) {
|
||||
dataSelectParams.dataSelectorParams = new DataSelectParams[dataFilters.size()];
|
||||
for (int i=0; i<dataSelectParams.dataSelectorParams.length; i++) {
|
||||
dataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSelectParams getParams() {
|
||||
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
|
||||
dlDataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
|
||||
}
|
||||
return dlDataSelectParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double scoreData(PamDataUnit pamDataUnit) {
|
||||
int score = dataFilters.get(dlDataSelectParams.dataSelectorIndex).scoreDLData(pamDataUnit);
|
||||
return score>=0 ? 1 : 0;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
|
||||
/**
|
||||
* Paramters for the DL data seelctor
|
||||
*/
|
||||
public class DLDataSelectorParams extends DataSelectParams implements Serializable, Cloneable{
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public int dataSelectorIndex = 0;
|
||||
|
||||
/**
|
||||
* List of data selector parameters for different deep learning selectors.
|
||||
*/
|
||||
public DataSelectParams[] dataSelectorParams;
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,107 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
|
||||
import PamguardMVC.PamDataUnit;
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
import rawDeepLearningClassifier.dlClassification.DLDetection;
|
||||
import rawDeepLearningClassifier.dlClassification.PredictionResult;
|
||||
|
||||
/**
|
||||
* A data filter which filters data by the maximum prediction value
|
||||
* for different classes.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*/
|
||||
public class DLPredictionFilter implements DLDataFilter {
|
||||
|
||||
/**
|
||||
* Reference to the DLControl
|
||||
*/
|
||||
private DLControl dlcontrol;
|
||||
|
||||
/**
|
||||
* The filter parameters
|
||||
*/
|
||||
private DLPredictionFilterParams filterParams = new DLPredictionFilterParams();
|
||||
|
||||
public DLPredictionFilter(DLControl dlcontrol) {
|
||||
this.dlcontrol = dlcontrol;
|
||||
checkParamsClass() ;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int scoreDLData(PamDataUnit dataUnit) {
|
||||
|
||||
DLDetection dlDetection = (DLDetection) dataUnit;
|
||||
|
||||
//iterate through all results and check that at least one class passes data selection.
|
||||
float[] results;
|
||||
int maxClassIndex = -1;
|
||||
int maxPred = -1;
|
||||
|
||||
//get the maximum prediction index which passes the minimum threshold
|
||||
for (PredictionResult modelResult: dlDetection.getModelResults()) {
|
||||
results = modelResult.getPrediction();
|
||||
for (int j=0; j<results.length; j++) {
|
||||
if (filterParams.classSelect[j] && results[j]>filterParams.minClassPredicton[j]) {
|
||||
if (results[j]> maxPred) maxClassIndex = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxClassIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSelectParams getParams() {
|
||||
checkParamsClass() ;
|
||||
return filterParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParams(DataSelectParams params) {
|
||||
this.filterParams = ((DLPredictionFilterParams) params).clone();
|
||||
checkParamsClass();
|
||||
|
||||
}
|
||||
|
||||
private void checkParamsClass() {
|
||||
if (filterParams.classSelect==null || dlcontrol.getNumClasses()!=filterParams.classSelect.length) {
|
||||
filterParams.classSelect = new boolean[dlcontrol.getNumClasses()];
|
||||
filterParams.minClassPredicton = new double[dlcontrol.getNumClasses()];
|
||||
|
||||
//set default so data selector does not always prevent all classes from showing.
|
||||
for (int i=0; i<filterParams.classSelect.length; i++) {
|
||||
filterParams.classSelect[i]=true;
|
||||
filterParams.minClassPredicton[i]=0.4;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// /**
|
||||
// * Get the index of the highest prediciton value a list of results.
|
||||
// * @param predictions - index of the highest prediction within a matrix of predicitons.
|
||||
// * @return an index of the hghest predictions.
|
||||
// */
|
||||
// public static int[] getBestClass(List<PredictionResult> predictions) {
|
||||
//
|
||||
// float[][] results = new float[predictions.size()][];
|
||||
//
|
||||
// //A detection might have multiple prediction results, i.e. predictions are a matrix. Need
|
||||
// //to iterate through all the predictions and then work out whihc is the maximum. That index is then then]
|
||||
// //class colour.
|
||||
// int i=0;
|
||||
// for (PredictionResult modelResult: predictions) {
|
||||
// results[i] = modelResult.getPrediction();
|
||||
// i++;
|
||||
// }
|
||||
//
|
||||
// int[] indexBest = PamArrayUtils.maxPos(results);
|
||||
//
|
||||
// return indexBest;
|
||||
// }
|
||||
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import PamguardMVC.dataSelector.DataSelectParams;
|
||||
|
||||
/**
|
||||
* Parameters for filtering by the minimum prediciton value.
|
||||
*/
|
||||
public class DLPredictionFilterParams extends DataSelectParams {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* An array indicating which classes are to be used in data selection
|
||||
*/
|
||||
public boolean[] classSelect;
|
||||
|
||||
/**
|
||||
* The minimum class prediction.
|
||||
*/
|
||||
public double[] minClassPredicton;
|
||||
|
||||
/**
|
||||
* Clone the parameters.
|
||||
*/
|
||||
public DLPredictionFilterParams clone() {
|
||||
try {
|
||||
return (DLPredictionFilterParams) super.clone();
|
||||
} catch (CloneNotSupportedException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import javafx.scene.Node;
|
||||
import javafx.scene.control.Label;
|
||||
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
|
||||
|
||||
/**
|
||||
* JavaFX pane for the deep learning data selector.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*/
|
||||
public class DLSelectPaneFX extends DynamicSettingsPane<Boolean>{
|
||||
|
||||
public DLSelectPaneFX(Object ownerWindow) {
|
||||
super(ownerWindow);
|
||||
// TODO Auto-generated constructor stub
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getParams(Boolean currParams) {
|
||||
// TODO Auto-generated method stub
|
||||
return currParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParams(Boolean input) {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "Deep Learning Data Selector:";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Node getContentNode() {
|
||||
return new Label("Hello data selector");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void paneInitialized() {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package rawDeepLearningClassifier.dataSelector;
|
||||
|
||||
import javax.swing.JComponent;
|
||||
|
||||
import PamView.dialog.PamDialogPanel;
|
||||
|
||||
/**
|
||||
* Swing panel for the deep learning data selector.
|
||||
*/
|
||||
public class DLSelectPanel implements PamDialogPanel {
|
||||
|
||||
public DLSelectPanel(DLDataSelector dlDataSelector) {
|
||||
// TODO Auto-generated constructor stub
|
||||
}
|
||||
|
||||
@Override
|
||||
public JComponent getDialogComponent() {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParams() {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean getParams() {
|
||||
// TODO Auto-generated method stub
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
@ -115,7 +115,7 @@ public interface DLClassiferModel {
|
||||
|
||||
/**
|
||||
* Get the parameters which can be serialized from transforms.
|
||||
* @param dlTransfroms- the dl transforms.
|
||||
* @param dlTransfroms- the DL transforms.
|
||||
*/
|
||||
public static ArrayList<DLTransfromParams> getDLTransformParams(ArrayList<DLTransform> dlTransfroms) {
|
||||
ArrayList<DLTransfromParams> dlTransformParams = new ArrayList<DLTransfromParams>();
|
||||
|
@ -415,7 +415,7 @@ public class DLClassifyProcess extends PamInstantProcess {
|
||||
else {
|
||||
/****Add annotation to existing data unit (e.g. click, clip or other RawDataHolder)****/
|
||||
//Need to go by the parent data unit for merging data not the segments. Note that we may still add multiple
|
||||
//predicitions to a single data unit depending on how many segments it contains.
|
||||
//predictions to a single data unit depending on how many segments it contains.
|
||||
|
||||
//System.out.println("New model data " + pamRawData.getParentDataUnit().getUID() + " " + groupDataBuffer[i].size() + " " + modelResultDataBuffer[i].size());
|
||||
|
||||
|
@ -3,7 +3,10 @@ package rawDeepLearningClassifier.dlClassification;
|
||||
import PamView.GroupedDataSource;
|
||||
import PamView.GroupedSourceParameters;
|
||||
import PamguardMVC.AcousticDataBlock;
|
||||
import PamguardMVC.dataSelector.DataSelectorCreator;
|
||||
import clickTrainDetector.dataselector.CTDataSelectCreator;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
import rawDeepLearningClassifier.dataSelector.DLDataSelectCreator;
|
||||
import rawDeepLearningClassifier.tethys.DLSpeciesManager;
|
||||
import rawDeepLearningClassifier.tethys.DLTethysDataProvider;
|
||||
import tethys.TethysControl;
|
||||
@ -18,11 +21,31 @@ import tethys.species.DataBlockSpeciesManager;
|
||||
*/
|
||||
public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> implements GroupedDataSource {
|
||||
|
||||
/**
|
||||
* Reference to the deep learning classifier process.
|
||||
*/
|
||||
private DLClassifyProcess dlClassifyProcess;
|
||||
|
||||
/**
|
||||
* Reference to the Tethys data provider for the deep learning module (provides standardised metadata).
|
||||
*/
|
||||
private DLTethysDataProvider dlTethysDataProvider;
|
||||
|
||||
/**
|
||||
* A species manager for metadata.
|
||||
*/
|
||||
private DLSpeciesManager dlSpeciesManager;
|
||||
|
||||
/**
|
||||
* Reference to the deep learning control.
|
||||
*/
|
||||
private DLControl dlControl;
|
||||
|
||||
/**
|
||||
* Data selector for the deep learning detections.
|
||||
*/
|
||||
private DLDataSelectCreator dlDataSelectCreator;
|
||||
|
||||
public DLDetectionDataBlock(String dataName, DLClassifyProcess parentProcess, int channelMap) {
|
||||
super(DLDetection.class, dataName, parentProcess, channelMap);
|
||||
this.dlClassifyProcess = parentProcess;
|
||||
@ -51,4 +74,17 @@ public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> impleme
|
||||
}
|
||||
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see PamguardMVC.PamDataBlock#getDataSelectCreator()
|
||||
*/
|
||||
@Override
|
||||
public synchronized DataSelectorCreator getDataSelectCreator() {
|
||||
if (dlDataSelectCreator == null) {
|
||||
dlDataSelectCreator = new DLDataSelectCreator(dlControl, this);
|
||||
}
|
||||
return dlDataSelectCreator;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
package rawDeepLearningClassifier.dlClassification;
|
||||
|
||||
/**
|
||||
* Decides whether a prediction has passed a threshold to be used to create a new
|
||||
* data unit.
|
||||
* <p>
|
||||
* Note that the majority of the time this will be a simple test of the value of
|
||||
* predictions of a model but there will be cases when a classifier implements a
|
||||
* more complex system. For example, a implementation could save a buffer of predictions
|
||||
* so that previous predictions inform the latest prediction. Or results may include
|
||||
* some sort of object detection components and frequency bounds etc could be used for
|
||||
* classification.
|
||||
*
|
||||
*
|
||||
*/
|
||||
public interface DLPredictionDecision {
|
||||
|
||||
/**
|
||||
* Check whether a single prediction passes a binary classifier. Prediction which pass decision will be
|
||||
* passed on to create new data units.
|
||||
* @param result - the prediciton result to test.
|
||||
* @return true if the result is passed.
|
||||
*/
|
||||
public boolean isBinaryResult(PredictionResult result);
|
||||
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package rawDeepLearningClassifier.dlClassification;
|
||||
|
||||
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
|
||||
|
||||
/**
|
||||
* Make a decision based on a simple binary threshold for a prediction.
|
||||
*/
|
||||
public class SimpleDLDecision implements DLPredictionDecision {
|
||||
|
||||
/**
|
||||
* Reference to the parameters.
|
||||
*/
|
||||
private StandardModelParams params;
|
||||
|
||||
|
||||
@Override
|
||||
public boolean isBinaryResult(PredictionResult modelResult) {
|
||||
return isBinaryResult(modelResult, getParams()) ;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Check whether a model passes a binary test...
|
||||
* @param modelResult - the model results
|
||||
* @return the model results.
|
||||
*/
|
||||
private static boolean isBinaryResult(PredictionResult modelResult, StandardModelParams genericModelParams) {
|
||||
for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||
|
||||
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
public StandardModelParams getParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
public void setParams( StandardModelParams genericModelParams) {
|
||||
this.params = genericModelParams;
|
||||
}
|
||||
|
||||
}
|
@ -45,6 +45,12 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
||||
*/
|
||||
private TaskThread workerThread;
|
||||
|
||||
/**
|
||||
* Makes a binary decision on whether a prediction result should go on
|
||||
* to be part of a data unit.
|
||||
*/
|
||||
private SimpleDLDecision simpleDLDecision = new SimpleDLDecision();
|
||||
|
||||
|
||||
public StandardClassifierModel(DLControl dlControl) {
|
||||
this.dlControl=dlControl;
|
||||
@ -84,7 +90,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
||||
|
||||
for (int i =0; i<modelResult.size(); i++) {
|
||||
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
|
||||
modelResult.get(i).setBinaryClassification(GenericDLClassifier.isBinaryResult(modelResult.get(i), getDLParams()));
|
||||
modelResult.get(i).setBinaryClassification(isDecision(modelResult.get(i), getDLParams()));
|
||||
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
|
||||
|
||||
}
|
||||
@ -106,7 +112,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
||||
|
||||
@Override
|
||||
public void prepModel() {
|
||||
System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath);
|
||||
// System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath);
|
||||
// StandardModelParams oldParams = getDLParams().clone();
|
||||
|
||||
getDLWorker().prepModel(getDLParams(), dlControl);
|
||||
@ -221,7 +227,8 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
||||
* @return true if a threshold has been met.
|
||||
*/
|
||||
public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) {
|
||||
return isBinaryResult(modelResult, modelParmas);
|
||||
simpleDLDecision.setParams(modelParmas);
|
||||
return simpleDLDecision.isBinaryResult(modelResult);
|
||||
}
|
||||
|
||||
|
||||
@ -241,22 +248,22 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Check whether a model passes a binary test...
|
||||
* @param modelResult - the model results
|
||||
* @return the model results.
|
||||
*/
|
||||
public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) {
|
||||
for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||
|
||||
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// /**
|
||||
// * Check whether a model passes a binary test...
|
||||
// * @param modelResult - the model results
|
||||
// * @return the model results.
|
||||
// */
|
||||
// public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) {
|
||||
// for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||
// //System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||
//
|
||||
// if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||
// // System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||
// return true;
|
||||
// }
|
||||
// }
|
||||
// return false;
|
||||
// }
|
||||
|
||||
|
||||
@Override
|
||||
|
Loading…
Reference in New Issue
Block a user