Updates to deep learning module

This commit is contained in:
Jamie Mac 2023-05-05 11:03:18 +01:00
parent aa0526a2d8
commit 353914cb46
40 changed files with 984 additions and 795 deletions

View File

@ -0,0 +1,17 @@
<?xml version="1.0" encoding="ISO-8859-1"?><PAMGUARD>
<VERSIONINFO Created="7 April 2023 at 08:06:41 UTC" MinAspVersion="2.35.23" Release="CORE" Version="2.02.07b"/>
<MODULES>
<MODULE>
<BinFilePrefix>Data_Acquisition_Sound_Acquisition_Raw_Data_</BinFilePrefix>
<lastUID>0</lastUID>
</MODULE>
<MODULE>
<BinFilePrefix>Deep_Learning_Classifier_Deep_Learning_Classifier_DL_Model_Data_</BinFilePrefix>
<lastUID>2000047</lastUID>
</MODULE>
<MODULE>
<BinFilePrefix>Deep_Learning_Classifier_Deep_Learning_Classifier_DL_detection_</BinFilePrefix>
<lastUID>2000000</lastUID>
</MODULE>
</MODULES>
</PAMGUARD>

View File

@ -350,7 +350,7 @@
<dependency>
<groupId>io.github.macster110</groupId>
<artifactId>jdl4pam</artifactId>
<version>0.0.94</version>
<version>0.0.96</version>
</dependency>
<!-- https://mvnrepository.com/artifact/gov.nist.math/jama -->

View File

