mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-22 07:02:29 +00:00
Implementing a data selector for the deep learning module
This commit is contained in:
parent
38dc7e6076
commit
15d72c8fd5
@ -76,7 +76,6 @@ public class ExportProcessDialog {
|
|||||||
|
|
||||||
|
|
||||||
public void createExportGroup() {
|
public void createExportGroup() {
|
||||||
|
|
||||||
//clear current tasks.
|
//clear current tasks.
|
||||||
dlOfflineGroup.clearTasks();
|
dlOfflineGroup.clearTasks();
|
||||||
|
|
||||||
@ -88,7 +87,6 @@ public class ExportProcessDialog {
|
|||||||
dlOfflineGroup.addTask(new ExportTask(dataBlocks.get(i), exportManager));
|
dlOfflineGroup.addTask(new ExportTask(dataBlocks.get(i), exportManager));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
////---Swing stuff----/// should not be here but this is how PG works.
|
////---Swing stuff----/// should not be here but this is how PG works.
|
||||||
|
|
||||||
@ -388,7 +386,7 @@ public class ExportProcessDialog {
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Export taks group.
|
* Export task
|
||||||
*/
|
*/
|
||||||
class ExportTaskGroup extends OfflineTaskGroup {
|
class ExportTaskGroup extends OfflineTaskGroup {
|
||||||
|
|
||||||
|
@ -622,8 +622,9 @@ public class DLControl extends PamControlledUnit implements PamSettings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the classifier chooser.
|
* Get the classifier chooser. The classifier chooser chooses which classifier use
|
||||||
* @return the classifier chooser.Take it
|
* based on a selected file or URL.
|
||||||
|
* @return the classifier chooser.
|
||||||
*/
|
*/
|
||||||
public DLClassifierChooser getDlClassifierChooser() {
|
public DLClassifierChooser getDlClassifierChooser() {
|
||||||
return dlClassifierChooser;
|
return dlClassifierChooser;
|
||||||
|
@ -361,7 +361,7 @@ public class DLSymbolOptionPane extends StandardSymbolModifierPane {
|
|||||||
classNameBox2.getSelectionModel().select(Math.max(symbolOptions.classIndex2, 0));
|
classNameBox2.getSelectionModel().select(Math.max(symbolOptions.classIndex2, 0));
|
||||||
|
|
||||||
|
|
||||||
int index = symbolOptions.classIndex2>=0? symbolOptions.classIndex2 : 0;
|
int index = symbolOptions.classIndex2>=0 ? symbolOptions.classIndex2 : 0;
|
||||||
|
|
||||||
if (symbolOptions.classColors==null) {
|
if (symbolOptions.classColors==null) {
|
||||||
symbolOptions.setDefaultClassColors(nClass);
|
symbolOptions.setDefaultClassColors(nClass);
|
||||||
|
35
src/rawDeepLearningClassifier/dataSelector/DLDataFilter.java
Normal file
35
src/rawDeepLearningClassifier/dataSelector/DLDataFilter.java
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import PamguardMVC.PamDataUnit;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Score a data unit with a deep learning annotation. Note that this could be an
|
||||||
|
* external data unit with an annotation e.g. a click, or deep learning
|
||||||
|
* detections generated by the deep learning module.
|
||||||
|
*
|
||||||
|
* @author Jamie Macaulay
|
||||||
|
*/
|
||||||
|
public interface DLDataFilter {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Score a data unit with deep learning annotations
|
||||||
|
* @param dataUnit - the data unit to score
|
||||||
|
* @return 0 to reject and >0 to accept.
|
||||||
|
*/
|
||||||
|
public int scoreDLData(PamDataUnit dataUnit);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get parameters for the data filter.
|
||||||
|
* @return parameters for the data selectors.
|
||||||
|
*/
|
||||||
|
public DataSelectParams getParams();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the parameters for the data filter.
|
||||||
|
* @param params - the parameters to set.
|
||||||
|
*/
|
||||||
|
public void setParams(DataSelectParams params);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import PamguardMVC.PamDataBlock;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
import PamguardMVC.dataSelector.DataSelector;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectorCreator;
|
||||||
|
import clickDetector.dataSelector.ClickDataSelector;
|
||||||
|
import rawDeepLearningClassifier.DLControl;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a data selector for the deep learning module.
|
||||||
|
*
|
||||||
|
* @author Jamie Macaulay
|
||||||
|
*/
|
||||||
|
public class DLDataSelectCreator extends DataSelectorCreator {
|
||||||
|
|
||||||
|
private DLControl dlcontrol;
|
||||||
|
|
||||||
|
public DLDataSelectCreator(DLControl dlcontrol, PamDataBlock pamDataBlock) {
|
||||||
|
super(pamDataBlock);
|
||||||
|
this.dlcontrol = dlcontrol;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSelectParams createNewParams(String name) {
|
||||||
|
return new DLDataSelectorParams();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSelector createDataSelector(String selectorName, boolean allowScores, String selectorType) {
|
||||||
|
return new DLDataSelector(dlcontrol, this.getPamDataBlock(), selectorName, allowScores, selectorType);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
120
src/rawDeepLearningClassifier/dataSelector/DLDataSelector.java
Normal file
120
src/rawDeepLearningClassifier/dataSelector/DLDataSelector.java
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import PamView.dialog.PamDialogPanel;
|
||||||
|
import PamguardMVC.PamDataBlock;
|
||||||
|
import PamguardMVC.PamDataUnit;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
import PamguardMVC.dataSelector.DataSelector;
|
||||||
|
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
|
||||||
|
import rawDeepLearningClassifier.DLControl;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data selector of DL data units. Note that data selectors are really data
|
||||||
|
* unit specific and not annotation specific.
|
||||||
|
* <p>
|
||||||
|
* The data selector can have different types of data selectors which can
|
||||||
|
* depend on the classifer used and user choice.
|
||||||
|
* <p>
|
||||||
|
* Note that this is slightly different from DLPredicitoDecision
|
||||||
|
* as it deals with data units that may have a more than one prediction.
|
||||||
|
* i.e.
|
||||||
|
*/
|
||||||
|
public class DLDataSelector extends DataSelector {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swing panel for the data selector.
|
||||||
|
*/
|
||||||
|
private DLSelectPanel dlSelectPanel;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FX panel for the data selector.
|
||||||
|
*/
|
||||||
|
private DLSelectPaneFX dlSelectPaneFX;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data filter for filtering.
|
||||||
|
*/
|
||||||
|
private List<DLDataFilter> dataFilters = new ArrayList<DLDataFilter> ();
|
||||||
|
|
||||||
|
|
||||||
|
private DLDataSelectorParams dlDataSelectParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor for the data selector.
|
||||||
|
* @param dlcontrol - reference to the DLControl.
|
||||||
|
* @param pamDataBlock - the data block.
|
||||||
|
* @param selectorName - the selector name.
|
||||||
|
* @param allowScores - allow all the scores.
|
||||||
|
* @param selectorType - the selector type.
|
||||||
|
*/
|
||||||
|
public DLDataSelector(DLControl dlcontrol, PamDataBlock pamDataBlock, String selectorName, boolean allowScores, String selectorType) {
|
||||||
|
super(pamDataBlock, selectorName, allowScores);
|
||||||
|
/****New data filters go here****/
|
||||||
|
dataFilters.add(new DLPredictionFilter(dlcontrol));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public PamDialogPanel getDialogPanel() {
|
||||||
|
if (dlSelectPanel == null) {
|
||||||
|
dlSelectPanel = new DLSelectPanel(this);
|
||||||
|
}
|
||||||
|
return dlSelectPanel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DynamicSettingsPane<Boolean> getDialogPaneFX() {
|
||||||
|
if (dlSelectPaneFX == null) {
|
||||||
|
dlSelectPaneFX = new DLSelectPaneFX(this);
|
||||||
|
}
|
||||||
|
return dlSelectPaneFX;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setParams(DataSelectParams dataSelectParams) {
|
||||||
|
|
||||||
|
if (dataSelectParams instanceof DLDataSelectorParams) {
|
||||||
|
dlDataSelectParams = (DLDataSelectorParams) dataSelectParams;
|
||||||
|
|
||||||
|
checkDataFilterParamsSize(dlDataSelectParams);
|
||||||
|
|
||||||
|
//set paramters for all data filters.
|
||||||
|
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
|
||||||
|
dataFilters.get(i).setParams((DLDataSelectorParams) dataSelectParams);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check that the data filters are the corret size.
|
||||||
|
* @param dataSelectParams - the apramters to set.
|
||||||
|
*/
|
||||||
|
private void checkDataFilterParamsSize(DLDataSelectorParams dataSelectParams) {
|
||||||
|
if (dataSelectParams.dataSelectorParams==null || dataSelectParams.dataSelectorParams.length!=dataFilters.size()) {
|
||||||
|
dataSelectParams.dataSelectorParams = new DataSelectParams[dataFilters.size()];
|
||||||
|
for (int i=0; i<dataSelectParams.dataSelectorParams.length; i++) {
|
||||||
|
dataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSelectParams getParams() {
|
||||||
|
for (int i=0; i<dlDataSelectParams.dataSelectorParams.length; i++) {
|
||||||
|
dlDataSelectParams.dataSelectorParams[i] = dataFilters.get(i).getParams();
|
||||||
|
}
|
||||||
|
return dlDataSelectParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double scoreData(PamDataUnit pamDataUnit) {
|
||||||
|
int score = dataFilters.get(dlDataSelectParams.dataSelectorIndex).scoreDLData(pamDataUnit);
|
||||||
|
return score>=0 ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Paramters for the DL data seelctor
|
||||||
|
*/
|
||||||
|
public class DLDataSelectorParams extends DataSelectParams implements Serializable, Cloneable{
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public int dataSelectorIndex = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List of data selector parameters for different deep learning selectors.
|
||||||
|
*/
|
||||||
|
public DataSelectParams[] dataSelectorParams;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,107 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
|
||||||
|
import PamguardMVC.PamDataUnit;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
import rawDeepLearningClassifier.DLControl;
|
||||||
|
import rawDeepLearningClassifier.dlClassification.DLDetection;
|
||||||
|
import rawDeepLearningClassifier.dlClassification.PredictionResult;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A data filter which filters data by the maximum prediction value
|
||||||
|
* for different classes.
|
||||||
|
*
|
||||||
|
* @author Jamie Macaulay
|
||||||
|
*/
|
||||||
|
public class DLPredictionFilter implements DLDataFilter {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the DLControl
|
||||||
|
*/
|
||||||
|
private DLControl dlcontrol;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The filter parameters
|
||||||
|
*/
|
||||||
|
private DLPredictionFilterParams filterParams = new DLPredictionFilterParams();
|
||||||
|
|
||||||
|
public DLPredictionFilter(DLControl dlcontrol) {
|
||||||
|
this.dlcontrol = dlcontrol;
|
||||||
|
checkParamsClass() ;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int scoreDLData(PamDataUnit dataUnit) {
|
||||||
|
|
||||||
|
DLDetection dlDetection = (DLDetection) dataUnit;
|
||||||
|
|
||||||
|
//iterate through all results and check that at least one class passes data selection.
|
||||||
|
float[] results;
|
||||||
|
int maxClassIndex = -1;
|
||||||
|
int maxPred = -1;
|
||||||
|
|
||||||
|
//get the maximum prediction index which passes the minimum threshold
|
||||||
|
for (PredictionResult modelResult: dlDetection.getModelResults()) {
|
||||||
|
results = modelResult.getPrediction();
|
||||||
|
for (int j=0; j<results.length; j++) {
|
||||||
|
if (filterParams.classSelect[j] && results[j]>filterParams.minClassPredicton[j]) {
|
||||||
|
if (results[j]> maxPred) maxClassIndex = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxClassIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSelectParams getParams() {
|
||||||
|
checkParamsClass() ;
|
||||||
|
return filterParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setParams(DataSelectParams params) {
|
||||||
|
this.filterParams = ((DLPredictionFilterParams) params).clone();
|
||||||
|
checkParamsClass();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkParamsClass() {
|
||||||
|
if (filterParams.classSelect==null || dlcontrol.getNumClasses()!=filterParams.classSelect.length) {
|
||||||
|
filterParams.classSelect = new boolean[dlcontrol.getNumClasses()];
|
||||||
|
filterParams.minClassPredicton = new double[dlcontrol.getNumClasses()];
|
||||||
|
|
||||||
|
//set default so data selector does not always prevent all classes from showing.
|
||||||
|
for (int i=0; i<filterParams.classSelect.length; i++) {
|
||||||
|
filterParams.classSelect[i]=true;
|
||||||
|
filterParams.minClassPredicton[i]=0.4;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Get the index of the highest prediciton value a list of results.
|
||||||
|
// * @param predictions - index of the highest prediction within a matrix of predicitons.
|
||||||
|
// * @return an index of the hghest predictions.
|
||||||
|
// */
|
||||||
|
// public static int[] getBestClass(List<PredictionResult> predictions) {
|
||||||
|
//
|
||||||
|
// float[][] results = new float[predictions.size()][];
|
||||||
|
//
|
||||||
|
// //A detection might have multiple prediction results, i.e. predictions are a matrix. Need
|
||||||
|
// //to iterate through all the predictions and then work out whihc is the maximum. That index is then then]
|
||||||
|
// //class colour.
|
||||||
|
// int i=0;
|
||||||
|
// for (PredictionResult modelResult: predictions) {
|
||||||
|
// results[i] = modelResult.getPrediction();
|
||||||
|
// i++;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// int[] indexBest = PamArrayUtils.maxPos(results);
|
||||||
|
//
|
||||||
|
// return indexBest;
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import PamguardMVC.dataSelector.DataSelectParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameters for filtering by the minimum prediciton value.
|
||||||
|
*/
|
||||||
|
public class DLPredictionFilterParams extends DataSelectParams {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array indicating which classes are to be used in data selection
|
||||||
|
*/
|
||||||
|
public boolean[] classSelect;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum class prediction.
|
||||||
|
*/
|
||||||
|
public double[] minClassPredicton;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clone the parameters.
|
||||||
|
*/
|
||||||
|
public DLPredictionFilterParams clone() {
|
||||||
|
try {
|
||||||
|
return (DLPredictionFilterParams) super.clone();
|
||||||
|
} catch (CloneNotSupportedException e) {
|
||||||
|
// TODO Auto-generated catch block
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import javafx.scene.Node;
|
||||||
|
import javafx.scene.control.Label;
|
||||||
|
import pamViewFX.fxSettingsPanes.DynamicSettingsPane;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JavaFX pane for the deep learning data selector.
|
||||||
|
*
|
||||||
|
* @author Jamie Macaulay
|
||||||
|
*/
|
||||||
|
public class DLSelectPaneFX extends DynamicSettingsPane<Boolean>{
|
||||||
|
|
||||||
|
public DLSelectPaneFX(Object ownerWindow) {
|
||||||
|
super(ownerWindow);
|
||||||
|
// TODO Auto-generated constructor stub
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean getParams(Boolean currParams) {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return currParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setParams(Boolean input) {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return "Deep Learning Data Selector:";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Node getContentNode() {
|
||||||
|
return new Label("Hello data selector");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void paneInitialized() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package rawDeepLearningClassifier.dataSelector;
|
||||||
|
|
||||||
|
import javax.swing.JComponent;
|
||||||
|
|
||||||
|
import PamView.dialog.PamDialogPanel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swing panel for the deep learning data selector.
|
||||||
|
*/
|
||||||
|
public class DLSelectPanel implements PamDialogPanel {
|
||||||
|
|
||||||
|
public DLSelectPanel(DLDataSelector dlDataSelector) {
|
||||||
|
// TODO Auto-generated constructor stub
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JComponent getDialogComponent() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setParams() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean getParams() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -115,7 +115,7 @@ public interface DLClassiferModel {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the parameters which can be serialized from transforms.
|
* Get the parameters which can be serialized from transforms.
|
||||||
* @param dlTransfroms- the dl transforms.
|
* @param dlTransfroms- the DL transforms.
|
||||||
*/
|
*/
|
||||||
public static ArrayList<DLTransfromParams> getDLTransformParams(ArrayList<DLTransform> dlTransfroms) {
|
public static ArrayList<DLTransfromParams> getDLTransformParams(ArrayList<DLTransform> dlTransfroms) {
|
||||||
ArrayList<DLTransfromParams> dlTransformParams = new ArrayList<DLTransfromParams>();
|
ArrayList<DLTransfromParams> dlTransformParams = new ArrayList<DLTransfromParams>();
|
||||||
|
@ -415,7 +415,7 @@ public class DLClassifyProcess extends PamInstantProcess {
|
|||||||
else {
|
else {
|
||||||
/****Add annotation to existing data unit (e.g. click, clip or other RawDataHolder)****/
|
/****Add annotation to existing data unit (e.g. click, clip or other RawDataHolder)****/
|
||||||
//Need to go by the parent data unit for merging data not the segments. Note that we may still add multiple
|
//Need to go by the parent data unit for merging data not the segments. Note that we may still add multiple
|
||||||
//predicitions to a single data unit depending on how many segments it contains.
|
//predictions to a single data unit depending on how many segments it contains.
|
||||||
|
|
||||||
//System.out.println("New model data " + pamRawData.getParentDataUnit().getUID() + " " + groupDataBuffer[i].size() + " " + modelResultDataBuffer[i].size());
|
//System.out.println("New model data " + pamRawData.getParentDataUnit().getUID() + " " + groupDataBuffer[i].size() + " " + modelResultDataBuffer[i].size());
|
||||||
|
|
||||||
|
@ -3,7 +3,10 @@ package rawDeepLearningClassifier.dlClassification;
|
|||||||
import PamView.GroupedDataSource;
|
import PamView.GroupedDataSource;
|
||||||
import PamView.GroupedSourceParameters;
|
import PamView.GroupedSourceParameters;
|
||||||
import PamguardMVC.AcousticDataBlock;
|
import PamguardMVC.AcousticDataBlock;
|
||||||
|
import PamguardMVC.dataSelector.DataSelectorCreator;
|
||||||
|
import clickTrainDetector.dataselector.CTDataSelectCreator;
|
||||||
import rawDeepLearningClassifier.DLControl;
|
import rawDeepLearningClassifier.DLControl;
|
||||||
|
import rawDeepLearningClassifier.dataSelector.DLDataSelectCreator;
|
||||||
import rawDeepLearningClassifier.tethys.DLSpeciesManager;
|
import rawDeepLearningClassifier.tethys.DLSpeciesManager;
|
||||||
import rawDeepLearningClassifier.tethys.DLTethysDataProvider;
|
import rawDeepLearningClassifier.tethys.DLTethysDataProvider;
|
||||||
import tethys.TethysControl;
|
import tethys.TethysControl;
|
||||||
@ -18,10 +21,30 @@ import tethys.species.DataBlockSpeciesManager;
|
|||||||
*/
|
*/
|
||||||
public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> implements GroupedDataSource {
|
public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> implements GroupedDataSource {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the deep learning classifier process.
|
||||||
|
*/
|
||||||
private DLClassifyProcess dlClassifyProcess;
|
private DLClassifyProcess dlClassifyProcess;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the Tethys data provider for the deep learning module (provides standardised metadata).
|
||||||
|
*/
|
||||||
private DLTethysDataProvider dlTethysDataProvider;
|
private DLTethysDataProvider dlTethysDataProvider;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A species manager for metadata.
|
||||||
|
*/
|
||||||
private DLSpeciesManager dlSpeciesManager;
|
private DLSpeciesManager dlSpeciesManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the deep learning control.
|
||||||
|
*/
|
||||||
private DLControl dlControl;
|
private DLControl dlControl;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data selector for the deep learning detections.
|
||||||
|
*/
|
||||||
|
private DLDataSelectCreator dlDataSelectCreator;
|
||||||
|
|
||||||
public DLDetectionDataBlock(String dataName, DLClassifyProcess parentProcess, int channelMap) {
|
public DLDetectionDataBlock(String dataName, DLClassifyProcess parentProcess, int channelMap) {
|
||||||
super(DLDetection.class, dataName, parentProcess, channelMap);
|
super(DLDetection.class, dataName, parentProcess, channelMap);
|
||||||
@ -49,6 +72,19 @@ public class DLDetectionDataBlock extends AcousticDataBlock<DLDetection> impleme
|
|||||||
}
|
}
|
||||||
return dlSpeciesManager;
|
return dlSpeciesManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* (non-Javadoc)
|
||||||
|
* @see PamguardMVC.PamDataBlock#getDataSelectCreator()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public synchronized DataSelectorCreator getDataSelectCreator() {
|
||||||
|
if (dlDataSelectCreator == null) {
|
||||||
|
dlDataSelectCreator = new DLDataSelectCreator(dlControl, this);
|
||||||
|
}
|
||||||
|
return dlDataSelectCreator;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
package rawDeepLearningClassifier.dlClassification;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decides whether a prediction has passed a threshold to be used to create a new
|
||||||
|
* data unit.
|
||||||
|
* <p>
|
||||||
|
* Note that the majority of the time this will be a simple test of the value of
|
||||||
|
* predictions of a model but there will be cases when a classifier implements a
|
||||||
|
* more complex system. For example, a implementation could save a buffer of predictions
|
||||||
|
* so that previous predictions inform the latest prediction. Or results may include
|
||||||
|
* some sort of object detection components and frequency bounds etc could be used for
|
||||||
|
* classification.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface DLPredictionDecision {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether a single prediction passes a binary classifier. Prediction which pass decision will be
|
||||||
|
* passed on to create new data units.
|
||||||
|
* @param result - the prediciton result to test.
|
||||||
|
* @return true if the result is passed.
|
||||||
|
*/
|
||||||
|
public boolean isBinaryResult(PredictionResult result);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package rawDeepLearningClassifier.dlClassification;
|
||||||
|
|
||||||
|
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make a decision based on a simple binary threshold for a prediction.
|
||||||
|
*/
|
||||||
|
public class SimpleDLDecision implements DLPredictionDecision {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the parameters.
|
||||||
|
*/
|
||||||
|
private StandardModelParams params;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isBinaryResult(PredictionResult modelResult) {
|
||||||
|
return isBinaryResult(modelResult, getParams()) ;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether a model passes a binary test...
|
||||||
|
* @param modelResult - the model results
|
||||||
|
* @return the model results.
|
||||||
|
*/
|
||||||
|
private static boolean isBinaryResult(PredictionResult modelResult, StandardModelParams genericModelParams) {
|
||||||
|
for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||||
|
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||||
|
|
||||||
|
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||||
|
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public StandardModelParams getParams() {
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void setParams( StandardModelParams genericModelParams) {
|
||||||
|
this.params = genericModelParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -44,6 +44,12 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
|||||||
* in real time without slowing down the rest of PAMGaurd.
|
* in real time without slowing down the rest of PAMGaurd.
|
||||||
*/
|
*/
|
||||||
private TaskThread workerThread;
|
private TaskThread workerThread;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Makes a binary decision on whether a prediction result should go on
|
||||||
|
* to be part of a data unit.
|
||||||
|
*/
|
||||||
|
private SimpleDLDecision simpleDLDecision = new SimpleDLDecision();
|
||||||
|
|
||||||
|
|
||||||
public StandardClassifierModel(DLControl dlControl) {
|
public StandardClassifierModel(DLControl dlControl) {
|
||||||
@ -84,7 +90,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
|||||||
|
|
||||||
for (int i =0; i<modelResult.size(); i++) {
|
for (int i =0; i<modelResult.size(); i++) {
|
||||||
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
|
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
|
||||||
modelResult.get(i).setBinaryClassification(GenericDLClassifier.isBinaryResult(modelResult.get(i), getDLParams()));
|
modelResult.get(i).setBinaryClassification(isDecision(modelResult.get(i), getDLParams()));
|
||||||
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
|
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -106,7 +112,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void prepModel() {
|
public void prepModel() {
|
||||||
System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath);
|
// System.out.println("STANDARD CLASSIFIER MODEL PREP MODEL! !!!: " + getDLParams().modelPath);
|
||||||
// StandardModelParams oldParams = getDLParams().clone();
|
// StandardModelParams oldParams = getDLParams().clone();
|
||||||
|
|
||||||
getDLWorker().prepModel(getDLParams(), dlControl);
|
getDLWorker().prepModel(getDLParams(), dlControl);
|
||||||
@ -221,7 +227,8 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
|||||||
* @return true if a threshold has been met.
|
* @return true if a threshold has been met.
|
||||||
*/
|
*/
|
||||||
public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) {
|
public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) {
|
||||||
return isBinaryResult(modelResult, modelParmas);
|
simpleDLDecision.setParams(modelParmas);
|
||||||
|
return simpleDLDecision.isBinaryResult(modelResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -241,22 +248,22 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* Check whether a model passes a binary test...
|
// * Check whether a model passes a binary test...
|
||||||
* @param modelResult - the model results
|
// * @param modelResult - the model results
|
||||||
* @return the model results.
|
// * @return the model results.
|
||||||
*/
|
// */
|
||||||
public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) {
|
// public static boolean isBinaryResult(StandardPrediction modelResult, StandardModelParams genericModelParams) {
|
||||||
for (int i=0; i<modelResult.getPrediction().length; i++) {
|
// for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||||
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
// //System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||||
|
//
|
||||||
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
// if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||||
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
// // System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
return false;
|
// return false;
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
Loading…
Reference in New Issue
Block a user