Implementing a data selector for the deep learning module

This commit is contained in:
Jamie Mac 2024-07-10 21:40:53 +01:00
parent 38dc7e6076
commit 15d72c8fd5
17 changed files with 578 additions and 27 deletions

View File

@ -76,7 +76,6 @@ public class ExportProcessDialog {
public void createExportGroup() { public void createExportGroup() {
//clear current tasks. //clear current tasks.
dlOfflineGroup.clearTasks(); dlOfflineGroup.clearTasks();
@ -88,7 +87,6 @@ public class ExportProcessDialog {
dlOfflineGroup.addTask(new ExportTask(dataBlocks.get(i), exportManager)); dlOfflineGroup.addTask(new ExportTask(dataBlocks.get(i), exportManager));
} }
} }
} }
////---Swing stuff----/// should not be here but this is how PG works. ////---Swing stuff----/// should not be here but this is how PG works.
@ -388,7 +386,7 @@ public class ExportProcessDialog {
/** /**
* Export taks group. * Export task
*/ */
class ExportTaskGroup extends OfflineTaskGroup { class ExportTaskGroup extends OfflineTaskGroup {

View File

@ -622,8 +622,9 @@ public class DLControl extends PamControlledUnit implements PamSettings {
} }
/** /**
* Get the classifier chooser. * Get the classifier chooser. The classifier chooser chooses which classifier use
* @return the classifier chooser.Take it * based on a selected file or URL.
* @return the classifier chooser.
*/ */
public DLClassifierChooser getDlClassifierChooser() { public DLClassifierChooser getDlClassifierChooser() {
return dlClassifierChooser; return dlClassifierChooser;

View File

@ -361,7 +361,7 @@ public class DLSymbolOptionPane extends StandardSymbolModifierPane {
classNameBox2.getSelectionModel().select(Math.max(symbolOptions.classIndex2, 0)); classNameBox2.getSelectionModel().select(Math.max(symbolOptions.classIndex2, 0));
int index = symbolOptions.classIndex2>=0? symbolOptions.classIndex2 : 0; int index = symbolOptions.classIndex2>=0 ? symbolOptions.classIndex2 : 0;
if (symbolOptions.classColors==null) { if (symbolOptions.classColors==null) {
symbolOptions.setDefaultClassColors(nClass); symbolOptions.setDefaultClassColors(nClass);

View File

@ -0,0 +1,35 @@
package rawDeepLearningClassifier.dataSelector;
import PamguardMVC.PamDataUnit;
import PamguardMVC.dataSelector.DataSelectParams;
/**
* Score a data unit with a deep learning annotation. Note that this could be an
* external data unit with an annotation e.g. a click, or deep learning
* detections generated by the deep learning module.
*
* @author Jamie Macaulay
*/
public interface DLDataFilter {
/**
* Score a data unit with deep learning annotations
* @param dataUnit - the data unit to score
* @return 0 to reject and >0 to accept.
*/
public int scoreDLData(PamDataUnit dataUnit);
/**
* Get parameters for the data filter.
* @return parameters for the data selectors.
*/
public DataSelectParams getParams();
/**
* Set the parameters for the data filter.
* @param params - the parameters to set.
*/
public void setParams(DataSelectParams params);
}

View File

@ -0,0 +1,34 @@
package rawDeepLearningClassifier.dataSelector;
import PamguardMVC.PamDataBlock;
import PamguardMVC.dataSelector.DataSelectParams;
import PamguardMVC.dataSelector.DataSelector;
import PamguardMVC.dataSelector.DataSelectorCreator;
import clickDetector.dataSelector.ClickDataSelector;
import rawDeepLearningClassifier.DLControl;
/**
* Creates a data selector for the deep learning module.
*
* @author Jamie Macaulay
*/
public class DLDataSelectCreator extends DataSelectorCreator {
private DLControl dlcontrol;
public DLDataSelectCreator(DLControl dlcontrol, PamDataBlock pamDataBlock) {
super(pamDataBlock);
this.dlcontrol = dlcontrol;
}
@Override
public DataSelectParams createNewParams(String name) {
return new DLDataSelectorParams();
}
@Override
public DataSelector createDataSelector(String selectorName, boolean allowScores, String selectorType) {
return new DLDataSelector(dlcontrol, this.getPamDataBlock(), selectorName, allowScores, selectorType);
}
}

View File

@ -0,0 +1,120 @@
package rawDeepLearningClassifier.dataSelector;
import java.util.ArrayList;
import java.util.List;
import PamView.dialog.PamDialogPanel;
import PamguardMVC.PamDataBlock;
import PamguardMVC.PamDataUnit;
import PamguardMVC.dataSelector.DataSelectParams;
import PamguardMVC.dataSelector.DataSelector;
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
import rawDeepLearningClassifier.DLControl;
/**
* Data selector of DL data units. Note that data selectors are really data
* unit specific and not annotation specific.
* <p>
* The data selector can have different types of data selectors which can
* depend on the classifer used and user choice.
* <p>
* Note that this is slightly different from DLPredicitoDecision
* as it deals with data units that may have a more than one prediction.
* i.e.
*/
public class DLDataSelector extends DataSelector {
/**
* Swing panel for the data selector.
*/
private DLSelectPanel dlSelectPanel;
/**
* FX panel for the data selector.
*/
private DLSelectPaneFX dlSelectPaneFX;
/**
* Data filter for filtering.
*/
private List<DLDataFilter> dataFilters = new ArrayList<DLDataFilter> ();
private DLDataSelectorParams dlDataSelectParams;
/**
* Constructor for the data selector.
* @param dlcontrol - reference to the DLControl.
* @param pamDataBlock - the data block.
* @param selectorName - the selector name.
* @param allowScores - allow all the scores.
* @param selectorType - the selector type.
*/
public DLDataSelector(DLControl dlcontrol, PamDataBlock pamDataBlock, String selectorName, boolean allowScores, String selectorType) {
super(pamDataBlock, selectorName, allowScores);
/****New data filters go here****/
dataFilters.add(new DLPredictionFilter(dlcontrol));
}
@Override
public PamDialogPanel getDialogPanel() {
if (dlSelectPanel == null) {
dlSelectPanel = new DLSelectPanel(this);
}
return dlSelectPanel;
}
@Override
public DynamicSettingsPane<Boolean> getDialogPaneFX() {
if (dlSelectPaneFX == null) {
dlSelectPaneFX = new DLSelectPaneFX(this);
}
return dlSelectPaneFX;
}
@Override
public void setParams(DataSelectParams dataSelectParams) {
if (dataSelectParams instanceof DLDataSelectorParams) {
dlDataSelectParams = (DLDataSelectorParams) dataSelectParams;
checkDataFilterParamsSize(dlDataSelectParams);
//set paramters for all data filters.
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
dataFilters.get(i).setParams((DLDataSelectorParams) dataSelectParams);
}
}
}
/**
* Check that the data filters are the corret size.
* @param dataSelectParams - the apramters to set.
*/
private void checkDataFilterParamsSize(DLDataSelectorParams dataSelectParams) {
if (dataSelectParams.dataSelectorParams==null || dataSelectParams.dataSelectorParams.length!=dataFilters.size()) {
dataSelectParams.dataSelectorParams = new DataSelectParams[dataFilters.size()];
for (int i=0; i<dataSelectParams.dataSelectorParams.length; i++) {
dataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
}
}
}
@Override
public DataSelectParams getParams() {
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
dlDataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
}
return dlDataSelectParams;
}
@Override
public double scoreData(PamDataUnit pamDataUnit) {
int score = dataFilters.get(dlDataSelectParams.dataSelectorIndex).scoreDLData(pamDataUnit);
return score>=0 ? 1 : 0;
}
}

View File

@ -0,0 +1,22 @@
package rawDeepLearningClassifier.dataSelector;
import java.io.Serializable;
import PamguardMVC.dataSelector.DataSelectParams;
/**
* Paramters for the DL data seelctor
*/
public class DLDataSelectorParams extends DataSelectParams implements Serializable, Cloneable{
private static final long serialVersionUID = 1L;
public int dataSelectorIndex = 0;
/**
* List of data selector parameters for different deep learning selectors.
*/
public DataSelectParams[] dataSelectorParams;
}

View File

@ -0,0 +1,107 @@
package rawDeepLearningClassifier.dataSelector;
import PamguardMVC.PamDataUnit;
import PamguardMVC.dataSelector.DataSelectParams;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.DLDetection;
import rawDeepLearningClassifier.dlClassification.PredictionResult;
/**
* A data filter which filters data by the maximum prediction value
* for different classes.
*
* @author Jamie Macaulay
*/
public class DLPredictionFilter implements DLDataFilter {
/**
* Reference to the DLControl
*/
private DLControl dlcontrol;
/**
* The filter parameters
*/
private DLPredictionFilterParams filterParams = new DLPredictionFilterParams();
public DLPredictionFilter(DLControl dlcontrol) {
this.dlcontrol = dlcontrol;
checkParamsClass() ;
}
@Override
public int scoreDLData(PamDataUnit dataUnit) {
DLDetection dlDetection = (DLDetection) dataUnit;
//iterate through all results and check that at least one class passes data selection.
float[] results;
int maxClassIndex = -1;
int maxPred = -1;
//get the maximum prediction index which passes the minimum threshold
for (PredictionResult modelResult: dlDetection.getModelResults()) {
results = modelResult.getPrediction();
for (int j=0; j<results.length; j++) {
if (filterParams.classSelect[j] && results[j]>filterParams.minClassPredicton[j]) {
if (results[j]> maxPred) maxClassIndex = j;
}
}
}
return maxClassIndex;
}
@Override
public DataSelectParams getParams() {
checkParamsClass() ;
return filterParams;
}
@Override
public void setParams(DataSelectParams params) {
this.filterParams = ((DLPredictionFilterParams) params).clone();
checkParamsClass();
}
private void checkParamsClass() {
if (filterParams.classSelect==null || dlcontrol.getNumClasses()!=filterParams.classSelect.length) {
filterParams.classSelect = new boolean[dlcontrol.getNumClasses()];
filterParams.minClassPredicton = new double[dlcontrol.getNumClasses()];
//set default so data selector does not always prevent all classes from showing.
for (int i=0; i<filterParams.classSelect.length; i++) {
filterParams.classSelect[i]=true;
filterParams.minClassPredicton[i]=0.4;
}
}
}
// /**
// * Get the index of the highest prediciton value a list of results.
// * @param predictions - index of the highest prediction within a matrix of predicitons.
// * @return an index of the hghest predictions.
// */
// public static int[] getBestClass(List<PredictionResult> predictions) {
//
// float[][] results = new float[predictions.size()][];
//
// //A detection might have multiple prediction results, i.e. predictions are a matrix. Need
// //to iterate through all the predictions and then work out whihc is the maximum. That index is then then]
// //class colour.
// int i=0;
// for (PredictionResult modelResult: predictions) {
// results[i] = modelResult.getPrediction();
// i++;
// }
//
// int[] indexBest = PamArrayUtils.maxPos(results);
//
// return indexBest;
// }
}

View File

@ -0,0 +1,35 @@
package rawDeepLearningClassifier.dataSelector;
import PamguardMVC.dataSelector.DataSelectParams;
/**
* Parameters for filtering by the minimum prediciton value.
*/
public class DLPredictionFilterParams extends DataSelectParams {
private static final long serialVersionUID = 1L;
/**
* An array indicating which classes are to be used in data selection
*/
public boolean[] classSelect;
/**
* The minimum class prediction.
*/
public double[] minClassPredicton;
/**
* Clone the parameters.
*/
public DLPredictionFilterParams clone() {
try {
return (DLPredictionFilterParams) super.clone();
} catch (CloneNotSupportedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return null;
}
}

View File

@ -0,0 +1,47 @@
package rawDeepLearningClassifier.dataSelector;
import javafx.scene.Node;
import javafx.scene.control.Label;
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
/**
* JavaFX pane for the deep learning data selector.
*
* @author Jamie Macaulay
*/
public class DLSelectPaneFX extends DynamicSettingsPane<Boolean>{
public DLSelectPaneFX(Object ownerWindow) {
super(ownerWindow);
// TODO Auto-generated constructor stub
}
@Override
public Boolean getParams(Boolean currParams) {
// TODO Auto-generated method stub
return currParams;
}
@Override
public void setParams(Boolean input) {
// TODO Auto-generated method stub
}
@Override
public String getName() {
return "Deep Learning Data Selector:";
}
@Override
public Node getContentNode() {
return new Label("Hello data selector");
}
@Override
public void paneInitialized() {
// TODO Auto-generated method stub
}
}

View File

@ -0,0 +1,34 @@
package rawDeepLearningClassifier.dataSelector;
import javax.swing.JComponent;
import PamView.dialog.PamDialogPanel;
/**
* Swing panel for the deep learning data selector.
*/
public class DLSelectPanel implements PamDialogPanel {
public DLSelectPanel(DLDataSelector dlDataSelector) {
// TODO Auto-generated constructor stub
}
@Override
public JComponent getDialogComponent() {
// TODO Auto-generated method stub
return null;
}
@Override
public void setParams() {
// TODO Auto-generated method stub
}
@Override
public boolean getParams() {
// TODO Auto-generated method stub
return false;
}
}

View File

@ -115,7 +115,7 @@ public interface DLClassiferModel {
/** /**
* Get the parameters which can be serialized from transforms. * Get the parameters which can be serialized from transforms.
* @param dlTransfroms- the dl transforms. * @param dlTransfroms- the DL transforms.
*/ */
public static ArrayList<DLTransfromParams> getDLTransformParams(ArrayList<DLTransform> dlTransfroms) { public static ArrayList<DLTransfromParams> getDLTransformParams(ArrayList<DLTransform> dlTransfroms) {
ArrayList<DLTransfromParams> dlTransformParams = new ArrayList<DLTransfromParams>(); ArrayList<DLTransfromParams> dlTransformParams = new ArrayList<DLTransfromParams>();

View File

@ -415,7 +415,7 @@ public class DLClassifyProcess extends PamInstantProcess {
else { else {
/****Add annotation to existing data unit (e.g. click, clip or other RawDataHolder)****/ /****Add annotation to existing data unit (e.g. click, clip or other RawDataHolder)****/
//Need to go by the parent data unit for merging data not the segments. Note that we may still add multiple //Need to go by the parent data unit for merging data not the segments. Note that we may still add multiple
//predicitions to a single data unit depending on how many segments it contains. //predictions to a single data unit depending on how many segments it contains.
//System.out.println("New model data " + pamRawData.getParentDataUnit().getUID() + " " + groupDataBuffer[i].size() + " " + modelResultDataBuffer[i].size()); //System.out.println("New model data " + pamRawData.getParentDataUnit().getUID() + " " + groupDataBuffer[i].size() + " " + modelResultDataBuffer[i].size());

View File

@ -3,7 +3,10 @@ package rawDeepLearningClassifier.dlClassification;
import PamView.GroupedDataSource; import PamView.GroupedDataSource;
import PamView.GroupedSourceParameters; import PamView.GroupedSourceParameters;
import PamguardMVC.AcousticDataBlock; import PamguardMVC.AcousticDataBlock;
import PamguardMVC.dataSelector.DataSelectorCreator;
import clickTrainDetector.dataselector.CTDataSelectCreator;
import rawDeepLearningClassifier.DLControl; import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dataSelector.DLDataSelectCreator;
import rawDeepLearningClassifier.tethys.DLSpeciesManager; import rawDeepLearningClassifier.tethys.DLSpeciesManager;
import rawDeepLearningClassifier.tethys.DLTethysDataProvider; import rawDeepLearningClassifier.tethys.DLTethysDataProvider;
import tethys.TethysControl; import tethys.TethysControl;
@ -18,10 +21,30 @@ import tethys.species.DataBlockSpeciesManager;
*/ */
public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> implements GroupedDataSource { public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> implements GroupedDataSource {
/**
* Reference to the deep learning classifier process.
*/
private DLClassifyProcess dlClassifyProcess; private DLClassifyProcess dlClassifyProcess;
/**
* Reference to the Tethys data provider for the deep learning module (provides standardised metadata).
*/
private DLTethysDataProvider dlTethysDataProvider; private DLTethysDataProvider dlTethysDataProvider;
/**
* A species manager for metadata.
*/
private DLSpeciesManager dlSpeciesManager; private DLSpeciesManager dlSpeciesManager;
/**
* Reference to the deep learning control.
*/
private DLControl dlControl; private DLControl dlControl;
/**
* Data selector for the deep learning detections.
*/
private DLDataSelectCreator dlDataSelectCreator;
public DLDetectionDataBlock(String dataName, DLClassifyProcess parentProcess, int channelMap) { public DLDetectionDataBlock(String dataName, DLClassifyProcess parentProcess, int channelMap) {
super(DLDetection.class, dataName, parentProcess, channelMap); super(DLDetection.class, dataName, parentProcess, channelMap);
@ -49,6 +72,19 @@ public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> impleme
} }
return dlSpeciesManager; return dlSpeciesManager;
} }
/* (non-Javadoc)
* @see PamguardMVC.PamDataBlock#getDataSelectCreator()
*/
@Override
public synchronized DataSelectorCreator getDataSelectCreator() {
if (dlDataSelectCreator == null) {
dlDataSelectCreator = new DLDataSelectCreator(dlControl, this);
}
return dlDataSelectCreator;
}
} }

View File

@ -0,0 +1,26 @@
package rawDeepLearningClassifier.dlClassification;
/**
* Decides whether a prediction has passed a threshold to be used to create a new
* data unit.
* <p>
* Note that the majority of the time this will be a simple test of the value of
* predictions of a model but there will be cases when a classifier implements a
* more complex system. For example, a implementation could save a buffer of predictions
* so that previous predictions inform the latest prediction. Or results may include
* some sort of object detection components and frequency bounds etc could be used for
* classification.
*
*
*/
public interface DLPredictionDecision {
/**
* Check whether a single prediction passes a binary classifier. Prediction which pass decision will be
* passed on to create new data units.
* @param result - the prediciton result to test.
* @return true if the result is passed.
*/
public boolean isBinaryResult(PredictionResult result);
}

View File

@ -0,0 +1,49 @@
package rawDeepLearningClassifier.dlClassification;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
/**
* Make a decision based on a simple binary threshold for a prediction.
*/
public class SimpleDLDecision implements DLPredictionDecision {
/**
* Reference to the parameters.
*/
private StandardModelParams params;
@Override
public boolean isBinaryResult(PredictionResult modelResult) {
return isBinaryResult(modelResult, getParams()) ;
}
/**
* Check whether a model passes a binary test...
* @param modelResult - the model results
* @return the model results.
*/
private static boolean isBinaryResult(PredictionResult modelResult, StandardModelParams genericModelParams) {
for (int i=0; i<modelResult.getPrediction().length; i++) {
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
return true;
}
}
return false;
}
public StandardModelParams getParams() {
return params;
}
public void setParams( StandardModelParams genericModelParams) {
this.params = genericModelParams;
}
}

View File

@ -44,6 +44,12 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
* in real time without slowing down the rest of PAMGaurd. * in real time without slowing down the rest of PAMGaurd.
*/ */
private TaskThread workerThread; private TaskThread workerThread;
/**
* Makes a binary decision on whether a prediction result should go on
* to be part of a data unit.
*/
private SimpleDLDecision simpleDLDecision = new SimpleDLDecision();
public StandardClassifierModel(DLControl dlControl) { public StandardClassifierModel(DLControl dlControl) {
@ -84,7 +90,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
for (int i =0; i<modelResult.size(); i++) { for (int i =0; i<modelResult.size(); i++) {
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams())); modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
modelResult.get(i).setBinaryClassification(GenericDLClassifier.isBinaryResult(modelResult.get(i), getDLParams())); modelResult.get(i).setBinaryClassification(isDecision(modelResult.get(i), getDLParams()));
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds()); modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
} }
@ -106,7 +112,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
@Override @Override
public void prepModel() { public void prepModel() {
System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath); // System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath);
// StandardModelParams oldParams = getDLParams().clone(); // StandardModelParams oldParams = getDLParams().clone();
getDLWorker().prepModel(getDLParams(), dlControl); getDLWorker().prepModel(getDLParams(), dlControl);
@ -221,7 +227,8 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
* @return true if a threshold has been met. * @return true if a threshold has been met.
*/ */
public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) { public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) {
return isBinaryResult(modelResult, modelParmas); simpleDLDecision.setParams(modelParmas);
return simpleDLDecision.isBinaryResult(modelResult);
} }
@ -241,22 +248,22 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
/** // /**
* Check whether a model passes a binary test... // * Check whether a model passes a binary test...
* @param modelResult - the model results // * @param modelResult - the model results
* @return the model results. // * @return the model results.
*/ // */
public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) { // public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) {
for (int i=0; i<modelResult.getPrediction().length; i++) { // for (int i=0; i<modelResult.getPrediction().length; i++) {
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length); // //System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
//
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) { // if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]); // // System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
return true; // return true;
} // }
} // }
return false; // return false;
} // }
@Override @Override