@ -9,7 +9,6 @@ public class ArraySettingsPane extends SettingsPane<PamArray >{
public ArraySettingsPane() {
super(null);
// TODO Auto-generated constructor stub
}
@Override
@ -26,13 +25,11 @@ public class ArraySettingsPane extends SettingsPane<PamArray >{
@Override
public String getName() {
// TODO Auto-generated method stub
return "Array Parameters";
}
@Override
public Node getContentNode() {
// TODO Auto-generated method stub
return new Label("TODO: The Array Manager needs an FX GUI");
}

View File

@ -0,0 +1,99 @@
package Array.layoutFX;
import clickDetector.layoutFX.clickClassifiers.ClickTypeProperty;
import javafx.collections.ObservableList;
import javafx.scene.control.Dialog;
import javafx.scene.control.TableCell;
import javafx.scene.control.TableColumn;
import javafx.scene.control.cell.CheckBoxTableCell;
import pamViewFX.fxNodes.PamBorderPane;
import pamViewFX.fxNodes.table.TableSettingsPane;
/**
* A pane for setting up hydrophones. Note that this is entirely separate from PAMGuard so can be used in
* other projects.
*
* @author Jamie Macaulay
*
*/
public class BasicArrayPane extends PamBorderPane {
/**
* Class which extends TableSettingsPane and creates a sliding pane instead of a dialog when an item is added.
* @author Jamie Macaulay
*
*/
class ClickClassifierTable extends TableSettingsPane<ClickTypeProperty> {
public ClickClassifierTable(ObservableList<ClickTypeProperty> data) {
super(data);
//need to set up all the rows.
TableColumn<ClickTypeProperty,String> icon = new TableColumn<ClickTypeProperty,String>("Name");
icon.setCellValueFactory(cellData -> cellData.getValue().name);
icon.setEditable(false);
icon.setCellFactory((tableColumn) -> {
TableCell<ClickTypeProperty, String> tableCell = new TableCell<ClickTypeProperty, String>();
return tableCell;
});
// TableColumn<ClickTypeProperty,String> name = new TableColumn<ClickTypeProperty,String>("Name");
// name.setCellValueFactory(cellData -> cellData.getValue().name);
// name.setEditable(true);
TableColumn<ClickTypeProperty,Number> code = new TableColumn<ClickTypeProperty,Number>("Species Code");
code.setCellValueFactory(cellData -> cellData.getValue().code);
TableColumn<ClickTypeProperty,Boolean> checkCol = new TableColumn<>("Enable");
checkCol.setCellValueFactory( cellData -> cellData.getValue().enableClassifier);
checkCol.setCellFactory(CheckBoxTableCell.forTableColumn(checkCol));
checkCol.setEditable(true);
checkCol.setMaxWidth( 100 );
checkCol.setMinWidth( 100 );
TableColumn<ClickTypeProperty,Boolean> discard = new TableColumn<>("Discard");
discard.setCellValueFactory( cellData -> cellData.getValue().discardClassifier);
discard.setCellFactory(CheckBoxTableCell.forTableColumn(checkCol));
discard.setEditable(true);
discard.setMaxWidth( 100 );
discard.setMinWidth( 100 );
getTableView().setEditable(true);
getTableView().getColumns().addAll(checkCol, icon, code, discard);
}
@Override
public void dialogClosed(ClickTypeProperty data) {
// TODO Auto-generated method stub
}
@Override
public Dialog<ClickTypeProperty> createSettingsDialog(ClickTypeProperty data) {
//we do not use dialogs here- sliding pane instead.
// setClassifierPane(data);
// showFlipPane(true);
return null;
}
@Override
public void editData(ClickTypeProperty data){
// setClassifierPane(data);
// showFlipPane(true);
}
@Override
public void createNewData(){
//create a new classifier.
// clickClassifiers.add(createClickTypeProperty());
}
}
}

View File

@ -6,6 +6,7 @@ import java.net.URI;
/**
* Class for breaking a file name down into it's constituent
* parts.
*
* @author Doug Gillespie
*
*/

View File

@ -43,11 +43,9 @@ public class SweepClassifierPaneFX extends BasicIdentifierPaneFX {
public void setClassifierPane(ClickTypeProperty clickTypeProperty){
SweepClassifierSetPaneFX sweepPane=new SweepClassifierSetPaneFX(sweepClickClassifier);
//make it so the title of the pane is the same as the name as the classifier
getFlipPane().getAdvLabel().textProperty().unbind();
getFlipPane().getAdvLabel().textProperty().bind(sweepPane.getNameTextProperty());
getFlipPane().getAdvLabel().graphicProperty().bind(sweepPane.getNameGraphicProperty());

View File

@ -77,7 +77,7 @@ import rawDeepLearningClassifier.segmenter.SegmenterProcess;
* AnimalSpot is a framework for training acoustic deep learning
* models using Pytorch. Users can load a .py model which contains embedded
* metadata so that PMAGuard knows the exact transforms required for the model
* input. This makes deployin models in PAMGuard very easy - users require little
* input. This makes deploying models in PAMGuard very easy - users require little
* or no experience to get this working.
* <p>
* <li>Ketos</li>
@ -118,6 +118,7 @@ public class DLControl extends PamControlledUnit implements PamSettings {
* List of different deep learning models that are available.
*/
private ArrayList<DLClassiferModel> dlModels = new ArrayList<DLClassiferModel>();
/**
* The settings pane.
@ -184,7 +185,8 @@ public class DLControl extends PamControlledUnit implements PamSettings {
/**
* Figure out which model type has been imported.
*/
private DLClassifierChooser dlClassifierChooser;
private DLClassifierChooser dlClassifierChooser;
/**
@ -198,6 +200,7 @@ public class DLControl extends PamControlledUnit implements PamSettings {
PamRawDataBlock rawDataBlock = PamController.getInstance()
.getRawDataBlock(rawDLParmas.groupedSourceParams.getDataSource());
/**
* In the latest release of djl (0.11.0) there is a bug with the dll's of tensorflow and
* pytorch. If tensorflow is loaded before pytorch there is a conglict in dll's and
@ -230,9 +233,12 @@ public class DLControl extends PamControlledUnit implements PamSettings {
/***** Add new deep learning models here ****/
dlModels.add(new GenericDLClassifier(this));
dlModels.add(new SoundSpotClassifier(this));
dlModels.add(new KetosClassifier(this));
//it is important the Generic Model is last because we need to check
//for PG metadata in all other models before resorting to manually
//setting up a model.
dlModels.add(new GenericDLClassifier(this));
// dlModels.add(new DummyClassifier());
// dlModels.add(new OrcaSpotClassifier(this)); //removed soon.
@ -258,6 +264,9 @@ public class DLControl extends PamControlledUnit implements PamSettings {
// serialized
if (rawDLParmas.classNameMap == null)
rawDLParmas.classNameMap = new ArrayList<DLClassName>();
//create the classiifer chooser.
dlClassifierChooser = new DLClassifierChooser(this);
// ensure everything is updated.
updateParams(rawDLParmas);
@ -278,7 +287,12 @@ public class DLControl extends PamControlledUnit implements PamSettings {
* @return the current deep learning model.
*/
public DLClassiferModel getDLModel() {
return dlModels.get(rawDLParmas.modelSelection);
if (this.rawDLParmas.modelSelection<0 || this.rawDLParmas.modelSelection>=dlModels.size()) {
return null;
}
else {
return dlModels.get(this.rawDLParmas.modelSelection);
}
}
/**

View File

@ -22,10 +22,10 @@ public class RawDLParams implements Serializable, Cloneable {
/**
* The currently selected Deep Learning model.
* (Models are now automatically selected)
* (Models are now automatically selected).
* -1 means no model selected.
*/
@Deprecated
public int modelSelection = 0;
public int modelSelection = -1;
/**
* The current model URI. The deep learning model must have some sort of external file to run.

View File

@ -82,13 +82,19 @@ public class DLPredictionPlotInfoFX extends GenericLinePlotInfo {
frequencyInfo = new GenericScaleInfo(0, 1, ParameterType.FREQUENCY, ParameterUnits.HZ);
DLClassName[] classNames = getDlControl().getDLModel().getClassNames();
System.out.println("Class names are: !!! " + (classNames == null ? "null" : classNames.length));
if (classNames!=null) {
//make sure this is initialised otherwise the plot won't work when first created.
if (dlPredParams.lineInfos==null ) dlPredParams.lineInfos = new LineInfo[classNames.length];
for (int i=0; i<classNames.length; i++) {
if (dlPredParams.lineInfos[i]==null) {
dlPredParams.lineInfos[i] = new LineInfo(true, Color.rgb(0, 0, 255%(i*30 + 50)));
//make sure this is initialised otherwise the plot won't work when first created.
if (dlPredParams.lineInfos==null ) dlPredParams.lineInfos = new LineInfo[classNames.length];
for (int i=0; i<classNames.length; i++) {
if (dlPredParams.lineInfos[i]==null) {
dlPredParams.lineInfos[i] = new LineInfo(true, Color.rgb(0, 0, 255%(i*30 + 50)));
}
}
}
addScaleInfo(probabilityScaleInfo);

View File

@ -11,7 +11,6 @@ import PamguardMVC.DataUnitBaseData;
import PamguardMVC.PamDataUnit;
import PamguardMVC.PamInstantProcess;
import PamguardMVC.PamObservable;
import PamguardMVC.PamProcess;
import binaryFileStorage.DataUnitFileInformation;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.RawDLParams;

View File

@ -38,9 +38,9 @@ public abstract class DLTaskThread extends Thread {
public void stopTaskThread() {
run.set(false);
//Clean up daemon.
if (dlModelWorker!=null) {
dlModelWorker.closeModel();
}
// if (dlModelWorker!=null) {
// dlModelWorker.closeModel();
// }
dlModelWorker = null;
}

View File

@ -0,0 +1,291 @@
package rawDeepLearningClassifier.dlClassification;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import org.codehaus.plexus.util.FileUtils;
import PamController.PamSettings;
import PamDetection.RawDataUnit;
import PamUtils.PamArrayUtils;
import PamUtils.PamCalendar;
import javafx.stage.FileChooser.ExtensionFilter;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericDLClassifier;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
import rawDeepLearningClassifier.layoutFX.RawDLSettingsPane;
import rawDeepLearningClassifier.segmenter.SegmenterProcess.GroupedRawData;
import warnings.PamWarning;
import warnings.WarningSystem;
/**
* A useful abstract class for standard models which are a file or URL that is loaded, have a UI and
* utilise PAMSettings to save settings state.
*/
public abstract class StandardClassifierModel implements DLClassiferModel, PamSettings {
protected DLControl dlControl;
/**
* True to force the classifier to use a queue - used for simulating real time operation.
*/
private boolean forceQueue = false;
/**
* The worker thread has a buffer so that Standard models can be run
* in real time without slowing down the rest of PAMGaurd.
*/
private TaskThread workerThread;
public StandardClassifierModel(DLControl dlControl) {
this.dlControl=dlControl;
}
/**
* Sound spot warning.
*/
PamWarning dlClassifierWarning = new PamWarning(getName(), "",2);
@Override
public ArrayList<? extends PredictionResult> runModel(ArrayList<GroupedRawData> groupedRawData) {
if (getDLWorker().isModelNull()) return null;
// System.out.println("SoundSpotClassifier: PamCalendar.isSoundFile(): "
// + PamCalendar.isSoundFile() + " " + (PamCalendar.isSoundFile() && !forceQueue));
/**
* If a sound file is being analysed then Ketos can go as slow as it wants. if used in real time
* then there is a buffer with a maximum queue size.
*/
if ((PamCalendar.isSoundFile() && !forceQueue) || dlControl.isViewer()) {
//run the model
ArrayList<GenericPrediction> modelResult = getDLWorker().runModel(groupedRawData,
groupedRawData.get(0).getParentDataBlock().getSampleRate(), 0);
if (modelResult==null) {
dlClassifierWarning.setWarningMessage(getName() + " deep learning model returned null");
WarningSystem.getWarningSystem().addWarning(dlClassifierWarning);
return null;
}
for (int i =0; i<modelResult.size(); i++) {
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
modelResult.get(i).setBinaryClassification(GenericDLClassifier.isBinaryResult(modelResult.get(i), getDLParams()));
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
}
return modelResult; //returns to the classifier.
}
else {
//add to a buffer if in real time.
if (workerThread.getQueue().size()>DLModelWorker.MAX_QUEUE_SIZE) {
//we are not doing well - clear the buffer
workerThread.getQueue().clear();
}
workerThread.getQueue().add(groupedRawData);
}
return null;
}
@Override
public void prepModel() {
//System.out.println("PrepModel! !!!");
StandardModelParams oldParams = getDLParams().clone();
getDLWorker().prepModel(getDLParams(), dlControl);
if (getDLWorker().isModelNull()) {
dlClassifierWarning.setWarningMessage("There is no loaded " + getName() + " classifier model. " + getName() + " disabled.");
WarningSystem.getWarningSystem().addWarning(dlClassifierWarning);
}
if ((!PamCalendar.isSoundFile() || forceQueue) && !dlControl.isViewer()) {
//for real time only
if (workerThread!=null) {
workerThread.stopTaskThread();
}
workerThread = new TaskThread(getDLWorker());
workerThread.setPriority(Thread.MAX_PRIORITY);
workerThread.start();
}
}
/**
* Get the sound spot worker.
* @return the sound spot worker.
*/
public abstract DLModelWorker<GenericPrediction> getDLWorker();
/**
* Get the sound spot worker.
* @return the sound spot worker.
*/
public abstract StandardModelParams getDLParams();
public DLControl getDLControl() {
return dlControl;
}
@Override
public int getNumClasses() {
return this.getDLParams().numClasses;
}
@Override
public DLClassName[] getClassNames() {
return getDLParams().classNames;
}
@Override
public boolean checkModelOK() {
return !getDLWorker().isModelNull();
}
@Override
public void setModel(URI uri) {
//will change the params if we do not clone.
StandardModelParams.setModel(uri, this.getDLParams());
this.prepModel();
}
/**
* The task thread.
* @author Jamie Macaulay
*
*/
public class TaskThread extends DLTaskThread {
public TaskThread(DLModelWorker soundSpotWorker) {
super(soundSpotWorker);
}
@Override
public void newDLResult(GenericPrediction soundSpotResult, GroupedRawData groupedRawData) {
soundSpotResult.setClassNameID(GenericDLClassifier.getClassNameIDs(getDLParams()));
soundSpotResult.setBinaryClassification(GenericDLClassifier.isBinaryResult(soundSpotResult, getDLParams()));
newResult(soundSpotResult, groupedRawData);
}
}
@Override
public void closeModel() {
getDLWorker().closeModel();
}
/**
* Send a new result form the thread queue to the process.
* @param modelResult - the model result;
* @param groupedRawData - the grouped raw data.
*/
protected void newResult(GenericPrediction modelResult, GroupedRawData groupedRawData) {
this.dlControl.getDLClassifyProcess().newModelResult(modelResult, groupedRawData);
}
@Override
public ArrayList<PamWarning> checkSettingsOK() {
return checkSettingsOK(getDLParams(), dlControl);
}
/**
* Get the number of samples for microseconds. Based on the sample rate of the parent data block.
*/
public double millis2Samples(double millis) {
//System.out.println("Samplerate: " + this.dlControl.getSegmenter().getSampleRate() );
return millis*this.dlControl.getSegmenter().getSampleRate()/1000.0;
}
/**
* Get raw settings pane
* @return the setting pane.
*/
public RawDLSettingsPane getRawSettingsPane() {
return this.dlControl.getSettingsPane();
}
public boolean isModelExtensions(URI uri){
// System.out.println("Check: " + getName() + " extensions");
String url;
try {
url = uri.toURL().getPath();
if (url.contains(".")) {
String extension = FileUtils.getExtension(url);
// System.out.println("Check: " + getName() + " file extension " + extension);
List<ExtensionFilter> pExtensions = this.getModelUI().getModelFileExtensions();
for (ExtensionFilter extF : pExtensions) {
for (String ext : extF.getExtensions()) {
// System.out.println(getName() + " Extensions: " + ext + " " + extension + " " + ext.equals(extension.substring(2).trim()) + " " + extension.substring(2).trim());
if (extension.equals(ext.substring(2))) {
return true;
}
}
}
}
return false;
} catch (MalformedURLException e) {
e.printStackTrace();
}
return false;
}
public static ArrayList<PamWarning> checkSettingsOK(StandardModelParams genericModelParams, DLControl dlControl) {
// TODO - check if model is null.
// check that classifier is selected if continous files.
//
ArrayList<PamWarning> warnings = new ArrayList<PamWarning>();
File file = new File(genericModelParams.modelPath);
if (genericModelParams.modelPath == null || !file.isFile()) {
warnings.add(
new PamWarning("Generic classifier", "There is no model loaded - the classifier will not run", 2));
// if no model then this is ionly message needed for the generic classifier.
return warnings;
}
// if continous data is selected and all classes are false then this is a
// potential mistake...
if (dlControl.getSettingsPane().getSelectedParentDataBlock().getUnitClass() == RawDataUnit.class
&& PamArrayUtils.isAllFalse(genericModelParams.binaryClassification)) {
warnings.add(new PamWarning("Generic classifier",
"There are no prediction classes selected for classification. "
+ "Predicitons for each segment will be saved but there will be no detections generated",
1));
}
return warnings;
}
}

View File

@ -2,26 +2,21 @@ package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import org.jamdev.jdl4pam.animalSpot.AnimalSpotModel;
import org.jamdev.jdl4pam.transforms.DLTransformsFactory;
import org.jamdev.jdl4pam.transforms.DLTransfromParams;
import PamController.PamControlledUnitSettings;
import PamController.PamSettingManager;
import PamController.PamSettings;
import PamUtils.PamCalendar;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.DLClassName;
import rawDeepLearningClassifier.dlClassification.DLClassiferModel;
import rawDeepLearningClassifier.dlClassification.DLTaskThread;
import rawDeepLearningClassifier.dlClassification.StandardClassifierModel;
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericDLClassifier;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
import rawDeepLearningClassifier.layoutFX.RawDLSettingsPane;
import rawDeepLearningClassifier.segmenter.SegmenterProcess.GroupedRawData;
import warnings.PamWarning;
import warnings.WarningSystem;
/**
* A deep learning classifier running models wihc have been created using the
@ -42,12 +37,7 @@ import warnings.WarningSystem;
* @author JamieMacaulay
*
*/
public class SoundSpotClassifier implements DLClassiferModel, PamSettings {
/**
* Reference to the control.
*/
private DLControl dlControl;
public class SoundSpotClassifier extends StandardClassifierModel {
/**
* The user interface for sound spot.
@ -63,87 +53,127 @@ public class SoundSpotClassifier implements DLClassiferModel, PamSettings {
/**
* The deep learning model worker.
*/
private SoundSpotWorker soundSpotWorker;
private DLModelWorker<GenericPrediction> soundSpotWorker;
/**
* True to force the classifier to use a queue - used for simulating real time operation.
*/
private boolean forceQueue = false;
/**
* Sound spot warning.
*/
PamWarning soundSpotWarning = new PamWarning("AnimalSpot_Classifier", "",2);
/**
* Runs the deep leanring on a different thread with a data buffer. For real time only.
*/
private DLTaskThread workerThread;
public SoundSpotClassifier(DLControl dlControl) {
this.dlControl=dlControl;
super(dlControl);
this.soundSpotParmas = new StandardModelParams();
this.soundSpotUI= new SoundSpotUI(this);
//load the previous settings
PamSettingManager.getInstance().registerSettings(this);
System.out.println("LOADED CLASS NAMES: currParams.classNames: " + soundSpotParmas.classNames);
// System.out.println("LOADED CLASS NAMES: currParams.classNames: " + soundSpotParmas.classNames);
}
@Override
public ArrayList<SoundSpotResult> runModel(ArrayList<GroupedRawData> groupedRawData) {
public String getName() {
return "AnimalSpot";
}
// System.out.println("SoundSpotClassifier: PamCalendar.isSoundFile(): "
// + PamCalendar.isSoundFile() + " " + (PamCalendar.isSoundFile() && !forceQueue));
/**
* If a sound file is being analysed then SoundSpot can go as slow as it wants. if used in real time
* then there is a buffer with a maximum queue size.
*/
if ((PamCalendar.isSoundFile() && !forceQueue) || dlControl.isViewer()) {
//run the model
ArrayList<SoundSpotResult> modelResult = getSoundSpotWorker().runModel(groupedRawData,
groupedRawData.get(0).getParentDataBlock().getSampleRate(), 0);
@Override
public DLCLassiferModelUI getModelUI() {
return soundSpotUI;
}
for (int i =0; i<modelResult.size(); i++) {
modelResult.get(i).setClassNameID(getClassNameIDs());
modelResult.get(i).setBinaryClassification(isBinaryResult(modelResult.get(i)));
}
@Override
public Serializable getDLModelSettings() {
return soundSpotParmas;
}
return modelResult; //returns to the classifier.
@Override
public String getUnitName() {
return dlControl.getUnitName()+"_SoundSpot";
}
@Override
public String getUnitType() {
return dlControl.getUnitType()+"_SoundSpot";
}
@Override
public Serializable getSettingsReference() {
if (soundSpotParmas==null) {
soundSpotParmas = new StandardModelParams();
}
else {
//add to a buffer if in real time.
if (workerThread.getQueue().size()>DLModelWorker.MAX_QUEUE_SIZE) {
//we are not doing well - clear the buffer
workerThread.getQueue().clear();
ArrayList<DLTransfromParams> dlTransformParams = DLClassiferModel.getDLTransformParams(soundSpotParmas.dlTransfroms);
soundSpotParmas.dlTransfromParams=dlTransformParams;
//System.out.println("SoundSpot have been saved. : " + soundSpotParmas.classNames);
return soundSpotParmas;
}
@Override
public long getSettingsVersion() {
return StandardModelParams.serialVersionUID;
}
@Override
public boolean restoreSettings(PamControlledUnitSettings pamControlledUnitSettings) {
StandardModelParams newParameters = (StandardModelParams) pamControlledUnitSettings.getSettings();
if (newParameters!=null) {
soundSpotParmas = newParameters.clone();
//System.out.println("SoundSpot have been restored. : " + soundSpotParmas.classNames);
if (soundSpotParmas.dlTransfromParams!=null) {
soundSpotParmas.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>) soundSpotParmas.dlTransfromParams);
}
workerThread.getQueue().add(groupedRawData);
}
return null;
else soundSpotParmas = new StandardModelParams();
return true;
}
/**
* Check whether a model passes a binary test...
* @param modelResult - the model results
* @return the model results.
* Get the sound spot parameters.
* @return sound spot parameters.
*/
private boolean isBinaryResult(GenericPrediction modelResult) {
for (int i=0; i<modelResult.getPrediction().length; i++) {
if (modelResult.getPrediction()[i]>soundSpotParmas.threshold && soundSpotParmas.binaryClassification[i]) {
//System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
return true;
}
}
return false;
public StandardModelParams getSoundSpotParams() {
return soundSpotParmas;
}
/**
* Get the sound spot worker.
* @return the sound spot worker.
* Set the sound spot parameters.
* @param the params to set
*/
SoundSpotWorker getSoundSpotWorker() {
public void setSoundSpotParams(StandardModelParams soundSpotParmas) {
this.soundSpotParmas=soundSpotParmas;
}
@Override
public boolean isModelType(URI uri) {
//TODO need to be more sophisticated here.
if (super.isModelExtensions(uri)) {
//we have a PyTorch model but is it animal spot.
try {
AnimalSpotModel soundSpotModel = new AnimalSpotModel(uri.getPath());
if (soundSpotModel!=null && soundSpotModel.getExtraFiles()!=null) {
return true;
}
return false;
}
catch (Exception e) {
return false;
}
}
return false;
}
@Override
public DLModelWorker<GenericPrediction> getDLWorker() {
if (soundSpotWorker==null) {
soundSpotWorker = new SoundSpotWorker();
}
@ -151,27 +181,33 @@ public class SoundSpotClassifier implements DLClassiferModel, PamSettings {
}
/**
* The task thread.
* @author Jamie Macaulay
*
*/
public class TaskThread extends DLTaskThread {
TaskThread(DLModelWorker soundSpotWorker) {
super(soundSpotWorker);
}
@Override
public void newDLResult(GenericPrediction soundSpotResult, GroupedRawData groupedRawData) {
soundSpotResult.setClassNameID(getClassNameIDs());
soundSpotResult.setBinaryClassification(isBinaryResult(soundSpotResult));
newResult(soundSpotResult, groupedRawData);
}
@Override
public StandardModelParams getDLParams() {
return soundSpotParmas;
}
// /**
// * Check whether a model passes a binary test...
// * @param modelResult - the model results
// * @return the model results.
// */
// private boolean isBinaryResult(GenericPrediction modelResult) {
// for (int i=0; i<modelResult.getPrediction().length; i++) {
// if (modelResult.getPrediction()[i]>soundSpotParmas.threshold && soundSpotParmas.binaryClassification[i]) {
// //System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
// return true;
// }
// }
// return false;
// }
// public class TaskThread extends Thread {
//
// private AtomicBoolean run = new AtomicBoolean(true);
@ -225,196 +261,19 @@ public class SoundSpotClassifier implements DLClassiferModel, PamSettings {
//
// }
/**
* Get the class name IDs
* @return an array of class name IDs
*/
private short[] getClassNameIDs() {
if (soundSpotParmas.classNames==null || soundSpotParmas.classNames.length<=0) return null;
short[] nameIDs = new short[soundSpotParmas.classNames.length];
for (int i = 0 ; i<soundSpotParmas.classNames.length; i++) {
nameIDs[i] = soundSpotParmas.classNames[i].ID;
}
return nameIDs;
}
// /**
// * Get the class name IDs
// * @return an array of class name IDs
// */
// private short[] getClassNameIDs() {
// if (soundSpotParmas.classNames==null || soundSpotParmas.classNames.length<=0) return null;
// short[] nameIDs = new short[soundSpotParmas.classNames.length];
// for (int i = 0 ; i<soundSpotParmas.classNames.length; i++) {
// nameIDs[i] = soundSpotParmas.classNames[i].ID;
// }
// return nameIDs;
// }
/**
* Send a new result form the thread queue to the process.
* @param modelResult - the model result;
* @param groupedRawData - the grouped raw data.
*/
protected void newResult(GenericPrediction modelResult, GroupedRawData groupedRawData) {
this.dlControl.getDLClassifyProcess().newModelResult(modelResult, groupedRawData);
}
@Override
public void prepModel() {
//System.out.println("PrepModel! !!!");
StandardModelParams oldParams = soundSpotParmas.clone();
getSoundSpotWorker().prepModel(soundSpotParmas, dlControl);
if (getSoundSpotWorker().getModel()==null) {
soundSpotWarning.setWarningMessage("There is no loaded classifier model. AnimalSpot disabled.");
WarningSystem.getWarningSystem().addWarning(soundSpotWarning);
}
if ((!PamCalendar.isSoundFile() || forceQueue) && !dlControl.isViewer()) {
//for real time only
if (workerThread!=null) {
workerThread.stopTaskThread();
}
workerThread = new TaskThread(getSoundSpotWorker());
workerThread.setPriority(Thread.MAX_PRIORITY);
workerThread.start();
}
}
@Override
public void closeModel() {
}
@Override
public String getName() {
return "AnimalSpot";
}
@Override
public DLCLassiferModelUI getModelUI() {
return soundSpotUI;
}
@Override
public Serializable getDLModelSettings() {
return soundSpotParmas;
}
@Override
public String getUnitName() {
return dlControl.getUnitName()+"_SoundSpot";
}
@Override
public String getUnitType() {
return dlControl.getUnitType()+"_SoundSpot";
}
@Override
public Serializable getSettingsReference() {
if (soundSpotParmas==null) {
soundSpotParmas = new StandardModelParams();
}
ArrayList<DLTransfromParams> dlTransformParams = DLClassiferModel.getDLTransformParams(soundSpotParmas.dlTransfroms);
soundSpotParmas.dlTransfromParams=dlTransformParams;
//System.out.println("SoundSpot have been saved. : " + soundSpotParmas.classNames);
return soundSpotParmas;
}
@Override
public long getSettingsVersion() {
return StandardModelParams.serialVersionUID;
}
@Override
public boolean restoreSettings(PamControlledUnitSettings pamControlledUnitSettings) {
StandardModelParams newParameters = (StandardModelParams) pamControlledUnitSettings.getSettings();
if (newParameters!=null) {
soundSpotParmas = newParameters.clone();
//System.out.println("SoundSpot have been restored. : " + soundSpotParmas.classNames);
if (soundSpotParmas.dlTransfromParams!=null) {
soundSpotParmas.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>) soundSpotParmas.dlTransfromParams);
}
}
else soundSpotParmas = new StandardModelParams();
return true;
}
/**
* Get the sound spot parameters.
* @return sound spot parameters.
*/
public StandardModelParams getSoundSpotParams() {
return soundSpotParmas;
}
/**
* Set the sound spot parameters.
* @param the params to set
*/
public void setSoundSpotParams(StandardModelParams soundSpotParmas) {
this.soundSpotParmas=soundSpotParmas;
}
/**
* Get raw settings pane
* @return the setting pane.
*/
public RawDLSettingsPane getRawSettingsPane() {
return this.dlControl.getSettingsPane();
}
/**
* Get the number of samples for microseconds. Based on the sample rate of the parent data block.
*/
public double millis2Samples(double millis) {
//System.out.println("Samplerate: " + this.dlControl.getSegmenter().getSampleRate() );
return millis*this.dlControl.getSegmenter().getSampleRate()/1000.0;
}
@Override
public int getNumClasses() {
return this.soundSpotParmas.numClasses;
}
@Override
public DLClassName[] getClassNames() {
return soundSpotParmas.classNames;
}
public DLControl getDLControl() {
return dlControl;
}
@Override
public boolean checkModelOK() {
return getSoundSpotWorker().getModel()!=null;
}
@Override
public ArrayList<PamWarning> checkSettingsOK() {
return GenericDLClassifier.checkSettingsOK(soundSpotParmas, dlControl);
}
@Override
public boolean isModelType(URI uri) {
// TODO Auto-generated method stub
return false;
}
@Override
public void setModel(URI model) {
// TODO Auto-generated method stub
}
}

View File

@ -11,20 +11,14 @@ public class SoundSpotModelPane extends StandardModelPane {
* Reference to the currently selected sound spot classifier.
*/
private SoundSpotClassifier soundSpotClassifier;
/**
* The extension filter for sound spot models.
*/
private ArrayList<ExtensionFilter> extensionFilters;
public SoundSpotModelPane(SoundSpotClassifier soundSpotClassifier) {
super(soundSpotClassifier);
// TODO Auto-generated constructor stub
this.soundSpotClassifier=soundSpotClassifier;
extensionFilters = new ArrayList<ExtensionFilter> ();
extensionFilters.add(new ExtensionFilter("Pytorch Model", "*.pk"));
}
@ -58,9 +52,9 @@ public class SoundSpotModelPane extends StandardModelPane {
* Note that the model prep will determine whether new transforms need to be loaded from the
* model or to use the existing transforms in the settings.
*/
soundSpotClassifier.getSoundSpotWorker().prepModel(params, soundSpotClassifier.getDLControl());
soundSpotClassifier.getDLWorker().prepModel(params, soundSpotClassifier.getDLControl());
//get the model tansforms calculated from the model by SoundSpoyWorker and apply them to our temporary params clone.
getParamsClone().dlTransfroms = this.soundSpotClassifier.getSoundSpotWorker().getModelTransforms();
getParamsClone().dlTransfroms = this.soundSpotClassifier.getDLWorker().getModelTransforms();
// if (getParamsClone().defaultSegmentLen!=null) {
@ -73,9 +67,4 @@ public class SoundSpotModelPane extends StandardModelPane {
}
@Override
public ArrayList<ExtensionFilter> getExtensionFilters() {
return extensionFilters;
}
}

View File

@ -1,8 +1,12 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.util.ArrayList;
import java.util.List;
import javax.swing.JPanel;
import PamController.SettingsPane;
import javafx.stage.FileChooser.ExtensionFilter;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
@ -24,12 +28,21 @@ public class SoundSpotUI implements DLCLassiferModelUI {
*/
private SoundSpotClassifier soundSpotClassifier;
/**
* The extension filter for sound spot models.
*/
private ArrayList<ExtensionFilter> extensionFilters;
/**
* SondSpot classifier.
* @param soundSpotClassifier
*/
public SoundSpotUI(SoundSpotClassifier soundSpotClassifier) {
this.soundSpotClassifier=soundSpotClassifier;
extensionFilters = new ArrayList<ExtensionFilter> ();
extensionFilters.add(new ExtensionFilter("Pytorch Model", "*.pk"));
}
@Override
@ -59,4 +72,15 @@ public class SoundSpotUI implements DLCLassiferModelUI {
return null;
}
/**
* Get a list of extension fitlers for the file dialog.
* e.g.
* new ExtensionFilter("Pytorch Model", "*.pk")
* @return a list of extension fitlers for the file dialog.
*/
@Override
public List<ExtensionFilter> getModelFileExtensions() {
return extensionFilters;
}
}

View File

@ -1,12 +1,13 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.nio.file.Paths;
import org.jamdev.jdl4pam.animalSpot.AnimalSpotModel;
import org.jamdev.jdl4pam.animalSpot.AnimalSpotParams;
import PamModel.PamModel;
import PamModel.PamModel.PluginClassloader;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
/**
@ -18,13 +19,16 @@ import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
* @author Jamie Macaulay
*
*/
public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
public class SoundSpotWorker extends DLModelWorker<GenericPrediction> {
/**
* Sound spot model.
*/
private AnimalSpotModel soundSpotModel;
private AnimalSpotModel soundSpotModel;
private String currentPath;
/**
@ -39,7 +43,7 @@ public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
*/
public void prepModel(StandardModelParams soundSpotParams, DLControl dlControl) {
//ClassLoader origCL = Thread.currentThread().getContextClassLoader();
//System.out.println("prepModel: " + soundSpotParams.useDefaultTransfroms);
try {
@ -52,9 +56,13 @@ public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
// PluginClassloader newCL = PamModel.getPamModel().getClassLoader();
// Thread.currentThread().setContextClassLoader(newCL);
// }
if (soundSpotModel==null || currentPath ==null || !Paths.get(currentPath).equals(Paths.get(soundSpotParams.modelPath))) {
//System.out.println("Sound spot path: " + soundSpotParams.modelPath);
//first open the model and get the correct parameters.
soundSpotModel = new AnimalSpotModel(soundSpotParams.modelPath);
this.currentPath = soundSpotParams.modelPath;
//first open the model and get the correct parameters.
soundSpotModel = new AnimalSpotModel(soundSpotParams.modelPath);
}
}
catch (Exception e) {
e.printStackTrace();
@ -75,7 +83,7 @@ public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
//use the old transforms.
setModelTransforms(soundSpotParams.dlTransfroms);
}
soundSpotParams.defaultSegmentLen = dlParams.seglen; //the segment length in microseconds.
soundSpotParams.numClasses = dlParams.classNames.length;
@ -132,7 +140,8 @@ public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
* Destroy the model.
*/
public void closeModel() {
//TODO
//TODO - need to be able to access model in JPAM API.
this.currentPath = null;
}
@ -145,4 +154,10 @@ public class SoundSpotWorker extends DLModelWorker<SoundSpotResult> {
}
@Override
public boolean isModelNull() {
return soundSpotModel==null;
}
}

View File

@ -1,8 +1,6 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.io.File;
import java.util.ArrayList;
import org.controlsfx.control.CheckComboBox;
import org.controlsfx.control.PopOver;
import org.controlsfx.control.ToggleSwitch;
@ -18,7 +16,6 @@ import javafx.scene.control.Spinner;
import javafx.scene.control.Alert.AlertType;
import javafx.scene.text.Font;
import javafx.scene.text.FontWeight;
import javafx.stage.FileChooser.ExtensionFilter;
import pamViewFX.PamGuiManagerFX;
import pamViewFX.fxGlyphs.PamGlyphDude;
import pamViewFX.fxNodes.PamBorderPane;
@ -203,14 +200,6 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
}
/**
* Get a list of extension fitlers for the file dialog.
* e.g.
* new ExtensionFilter("Pytorch Model", "*.pk")
* @return a list of extension fitlers for the file dialog.
*/
public abstract ArrayList<ExtensionFilter> getExtensionFilters();
/**
* The default segment len changed.
*/
@ -278,7 +267,7 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
@Override
public StandardModelParams getParams(StandardModelParams currParams) {
if (currentSelectedFile==null) {
//uuurgh need to sort this out with FX stuff
WarnOnce.showWarningFX(null, "No Model File", "There is no model file selected in the path: Please select a compatible model" , AlertType.ERROR);
@ -296,13 +285,9 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
currParams.threshold = detectionSpinner.getValue();
// currParams.useCUDA = useCuda.isSelected();
//System.out.println("StandardModelParams 1: " + currParams);
// System.out.println("StandardModelParams : this.paramsClone.numClasses " + this.paramsClone.numClasses);
//System.out.println("StandardModelParams 2: " + currParams.useDefaultTransfroms);
//System.out.println("StandardModelParams 2: " + currParams.useDefaultTransfroms);
boolean[] speciesClass = new boolean[this.paramsClone.numClasses];
for (int i=0; i< speciesClass.length; i++) {
@ -315,15 +300,13 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
currParams = (StandardModelParams) this.getAdvSettingsPane().getParams(currParams);
System.out.println("GET advanced params: ");
//get class names from the paramClone as these may have been set by a loaded model
//instead of a using changing a control.
currParams.classNames = paramsClone.classNames;
currParams.numClasses = paramsClone.numClasses;
if (paramsClone.classNames == null && speciesIDBox.getItems()!=null) {
if ((paramsClone.classNames == null || paramsClone.classNames.length<=0) && speciesIDBox.getItems()!=null) {
String[] classNames = new String[speciesIDBox.getItems().size()];
for (int i=0; i<speciesIDBox.getItems().size(); i++) {
@ -331,6 +314,10 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
}
currParams.classNames = this.dlClassifierModel.getDLControl().getClassNameManager().makeClassNames(classNames);
}
// System.out.println("GET CLASS NAMES: currParams.classNames: " + currParams.classNames + " " +
// (currParams.classNames!=null? currParams.classNames.length: 0 + " " + currParams.numClasses));
currParams.useDefaultSegLen = usedefaultSeg.isSelected();
@ -371,13 +358,12 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
private void setClassNames(StandardModelParams currParams) {
speciesIDBox.getItems().clear();
System.out.println("SET CLASS NAMES: currParams.classNames: " + currParams.classNames + " " + (currParams.classNames!=null ? currParams.classNames.length: 0) + " " + currParams.numClasses);
int classNamesLen = 0;
if (currParams.classNames!=null) classNamesLen = currParams.classNames.length;
//System.out.println("SET CLASS NAMES: currParams.classNames: " + currParams.classNames + " " + classNamesLen + " " + currParams.numClasses);
for (int i=0; i<Math.max(classNamesLen, currParams.numClasses); i++) {
if (currParams.classNames!=null && currParams.classNames.length>i && currParams.classNames[i]!=null) {

View File

@ -1,6 +1,8 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.io.File;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
@ -98,6 +100,18 @@ public class StandardModelParams implements Serializable, Cloneable {
*/
public int exampleSoundIndex = 0;
/**
* Set the URI in standard model params.
* @param uri - the uri to a model
* @param standarModelParams - the standard model params.
* @return the standard model params.
*/
public static StandardModelParams setModel(URI uri, StandardModelParams standarModelParams) {
standarModelParams.modelPath = new File(uri).getPath();
return standarModelParams;
}
@Override
public StandardModelParams clone() {
StandardModelParams newParams = null;

View File

@ -112,7 +112,7 @@ public abstract class DLModelWorker<T> {
float[] output = null;
long time1 = System.currentTimeMillis();
output = runModel(transformedDataStack);
//System.out.println("Model out: " + PamArrayUtils.array2String(output, 2, ","));
// System.out.println("Model out: " + PamArrayUtils.array2String(output, 2, ","));
long time2 = System.currentTimeMillis();
int numclasses = (int) (output.length/transformedDataStack.length);
@ -169,6 +169,13 @@ public abstract class DLModelWorker<T> {
}
public abstract float[] runModel(float[][][] transformedDataStack);
/**
* Check whether a model is null or not.
* @return true of the model is null.
*/
public abstract boolean isModelNull();
public abstract T makeModelResult(float[] prob, double time);

View File

@ -1,6 +1,5 @@
package rawDeepLearningClassifier.dlClassification.genericModel;
import java.io.File;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
@ -8,19 +7,13 @@ import org.jamdev.jdl4pam.transforms.DLTransformsFactory;
import org.jamdev.jdl4pam.transforms.DLTransfromParams;
import PamController.PamControlledUnitSettings;
import PamController.PamSettingManager;
import PamController.PamSettings;
import PamDetection.RawDataUnit;
import PamUtils.PamArrayUtils;
import PamUtils.PamCalendar;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.DLClassName;
import rawDeepLearningClassifier.dlClassification.DLClassiferModel;
import rawDeepLearningClassifier.dlClassification.DLTaskThread;
import rawDeepLearningClassifier.dlClassification.PredictionResult;
import rawDeepLearningClassifier.dlClassification.StandardClassifierModel;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.segmenter.SegmenterProcess.GroupedRawData;
import warnings.PamWarning;
import warnings.WarningSystem;
/**
@ -30,7 +23,7 @@ import warnings.WarningSystem;
*
*
*/
public class GenericDLClassifier implements DLClassiferModel, PamSettings {
public class GenericDLClassifier extends StandardClassifierModel {
/**
@ -53,20 +46,10 @@ public class GenericDLClassifier implements DLClassiferModel, PamSettings {
*/
private GenericModelWorker genericModelWorker;
/**
* Sound spot warning.
*/
PamWarning genericModelWarning = new PamWarning("Generic deep learning classifier", "",2);
private boolean forceQueue;
/**
* Generic Worker thread for real time
*/
private GenericTaskThread workerThread;
public GenericDLClassifier(DLControl dlControl) {
super(dlControl);
this.dlControl=dlControl;
genericModelUI = new GenericModelUI(this);
@ -84,94 +67,6 @@ public class GenericDLClassifier implements DLClassiferModel, PamSettings {
}
@Override
public void prepModel() {
//System.out.println("PrepModel! !!!");
genericModelWorker.prepModel(genericModelParams, dlControl);
//set cusotm transforms in the model.
if (genericModelParams.dlTransfromParams!=null) {
//important to remkae transforms from params
genericModelParams.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>)genericModelParams.dlTransfromParams);
}
genericModelWorker.setModelTransforms(genericModelParams.dlTransfroms);
if (genericModelWorker.getModel()==null) {
genericModelWarning.setWarningMessage("There is no loaded deep learning model. Generic model classifier disabled.");
WarningSystem.getWarningSystem().addWarning(genericModelWarning);
}
if ((!PamCalendar.isSoundFile() || forceQueue) && !dlControl.isViewer()) {
//for real time only
if (workerThread!=null) {
workerThread.stopTaskThread();
}
workerThread = new GenericTaskThread(genericModelWorker);
workerThread.setPriority(Thread.MAX_PRIORITY);
workerThread.start();
}
}
@Override
public ArrayList<? extends PredictionResult> runModel(ArrayList<GroupedRawData> groupedRawData) {
if (genericModelWorker.getModel()==null) return null;
// System.out.println("SoundSpotClassifier: PamCalendar.isSoundFile(): "
// + PamCalendar.isSoundFile() + " " + (PamCalendar.isSoundFile() && !forceQueue));
/**
* If a sound file is being analysed then SoundSpot can go as slow as it wants. if used in real time
* then there is a buffer with a maximum queue size.
*/
if ((PamCalendar.isSoundFile() && !forceQueue) || dlControl.isViewer()) {
//run the model
//System.out.println("GENERICDLCLASSIFIER: Run here!!!");
ArrayList<GenericPrediction> modelResult = getGenericDLWorker().runModel(groupedRawData,
groupedRawData.get(0).getParentDataBlock().getSampleRate(), 0);
//System.out.println("GENERICDLCLASSIFIER: RESULTS!!!: " + modelResult);
if (modelResult==null) {
genericModelWarning.setWarningMessage("Generic deep learning model returned null");
WarningSystem.getWarningSystem().addWarning(genericModelWarning);
return null;
}
for (int i =0; i<modelResult.size(); i++) {
modelResult.get(i).setClassNameID(getClassNameIDs(genericModelParams));
modelResult.get(i).setBinaryClassification(isBinaryResult(modelResult.get(i), genericModelParams));
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
}
return modelResult; //returns to the classifier.
}
else {
//add to a buffer if in real time.
if (workerThread.getQueue().size()>DLModelWorker.MAX_QUEUE_SIZE) {
//we are not doing well - clear the buffer
workerThread.getQueue().clear();
}
workerThread.getQueue().add(groupedRawData);
}
return null;
}
@Override
public void closeModel() {
//very important to prevent memory leak for long term processing.
if (genericModelWorker.getModel()!=null && genericModelWorker.getModel().getModel()!=null) {
//System.out.println("CLOSE GENERNIC MODEL");
genericModelWorker.getModel().getModel().close();
}
}
@Override
public int getNumClasses() {
@ -286,10 +181,6 @@ public class GenericDLClassifier implements DLClassiferModel, PamSettings {
}
public void newModelSelected(File file) {
// TODO Auto-generated method stub
}
@Override
@ -297,25 +188,7 @@ public class GenericDLClassifier implements DLClassiferModel, PamSettings {
return genericModelWorker.getModel()!=null;
}
/**
* The task thread.
* @author Jamie Macaulay
*
*/
public class GenericTaskThread extends DLTaskThread {
GenericTaskThread(DLModelWorker soundSpotWorker) {
super(soundSpotWorker);
}
@Override
public void newDLResult(GenericPrediction modelResult, GroupedRawData groupedRawData) {
modelResult.setClassNameID(getClassNameIDs(genericModelParams));
modelResult.setBinaryClassification(isBinaryResult(modelResult, genericModelParams));
newResult(modelResult, groupedRawData);
}
}
/**
* Send a new result form the thread queue to the process.
@ -362,44 +235,23 @@ public class GenericDLClassifier implements DLClassiferModel, PamSettings {
return checkSettingsOK(genericModelParams, dlControl);
}
public static ArrayList<PamWarning> checkSettingsOK(StandardModelParams genericModelParams, DLControl dlControl) {
//TODO - check if model is null.
//check that classifier is selected if continous files.
//
ArrayList<PamWarning> warnings = new ArrayList<PamWarning>();
File file = new File(genericModelParams.modelPath);
if (genericModelParams.modelPath == null || !file.isFile()) {
warnings.add(new PamWarning("Generic classifier", "There is no model loaded - the classifier will not run", 2));
//if no model then this is ionly message needed for the generic classifier.
return warnings;
}
//if continous data is selected and all classes are false then this is a potential mistake...
if (dlControl.getSettingsPane().getSelectedParentDataBlock().getUnitClass() == RawDataUnit.class
&& PamArrayUtils.isAllFalse(genericModelParams.binaryClassification)) {
warnings.add(new PamWarning("Generic classifier", "There are no prediction classes selected for classification. "
+ "Predicitons for each segment will be saved but there will be no detections generated", 1));
}
return warnings;
}
@Override
public boolean isModelType(URI uri) {
// TODO Auto-generated method stub
return false;
return super.isModelExtensions(uri);
}
@Override
public DLModelWorker<GenericPrediction> getDLWorker() {
return this.genericModelWorker;
}
@Override
public void setModel(URI model) {
// TODO Auto-generated method stub
public StandardModelParams getDLParams() {
return this.genericModelParams;
}

View File

@ -1,9 +1,6 @@
package rawDeepLearningClassifier.dlClassification.genericModel;
import java.io.File;
import java.util.ArrayList;
import javafx.stage.FileChooser.ExtensionFilter;
import pamViewFX.fxNodes.PamButton;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelPane;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
@ -17,10 +14,6 @@ import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams
*/
public class GenericModelPane extends StandardModelPane {
/**
* The extension filter for sound spot models.
*/
private ArrayList<ExtensionFilter> extensionFilters;
private GenericAdvPane advPane;
@ -35,13 +28,6 @@ public class GenericModelPane extends StandardModelPane {
this.genericDLClassifier = genericDLClassifier;
//must add an additional import settings button.
extensionFilters = new ArrayList<ExtensionFilter>();
//import the settings holder
extensionFilters.add(new ExtensionFilter("TensorFlow Model", "*.pb"));
extensionFilters.add(new ExtensionFilter("Pytorch Model", "*.pk"));
//this.getVBoxHolder().getChildren().add(2, new Label("Classifier Settings"));
usedefaultSeg.setDisable(true);
defaultSegBox.setVisible(false);
@ -91,7 +77,6 @@ public class GenericModelPane extends StandardModelPane {
@Override
public void newModelSelected(File file) {
this.setCurrentSelectedFile(file);
this.genericDLClassifier.newModelSelected(file);
//this.setParamsClone(new GenericModelParams());
//prep the model with current parameters;
@ -105,10 +90,6 @@ public class GenericModelPane extends StandardModelPane {
}
@Override
public ArrayList<ExtensionFilter> getExtensionFilters() {
return extensionFilters;
}
}

View File

@ -1,5 +1,8 @@
package rawDeepLearningClassifier.dlClassification.genericModel;
import java.util.ArrayList;
import java.util.List;
import javax.swing.JPanel;
import PamController.SettingsPane;
@ -7,6 +10,7 @@ import PamView.dialog.warn.WarnOnce;
import PamView.dialog.warn.WarnOnceDialog;
import javafx.application.Platform;
import javafx.scene.control.Alert.AlertType;
import javafx.stage.FileChooser.ExtensionFilter;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
@ -28,6 +32,11 @@ public class GenericModelUI implements DLCLassiferModelUI {
* The sound spot classifier.
*/
private GenericDLClassifier genericModelClassifier;
/**
* The extension filter for sound spot models.
*/
private ArrayList<ExtensionFilter> extensionFilters;
/**
* SondSpot classifier.
@ -35,6 +44,13 @@ public class GenericModelUI implements DLCLassiferModelUI {
*/
public GenericModelUI(GenericDLClassifier soundSpotClassifier) {
this.genericModelClassifier=soundSpotClassifier;
//must add an additional import settings button.
extensionFilters = new ArrayList<ExtensionFilter>();
//import the settings holder
extensionFilters.add(new ExtensionFilter("TensorFlow Model", "*.pb"));
extensionFilters.add(new ExtensionFilter("Pytorch Model", "*.pk"));
}
@Override
@ -68,4 +84,8 @@ public class GenericModelUI implements DLCLassiferModelUI {
}
@Override
public List<ExtensionFilter> getModelFileExtensions() {
return extensionFilters;
}
}

View File

@ -78,9 +78,13 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
// }
//first open the model and get the correct parameters.
//21/11/2022 - Added a null and filename check here to stop the mdoel reloading everytime PAMGuard hits a new file or
//21/11/2022 - Added a null and filename check here to stop the model reloading everytime PAMGuard hits a new file or
//is stopped or started - this was causing a memory leak.
if (genericModel==null || !Paths.get(genericModel.getModel().getName()).equals(Paths.get(genericParams.modelPath))) {
if (genericModel!=null && genericModel.getModel()!=null) {
genericModel.getModel().close();
}
//System.out.println(Paths.get(genericParams.modelPath));
genericModel = new PamGenericModel(genericParams.modelPath);
@ -88,6 +92,8 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
//System.out.println(genericModel.getModel().getModelPath().getFileName());
}
setModelTransforms(genericParams.dlTransfroms);
//is this a waveform or a spectrogram model?
DLTransform transform = genericParams.dlTransfroms.get(genericParams.dlTransfroms.size()-1);
if (transform instanceof FreqTransform) {
@ -140,7 +146,10 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
@Override
public void closeModel() {
genericModel.getModel().close();
// can be very important to prevent memory leak for long term processing.
if (genericModel!=null && genericModel.getModel()!=null) {
genericModel.getModel().close();
}
}
/**
@ -151,4 +160,9 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
return genericModel;
}
@Override
public boolean isModelNull() {
return genericModel==null;
}
}

View File

@ -9,20 +9,13 @@ import org.jamdev.jdl4pam.transforms.DLTransfromParams;
import PamController.PamControlledUnitSettings;
import PamController.PamSettingManager;
import PamController.PamSettings;
import PamUtils.PamCalendar;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.DLClassName;
import rawDeepLearningClassifier.dlClassification.DLClassiferModel;
import rawDeepLearningClassifier.dlClassification.DLTaskThread;
import rawDeepLearningClassifier.dlClassification.PredictionResult;
import rawDeepLearningClassifier.dlClassification.StandardClassifierModel;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericDLClassifier;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
import rawDeepLearningClassifier.segmenter.SegmenterProcess.GroupedRawData;
import warnings.PamWarning;
import warnings.WarningSystem;
/**
* Classifier which uses deep learning models from Meridian's Ketos framework.
@ -36,12 +29,9 @@ import warnings.WarningSystem;
* @author Jamie Macaulay
*
*/
public class KetosClassifier implements DLClassiferModel, PamSettings {
public class KetosClassifier extends StandardClassifierModel {
/**
* Reference to the DL contro..
*/
private DLControl dlControl;
/**
* Paramters for a Ketos classifier.
@ -59,27 +49,12 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
*/
private KetosWorker ketosWorker;
/**
* True to force the classifier to use a queue - can be used for simulating real time operation.
*/
private boolean forceQueue = false;
/**
* Sound spot warning.
*/
PamWarning ketosWarning = new PamWarning("Ketos_Classifier", "",2);
/**
* The Ketos worker thread has a buffer so that Ketos models can be run
* in real time without dslowing down the rest of PAMGaurd.
*/
private KetosThread workerThread;
/**
* The ketos classifier.
*/
public KetosClassifier(DLControl dlControl) {
super(dlControl);
this.dlControl=dlControl;
this.ketosDLParams = new KetosDLParams();
this.ketosUI= new KetosUI(this);
@ -88,115 +63,6 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
}
@Override
public ArrayList<? extends PredictionResult> runModel(ArrayList<GroupedRawData> groupedRawData) {
if (getKetosWorker().getModel()==null) return null;
// System.out.println("SoundSpotClassifier: PamCalendar.isSoundFile(): "
// + PamCalendar.isSoundFile() + " " + (PamCalendar.isSoundFile() && !forceQueue));
/**
* If a sound file is being analysed then Ketos can go as slow as it wants. if used in real time
* then there is a buffer with a maximum queue size.
*/
if ((PamCalendar.isSoundFile() && !forceQueue) || dlControl.isViewer()) {
//run the model
ArrayList<KetosResult> modelResult = getKetosWorker().runModel(groupedRawData,
groupedRawData.get(0).getParentDataBlock().getSampleRate(), 0);
if (modelResult==null) {
ketosWarning.setWarningMessage("Generic deep learning model returned null");
WarningSystem.getWarningSystem().addWarning(ketosWarning);
return null;
}
for (int i =0; i<modelResult.size(); i++) {
modelResult.get(i).setClassNameID(GenericDLClassifier.getClassNameIDs(ketosDLParams));
modelResult.get(i).setBinaryClassification(GenericDLClassifier.isBinaryResult(modelResult.get(i), ketosDLParams));
modelResult.get(i).setTimeMillis(groupedRawData.get(i).getTimeMilliseconds());
}
return modelResult; //returns to the classifier.
}
else {
//add to a buffer if in real time.
if (workerThread.getQueue().size()>DLModelWorker.MAX_QUEUE_SIZE) {
//we are not doing well - clear the buffer
workerThread.getQueue().clear();
}
workerThread.getQueue().add(groupedRawData);
}
return null;
}
@Override
public void prepModel() {
//System.out.println("PrepModel! !!!");
getKetosWorker().prepModel(ketosDLParams, dlControl);
if (!ketosDLParams.useDefaultTransfroms) {
//set custom transforms in the model.
getKetosWorker().setModelTransforms(ketosDLParams.dlTransfroms);
}
if (getKetosWorker().getModel()==null) {
ketosWarning.setWarningMessage("There is no loaded classifier model. Ketos disabled.");
WarningSystem.getWarningSystem().addWarning(ketosWarning);
}
if ((!PamCalendar.isSoundFile() || forceQueue) && !dlControl.isViewer()) {
//for real time only
if (workerThread!=null) {
workerThread.stopTaskThread();
}
workerThread = new KetosThread(getKetosWorker());
workerThread.setPriority(Thread.MAX_PRIORITY);
workerThread.start();
}
}
/**
* The task thread to run Ketos classifier in real time.
*
* @author Jamie Macaulay
*
*/
public class KetosThread extends DLTaskThread {
KetosThread(DLModelWorker soundSpotWorker) {
super(soundSpotWorker);
}
@Override
public void newDLResult(GenericPrediction soundSpotResult, GroupedRawData groupedRawData) {
soundSpotResult.setClassNameID(GenericDLClassifier.getClassNameIDs(ketosDLParams));
soundSpotResult.setBinaryClassification(GenericDLClassifier.isBinaryResult(soundSpotResult, ketosDLParams));
newResult(soundSpotResult, groupedRawData);
}
}
/**
* Send a new result form the thread queue to the process.
* @param modelResult - the model result;
* @param groupedRawData - the grouped raw data.
*/
protected void newResult(GenericPrediction modelResult, GroupedRawData groupedRawData) {
this.dlControl.getDLClassifyProcess().newModelResult(modelResult, groupedRawData);
}
@Override
public void closeModel() {
// TODO Auto-generated method stub
}
@Override
public String getName() {
return "Ketos";
@ -206,33 +72,19 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
public DLCLassiferModelUI getModelUI() {
return this.ketosUI;
}
@Override
public Serializable getDLModelSettings() {
public DLModelWorker<GenericPrediction> getDLWorker() {
return getKetosWorker();
}
@Override
public StandardModelParams getDLParams() {
return ketosDLParams;
}
@Override
public int getNumClasses() {
return ketosDLParams.numClasses;
}
@Override
public DLClassName[] getClassNames() {
//System.out.println("Ketos Model: " + ketosDLParams.numClasses);
return ketosDLParams.classNames;
}
@Override
public DLControl getDLControl() {
return dlControl;
}
@Override
public boolean checkModelOK() {
return getKetosWorker().getModel()!=null;
}
/**
* Get the parameters for the Ketos classifier.
* @param ketosDLParams - the Ketos parameters.
@ -249,6 +101,11 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
this.ketosDLParams = ketosDLParams;
}
@Override
public Serializable getDLModelSettings() {
return ketosDLParams;
}
/**
* Get the KetosWorker. this handles loading and running the Ketos model.
* @return the Ketos worker.
@ -309,23 +166,12 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
else ketosDLParams = new KetosDLParams();
return true;
}
@Override
public ArrayList<PamWarning> checkSettingsOK() {
return GenericDLClassifier.checkSettingsOK(ketosDLParams, dlControl);
}
@Override
public boolean isModelType(URI uri) {
// TODO Auto-generated method stub
return false;
}
@Override
public void setModel(URI model) {
ketosDLParams.modelPath = model.getPath();
//Ketos is easy because there are not many files with a .ktpb extension.
return super.isModelExtensions(uri);
}
}

View File

@ -24,15 +24,9 @@ public class KetosModelPane extends StandardModelPane {
super(soundSpotClassifier);
this.ketosClassifier = soundSpotClassifier;
extensionFilters = new ArrayList<ExtensionFilter>();
//import the settings holder
extensionFilters.add(new ExtensionFilter("Ketos Model", "*.ktpb"));
}
@Override
public ArrayList<ExtensionFilter> getExtensionFilters() {
return extensionFilters;
}
@Override
public void newModelSelected(File file) {

View File

@ -30,20 +30,26 @@ public class KetosModelTest {
public static void main(String[] args) {
//test on a right whale.
//File file = new File("/Volumes/GoogleDrive/My Drive/PAMGuard_dev/Deep_Learning/Meridian/right_whales/for_pamguard/narw.ktpb");
// File file = new File("/Users/au671271/Library/CloudStorage/GoogleDrive-macster110@gmail.com/My Drive/PAMGuard_dev/Deep_Learning/Ketos/right_whales/for_pamguard/narw.ktpb");
// File file = new File("/Volumes/GoogleDrive-108005893101854397430/My Drive/PAMGuard_dev/Deep_Learning/Meridian/humpback_whales/SOCAL_Mn_Network.ktpb");
//File file = new File("/Volumes/GoogleDrive-108005893101854397430/My Drive/PAMGuard_dev/Deep_Learning/Meridian/orca/kw_detector_v11_5s.ktpb");
//the wav file to test.
//String wavFilePath = "/Volumes/GoogleDrive/My Drive/PAMGuard_dev/Deep_Learning/Meridian/right_whales/for_pamguard/input.wav";
// String wavFilePath = "/Users/au671271/Library/CloudStorage/GoogleDrive-macster110@gmail.com/My Drive/PAMGuard_dev/Deep_Learning/Ketos/right_whales/for_pamguard/input.wav";
// String wavFilePath = "/Volumes/GoogleDrive-108005893101854397430/My Drive/PAMGuard_dev/Deep_Learning/Meridian/humpback_whales/wav/5353.210403161502.wav";
// double windowSize = 3.52;
// double[] window = new double[]{0., 3.0};
// //Minke model
// File file = new File("/Users/au671271/Desktop/Minke_test/Minke_Network_12s.ktpb");
// String wavFilePath = "/Users/au671271/Desktop/Minke_test/1705_FLAC_1705_20171106_185953_253.wav";
// double windowSize = 12;
//Minke model
File file = new File("/Users/au671271/Desktop/Minke_test/Minke_Network_12s.ktpb");
String wavFilePath = "/Users/au671271/Desktop/Minke_test/1705_FLAC_1705_20171106_185953_253.wav";
double windowSize = 12;
//
File file = new File("/Users/au671271/Library/CloudStorage/GoogleDrive-macster110@gmail.com/My Drive/PAMGuard_dev/Deep_Learning/Ketos/narw_2/hallo-kw-det_v1_test/hallo-kw-det_v1.ktpb");
String wavFilePath = "/Users/au671271/Library/CloudStorage/GoogleDrive-macster110@gmail.com/My Drive/PAMGuard_dev/Deep_Learning/Ketos/narw_2/hallo-kw-det_v1_test/audio/jasco_reduced.wav";
// double[] window = new double[]{10., 15.0176};
double[] window = new double[]{45, 50.0176};
try {
//the ketos model.
@ -58,22 +64,38 @@ public class KetosModelTest {
//System.out.println(ketosParams.toString());
System.out.println("Output shape" + ketosParams.defaultOutputShape);
System.out.println("Input shape" + ketosParams.defaultInputShape);
//28-04-2023 seems like there is a BUG in ketos where the input shape reported by the model is incorrect.
ketosModel.setInputShape(ketosParams.defaultInputShape);
//Open wav files.
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
soundData = soundData.trim(0, (int) (soundData.getSampleRate()*windowSize));
soundData = soundData.trim((int) (soundData.getSampleRate()*window[0]), (int) (soundData.getSampleRate()*window[1]));
System.out.println("Input sample rate is " + soundData.getSampleRate());
//generate the transforms.
ArrayList<DLTransform> transforms = DLTransformsFactory.makeDLTransforms(ketosParams.dlTransforms);
((WaveTransform) transforms.get(0)).setWaveData(soundData);
DLTransform transform = transforms.get(0);
for (int i=0; i<transforms.size(); i++) {
// System.out.println(transforms);
for (int i=0; i<ketosParams.dlTransforms.size(); i++) {
// try {
// System.out.println("Transform: " + ketosParams.dlTransforms.get(i));
// }
// catch(Exception e) {
// e.printStackTrace();
// }
transform = transforms.get(i).transformData(transform);
// if (i==1) {
// transfromedData = DLMatFile.array2Matrix(((FreqTransform) transform).getSpecTransfrom().getTransformedData());
// }

View File

@ -1,8 +1,12 @@
package rawDeepLearningClassifier.dlClassification.ketos;
import java.util.ArrayList;
import java.util.List;
import javax.swing.JPanel;
import PamController.SettingsPane;
import javafx.stage.FileChooser.ExtensionFilter;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
@ -23,6 +27,8 @@ public class KetosUI implements DLCLassiferModelUI {
* The sound spot classifier.
*/
private KetosClassifier ketosClassifier;
private ArrayList<ExtensionFilter> extensionFilters;
/**
* SondSpot classifier.
@ -30,6 +36,10 @@ public class KetosUI implements DLCLassiferModelUI {
*/
public KetosUI(KetosClassifier ketosClassifier) {
this.ketosClassifier=ketosClassifier;
extensionFilters = new ArrayList<ExtensionFilter>();
//import the settings holder
extensionFilters.add(new ExtensionFilter("Ketos Model", "*.ktpb"));
}
@Override
@ -44,7 +54,6 @@ public class KetosUI implements DLCLassiferModelUI {
@Override
public void getParams() {
KetosDLParams ketosParams = (KetosDLParams) getSettingsPane().getParams(ketosClassifier.getKetosParams());
System.out.println("KetosUI PARMAS: " + ketosParams.dlTransfromParams);
ketosClassifier.setKetosParams(ketosParams);
}
@ -62,4 +71,11 @@ public class KetosUI implements DLCLassiferModelUI {
return null;
}
@Override
public List<ExtensionFilter> getModelFileExtensions() {
return extensionFilters;
}
}

View File

@ -11,12 +11,14 @@ import org.jamdev.jdl4pam.transforms.DLTransform.DLTransformType;
import org.jamdev.jdl4pam.transforms.DLTransformsFactory;
import org.jamdev.jdl4pam.transforms.jsonfile.DLTransformsParser;
import PamModel.PamModel;
import PamModel.PamModel.PluginClassloader;
import PamView.dialog.PamDialog;
import PamView.dialog.warn.WarnOnce;
import rawDeepLearningClassifier.DLControl;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.PamGenericModel;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
import ai.djl.ndarray.types.Shape;
/**
*
@ -26,7 +28,7 @@ import rawDeepLearningClassifier.dlClassification.genericModel.PamGenericModel;
* @author Jamie Macaulay
*
*/
public class KetosWorker extends DLModelWorker<KetosResult> {
public class KetosWorker extends DLModelWorker<GenericPrediction> {
/**
@ -66,6 +68,12 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
//21/11/2022 - Added a null and filename check here to stop the mdoel reloading everytime PAMGuard hits a new file or
//is stopped or started - this was causing a memory leak.
if (ketosModel==null || currentPath ==null || !Paths.get(currentPath).equals(Paths.get(ketosDLParams.modelPath))) {
//TODO
// if (ketosModel!=null && ketosModel.getModel()!=null) {
// ketosModel.getModel().close();
// }
//System.out.println(Paths.get(genericParams.modelPath));
this.currentPath = ketosDLParams.modelPath;
@ -93,8 +101,13 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
ketosParams.defaultOutputShape = ketosModel.getOutShape();
}
//HACK there seems to be some sort of bug in ketos where the params input shape is correct but the model input shape is wrong.
if (ketosModel.getInputShape()==null || !ketosModel.getInputShape().equals(ketosParams.defaultInputShape)) {
WarnOnce.showWarning("Model shape", "The model shape does not match the model metadata. \n Metadata shape will be used used.", WarnOnce.OK_OPTION);
ketosModel.setInputShape(ketosParams.defaultInputShape);
}
///HACK here for now to fix an issue with dB and Ketos transforms having zero length somehow...
for (int i=0; i<ketosParams.dlTransforms.size(); i++) {
if (ketosParams.dlTransforms.get(i).dltransfromType == DLTransformType.SPEC2DB) {
@ -114,14 +127,14 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
//only load new transforms if defaults are selected
if (getModelTransforms()==null || ketosDLParams.dlTransfroms==null || ketosDLParams.useDefaultTransfroms) {
System.out.println(" " + transforms);
System.out.println("SET MODEL TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
//System.out.println("SET MODEL TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
//only set the transforms if they are null - otherwise handled elsewhere.
setModelTransforms(transforms);
ketosDLParams.useDefaultTransfroms = true;
}
else {
System.out.println("SET CURRENT TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
//System.out.println("SET CURRENT TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
//use the old transforms.
setModelTransforms(ketosDLParams.dlTransfroms);
}
@ -129,8 +142,12 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
//ketosDLParams.dlTransfroms = transforms; //this is done after prep model in the settings pane.
ketosDLParams.defaultSegmentLen = ketosParams.seglen*1000.; //the segment length in microseconds.
//ketosParams.classNames = new String[] {"Noise", "Right Whale"}; // FIXME;
ketosDLParams.numClasses = (int) ketosModel.getOutShape().get(1);
ketosDLParams.numClasses = (int) ketosModel.getOutShape().get(1);
/*****
//ok 0 the other values are not user selectable but this is. If we relaod the same model we probably want to keep it....
//So this is a little bt of a hack but will probably be OK in most cases.
if (ketosDLParams.binaryClassification==null || ketosDLParams.binaryClassification.length!=ketosDLParams.numClasses) {
@ -186,6 +203,7 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
*/
public void closeModel() {
//TODO
this.currentPath = null;
}
@ -197,5 +215,10 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
return ketosModel;
}
@Override
public boolean isModelNull() {
return ketosModel==null;
}
}

View File

@ -1,8 +1,11 @@
package rawDeepLearningClassifier.dlClassification.orcaSpot;
import java.util.List;
import javax.swing.SwingUtilities;
import PamController.SettingsPane;
import javafx.stage.FileChooser.ExtensionFilter;
import rawDeepLearningClassifier.layoutFX.DLCLassiferModelUI;
/**
@ -70,4 +73,10 @@ public class OrcaSpotClassifierUI implements DLCLassiferModelUI {
return orcaSpotSidePanel;
}
@Override
public List<ExtensionFilter> getModelFileExtensions() {
// TODO Auto-generated method stub
return null;
}
}

View File

@ -1,7 +1,9 @@
package rawDeepLearningClassifier.layoutFX;
import javax.swing.JPanel;
import java.util.List;
import javax.swing.JPanel;
import javafx.stage.FileChooser.ExtensionFilter;
import PamController.SettingsPane;
/**
@ -27,6 +29,13 @@ public interface DLCLassiferModelUI {
*/
public void setParams();
/**
* If using a file dialog to search for
* @return the file extensions (if any) for this type of classifier
*/
public List<ExtensionFilter> getModelFileExtensions();
/**
* Get a side panel specific to the classifier.
* @return the side panel.

View File

@ -7,6 +7,7 @@ import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import org.controlsfx.control.PopOver;
@ -42,7 +43,7 @@ import rawDeepLearningClassifier.dlClassification.DefaultModels.DefualtModel;
/**
* A pane which allows users to select a model.
*
* Models could be potnetially selected from
* Models could be potentially selected from
* 1) A file (implemented)
* 2) A URL (not implemented)
* 3) A default list of models (not implemented. )
@ -61,7 +62,7 @@ public class DLModelSelectPane extends PamBorderPane {
/**
* Currently selected file.
*/
private File currentSelectedFile = new File(System.getProperty("user.home"));
protected URI currentSelectedFile = new File(System.getProperty("user.home")).toURI();
/**
* The label showing the path to the file.
@ -87,7 +88,7 @@ public class DLModelSelectPane extends PamBorderPane {
/**
* The current classifier model.
*/
private DLClassiferModel currentClassifierModel;
DLClassiferModel currentClassifierModel;
/**
* The default models.
@ -99,11 +100,14 @@ public class DLModelSelectPane extends PamBorderPane {
*/
private PopOver urlPopOver;
private TextField uriTextField;
private TextField uriTextField;
private RawDLSettingsPane rawDLSettingsPane;
public DLModelSelectPane(DLControl dlControl) {
this.dlControl=dlControl;
public DLModelSelectPane(RawDLSettingsPane rawDLSettingsPane) {
this.rawDLSettingsPane=rawDLSettingsPane;
this.dlControl=rawDLSettingsPane.getDLControl();
this.setCenter(createDLSelectPane());
//the directory chooser.
fileChooser = new FileChooser();
@ -140,9 +144,9 @@ public class DLModelSelectPane extends PamBorderPane {
fileChooser.getExtensionFilters().addAll(getExtensionFilters());
Path path = currentSelectedFile.toPath();
Path path = Paths.get(currentSelectedFile);
if(path!=null && Files.exists(path, LinkOption.NOFOLLOW_LINKS)) {
fileChooser.setInitialDirectory(new File(currentSelectedFile.getParent()));
fileChooser.setInitialDirectory(new File(new File(currentSelectedFile).getParent()));
}
else {
fileChooser.setInitialDirectory(new File(System.getProperty("user.home")));
@ -154,40 +158,12 @@ public class DLModelSelectPane extends PamBorderPane {
return;
}
modelLoadIndicator.setVisible(true);
pathLabel.setText("Loading model...");
loadNewModel(file.toURI());
// separate non-FX thread - load the model
//on a separate thread so we can show a moving load
//bar on the FX thread. Otherwise the GUI locks up
//whilst stuff is loaded.
new Thread() {
// runnable for that thread
public void run() {
try {
newModelSelected(file.toURI());
Thread.sleep(5000);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
Platform.runLater(new Runnable() {
public void run() {
modelLoadIndicator.setVisible(false);
updatePathLabel();
}
});
}
}.start();
});
PamVBox urlBox = new PamVBox();
urlBox.setPadding(new Insets(5,5,5,5));
urlBox.setSpacing(5);
@ -252,30 +228,89 @@ public class DLModelSelectPane extends PamBorderPane {
return hBox;
}
/**
* Load a new model on a seperate thread.
* @param uri - the uri to the model.
*/
public void loadNewModel(URI uri) {
// separate non-FX thread - load the model
//on a separate thread so we can show a moving load
//bar on the FX thread. Otherwise the GUI locks up
//whilst stuff is loaded.
if (uri==null) return;
pathLabel.setText("Loading model...");
modelLoadIndicator.setVisible(true);
new Thread() {
// runnable for that thread
public void run() {
try {
newModelSelected(uri);
currentSelectedFile = uri;
Thread.sleep(1000); //just show the user something happened if model loading is rapid.
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
Platform.runLater(new Runnable() {
public void run() {
try {
rawDLSettingsPane.setClassifierPane();
modelLoadIndicator.setVisible(false);
updatePathLabel();
}
catch (Exception e) {
e.printStackTrace();
}
}
});
}
}.start();
}
/**
* A new model has been selected
* @param
*/
private void newModelSelected(URI file) {
this.currentClassifierModel = this.dlControl.getDlClassifierChooser().selectClassiferModel(file);
if (currentClassifierModel==null) {
currentClassifierModel.setModel(file);
if (file == null) {
return;
}
this.currentClassifierModel = this.dlControl.getDlClassifierChooser().selectClassiferModel(file);
System.out.println("New classifier model selected!: " + currentClassifierModel);
if (currentClassifierModel!=null) {
currentClassifierModel.setModel(file);
currentClassifierModel.prepModel();
}
}
/**
* Update the path label and tool tip text;
*/
private void updatePathLabel() {
if (currentClassifierModel == null ? true : !currentClassifierModel.checkModelOK()) {
protected void updatePathLabel() {
//System.out.println("Update path label: " + currentClassifierModel.checkModelOK());
if (currentClassifierModel == null) {
pathLabel.setText("No classifier model loaded: Select model");
pathLabel.setTooltip(new Tooltip("Use the browse button/ URI botton to select a model or select a default model"));
}
else if (!currentClassifierModel.checkModelOK()) {
pathLabel.setText("The model could not be loaded?");
pathLabel.setTooltip(new Tooltip("Use the browse button/ URI botton to select a model or select a default model"));
}
else {
pathLabel .setText(this.currentSelectedFile.getName());
pathLabel .setText(new File(this.currentSelectedFile).getName());
try {
pathLabel.setTooltip(new Tooltip(this.currentSelectedFile.getPath()
+ "\n" +" Processor CPU " + Device.cpu() + " " + Device.gpu()));
@ -295,10 +330,23 @@ public class DLModelSelectPane extends PamBorderPane {
*/
public ArrayList<ExtensionFilter> getExtensionFilters() {
ArrayList<String> extensionFilters = new ArrayList<String>();
for (DLClassiferModel dlModel: dlControl.getDLModels()) {
if (dlModel.getModelUI()!=null) {
for (ExtensionFilter extFilter: dlModel.getModelUI().getModelFileExtensions()){
extensionFilters.addAll(extFilter.getExtensions());
}
}
}
//Now we don't really want lots of extension filters
ArrayList<ExtensionFilter> dlExtFilter = new ArrayList<ExtensionFilter>();
dlExtFilter.add(new ExtensionFilter("Deep Learning Models", extensionFilters));
return null;
return dlExtFilter ;
}

View File

@ -35,7 +35,7 @@ public class DLSidePanelSwing implements PamSidePanel {
*/
public void setupPanel() {
//System.out
if (dlControl.getDLModel().getModelUI()!=null && dlControl.getDLModel().getModelUI().getSidePanel()!=null) {
if (dlControl.getDLModel()!=null && dlControl.getDLModel().getModelUI()!=null && dlControl.getDLModel().getModelUI().getSidePanel()!=null) {
mainPanel.add(dlControl.getDLModel().getModelUI().getSidePanel(), BorderLayout.CENTER);
// mainPanel.add(new JLabel("Hello"), BorderLayout.WEST);
mainPanel.validate();
@ -44,6 +44,7 @@ public class DLSidePanelSwing implements PamSidePanel {
//blank
mainPanel.removeAll();
}
}
@Override

View File

@ -1,5 +1,6 @@
package rawDeepLearningClassifier.layoutFX;
import java.net.URI;
import java.util.ArrayList;
import org.controlsfx.control.PopOver;
@ -117,11 +118,10 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
private Object flipPane;
private PopupControl advLabel;
private DLModelSelectPane modelSelectPane;
/**
* The currently loaded classifier model
*/
private DLClassiferModel classifierModel = null;
public RawDLSettingsPane(DLControl dlControl){
super(null);
@ -139,7 +139,7 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
mainPane=new PamBorderPane();
mainPane.setCenter(createDLPane());
mainPane.setPadding(new Insets(5,5,5,5));
mainPane.setMinHeight(400);
mainPane.setMinHeight(430);
mainPane.setMaxWidth(MAX_WIDTH);
mainPane.setPrefWidth(MAX_WIDTH);
//this.getAdvPane().setMaxWidth(MAX_WIDTH);
@ -240,7 +240,7 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
/**
* Pane which allows users to select a model type.
*/
DLModelSelectPane selectPane = new DLModelSelectPane(dlControl);
modelSelectPane = new DLModelSelectPane(this);
// //add the possible deep learning models.
// dlModelBox= new ComboBox<String>();
@ -264,7 +264,7 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
classifierPane = new PamBorderPane();
vBox.getChildren().addAll(selectPane, classifierPane);
vBox.getChildren().addAll(modelSelectPane, classifierPane);
return vBox;
@ -370,12 +370,14 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
/**
* Set the classifier pane.
*/
private void setClassifierPane() {
protected void setClassifierPane() {
//set the classifier Pane.class
System.out.println("Set CLASSIFIER PANE: " + modelSelectPane.currentClassifierModel);
if (classifierModel!=null && classifierModel.getModelUI()!=null) {
classifierPane.setCenter(classifierModel.getModelUI().getSettingsPane().getContentNode());
classifierModel.getModelUI().setParams();
if (modelSelectPane.currentClassifierModel!=null && modelSelectPane.currentClassifierModel.getModelUI()!=null) {
classifierPane.setCenter(modelSelectPane.currentClassifierModel.getModelUI().getSettingsPane().getContentNode());
modelSelectPane.currentClassifierModel.getModelUI().setParams();
}
else {
classifierPane.setCenter(null);
@ -410,6 +412,14 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
currParams.rawSampleSize = windowLength.getValue();
currParams.sampleHop = hopLength.getValue();
currParams.maxMergeHops = reMergeSeg.getValue();
if (modelSelectPane.currentClassifierModel == null) {
currParams.modelSelection = -1;
}
else {
currParams.modelSelection = dlControl.getDLModels().indexOf((modelSelectPane.currentClassifierModel));
}
// //update any changes
@ -432,8 +442,18 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
if (dlControl.getDataSelector()!=null) {
dlControl.getDataSelector().getDialogPaneFX().getParams(true);
}
//need to make sure we call get params for the current model when the oK button is pressed.
if (this.modelSelectPane.currentClassifierModel!=null) {
if (this.modelSelectPane.currentClassifierModel.getModelUI()!=null) {
this.modelSelectPane.currentClassifierModel.getModelUI().getParams();
}
}
currParams.modelURI = this.modelSelectPane.currentSelectedFile;
return currParams;
}
@ -473,7 +493,9 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
dlControl.createDataSelector(sourcePane.getSource());
//set the classifier model.
classifierModel = dlControl.getDLModel();
if (currParams.modelURI !=null) {
modelSelectPane.currentClassifierModel = dlControl.getDLModel();
}
// dlModelBox.getSelectionModel().select(currParams.modelSelection);
@ -490,6 +512,14 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
enableControls();
setSegInfoLabel();
//set up the model and the custom pane if necessary.
this.modelSelectPane.loadNewModel(currParams.modelURI);
//this.modelSelectPane.updatePathLabel();
this.setClassifierPane();
sourcePane.getChannelValidator().validate();
}
@ -519,6 +549,14 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
return sourcePane.getSource();
}
/**
* Get the DLControl associated with the pane.
* @return a reference to the DLControl.
*/
public DLControl getDLControl() {
return dlControl;
}
}

View File

@ -216,7 +216,7 @@ public class SimpleTransformPane extends DLTransformPane {
//System.out.println("Transform type: " + simpleTransform.getDLTransformType() + " " + simpleTransform.getParams().length + " " + spinners.size());
for (int i=0; i<spinners.size(); i++) {
//spinners.get(i).getValueFactory().setValue(simpleTransform.getParams()[i] );
System.out.println("Set params: " + input.getDLTransformType() + " param val: " + simpleTransform.getParams()[i] + " " + (simpleTransform.getParams()[i] instanceof Float));
//System.out.println("Set params: " + input.getDLTransformType() + " param val: " + simpleTransform.getParams()[i] + " " + (simpleTransform.getParams()[i] instanceof Float));
if (simpleTransform.getParams()[i] instanceof Float || simpleTransform.getParams()[i] instanceof Double) {
//System.out.println("Double: simpleTransform.getParams()[i]: " + simpleTransform.getParams()[i] + " " +spinners.get(i).getValueFactory());
spinners.get(i).getValueFactory().setValue(simpleTransform.getParams()[i].doubleValue());

View File

@ -94,7 +94,7 @@ public class DLOfflineTask extends OfflineTask<PamDataUnit<?,?>>{
@Override
public void completeTask() {
//dlControl.setNotifyProcesses(false);
this.dlControl.getDLModel().closeModel();
//this.dlControl.getDLModel().closeModel();
dlControl.update(MTClassifierControl.PROCESSING_END);
}