mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-21 14:42:27 +00:00
Implement zip model for deep learning classifier.
Implement zip models which allow packaging of deep learning models with metadata\ Fix TDdsiplayFX spinner bug Bug fixes for deep learning module after testing Implement a validator to warn if users have selected all or no classes for binary classifcation.
This commit is contained in:
parent
98e7d55034
commit
45bf890a62
BIN
nullPamguardSettings_20171106_185953.psfx
Normal file
BIN
nullPamguardSettings_20171106_185953.psfx
Normal file
Binary file not shown.
17
pom.xml
17
pom.xml
@ -240,23 +240,6 @@
|
||||
<url>file://${project.basedir}/repo</url>
|
||||
</repository>
|
||||
|
||||
<!-- Repo for Jamie's code -->
|
||||
<!-- IMPORTANT: SOMEHOW THIS REPO INTERFERES WITH THE
|
||||
JAVAFX DEPENDENCIES. Don't understand how, but it
|
||||
does. If Maven is not downloading the JavaFX dependencies
|
||||
and is instead throwing errors, try commenting out this
|
||||
repository. Let Maven download the JavaFX files, and
|
||||
then uncomment this repo again.
|
||||
Same thing seems to be occurring if I try to update the postgresql library. -->
|
||||
<!-- <repository> -->
|
||||
<!-- <snapshots> -->
|
||||
<!-- <enabled>false</enabled> -->
|
||||
<!-- </snapshots> -->
|
||||
<!-- <id>central</id> -->
|
||||
<!-- <name>a0u0ltozdsehx-artifactory-primary-0-releases</name> -->
|
||||
<!-- <url>https://jmachund.jfrog.io/artifactory/jpam2</url> -->
|
||||
<!-- </repository> -->
|
||||
|
||||
<!-- Repo for netCDF -->
|
||||
<repository>
|
||||
<snapshots>
|
||||
|
@ -115,12 +115,12 @@ public class TDControlFX extends TDControl implements UserDisplayNodeFX {
|
||||
System.out.println("TDControldFX: parent datablock null");
|
||||
return dataBlocks;
|
||||
}
|
||||
|
||||
for (int i=0; i<tdDisplayController.getUserDisplayProcess().getNumMuiltiplexDataBlocks(); i++){
|
||||
dataBlock=this.tdDisplayController.getUserDisplayProcess().getMuiltiplexDataBlock(i);
|
||||
if (TDDataProviderRegisterFX.getInstance().findDataProvider(dataBlock)!=null){
|
||||
dataBlocks.add(dataBlock);
|
||||
}
|
||||
// System.out.println("TDControldFX: parent data block "+dataBlock.getDataName());
|
||||
}
|
||||
return dataBlocks;
|
||||
}
|
||||
|
@ -517,7 +517,7 @@ public class ClickPlotInfoFX extends TDDataInfoFX {
|
||||
|
||||
if (pamDataUnit.getSuperDetectionsCount()==0) return null;
|
||||
|
||||
ListIterator<PamDataUnit> iterator = this.getPamDataBlock().getListIterator(PamDataBlock.ITERATOR_END);
|
||||
ListIterator<PamDataUnit> iterator = this.getDataBlock().getListIterator(PamDataBlock.ITERATOR_END);
|
||||
PamDataUnit testDataUnit;
|
||||
|
||||
while (iterator.hasPrevious()) {
|
||||
|
@ -731,13 +731,6 @@ public abstract class TDDataInfoFX {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the pamDataBlock
|
||||
*/
|
||||
protected PamDataBlock getPamDataBlock() {
|
||||
return pamDataBlock;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the tdDataProvider
|
||||
*/
|
||||
|
@ -80,6 +80,7 @@ public class TDDataProviderRegisterFX {
|
||||
*/
|
||||
synchronized public TDDataProviderFX findDataProvider(PamDataBlock pamDataBlock) {
|
||||
for (TDDataProviderFX aProvider:dataProviders) {
|
||||
// System.out.println("Data providers: " + aProvider.getName() + " | " + aProvider.getDataBlock() + " | " +pamDataBlock);
|
||||
if (aProvider.getDataBlock()==pamDataBlock) {
|
||||
return aProvider;
|
||||
}
|
||||
|
@ -294,7 +294,7 @@ public class GenericDataPlotInfo extends TDDataInfoFX {
|
||||
* do some which may be associated with other annotations ?
|
||||
*/
|
||||
protected void updateAvailability() {
|
||||
LocalisationInfo locInfo = getPamDataBlock().getLocalisationContents();
|
||||
LocalisationInfo locInfo = getDataBlock().getLocalisationContents();
|
||||
bearingScaleInfo.setAvailable(locInfo.hasLocContent(LocContents.HAS_BEARING));
|
||||
slantScaleInfo.setAvailable(locInfo.hasLocContent(LocContents.HAS_BEARING));
|
||||
}
|
||||
|
@ -23,8 +23,9 @@ import rawDeepLearningClassifier.dataPlotFX.LineInfo;
|
||||
|
||||
|
||||
/**
|
||||
* A Data Info which plots 1D line data i.e. usually used to plot continuous 1D data e.g. deep learning predictions,
|
||||
* Ishmael data, click trigger data.
|
||||
* A Data Info which plots 1D line data i.e. usually used to plot continuous 1D
|
||||
* data e.g. deep learning predictions, Ishmael data, click trigger data.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
@ -94,7 +95,7 @@ public abstract class GenericLinePlotInfo extends TDDataInfoFX {
|
||||
}
|
||||
|
||||
//use the center of the window for plotting
|
||||
double timeMillis = pamDataUnit.getTimeMilliseconds();
|
||||
double timeMillis = pamDataUnit.getTimeMilliseconds()+pamDataUnit.getDurationInMilliseconds()/2;
|
||||
double tC=tdProjector.getTimePix(timeMillis-scrollStart);
|
||||
|
||||
|
||||
|
@ -149,6 +149,12 @@ public class TDAcousticScroller extends AcousticScrollerFX implements PamSetting
|
||||
|
||||
}
|
||||
|
||||
private void setSpinnerValue(long millis) {
|
||||
spinnerCall = true;
|
||||
spinner.getValueFactory().setValue(millis);
|
||||
spinnerCall = false;
|
||||
}
|
||||
|
||||
boolean spinnerCall = false;
|
||||
|
||||
/**
|
||||
@ -177,9 +183,7 @@ public class TDAcousticScroller extends AcousticScrollerFX implements PamSetting
|
||||
|
||||
//add a listener so that the spinner changes of the visible amount chnages
|
||||
super.getScrollBarPane().visibleAmountProperty().addListener((obsVal, oldVal, newVal)->{
|
||||
spinnerCall = true;
|
||||
spinner.getValueFactory().setValue(newVal.longValue());
|
||||
spinnerCall = false;
|
||||
setSpinnerValue(newVal.longValue());
|
||||
});
|
||||
|
||||
//add a listener so the visible amount changes of the spinner changes value.
|
||||
@ -215,6 +219,13 @@ public class TDAcousticScroller extends AcousticScrollerFX implements PamSetting
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setVisibleMillis(long visibleAmount) {
|
||||
super.setVisibleMillis(visibleAmount);
|
||||
setSpinnerValue(visibleAmount);
|
||||
}
|
||||
|
||||
|
||||
|
||||
private void createHolderPane() {
|
||||
|
||||
@ -624,6 +635,7 @@ public class TDAcousticScroller extends AcousticScrollerFX implements PamSetting
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public String getUnitName() {
|
||||
// TODO Auto-generated method stub
|
||||
|
@ -12,6 +12,7 @@ import javafx.scene.shape.Shape;
|
||||
|
||||
/**
|
||||
* Line which connects a plug or socket to a ConnectionRectangle or other ConnectionLine.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
|
@ -275,7 +275,7 @@ public class DLControl extends PamControlledUnit implements PamSettings {
|
||||
TDDataProviderRegisterFX.getInstance()
|
||||
.registerDataInfo(new DLDetectionPlotProvider(this, dlClassifyProcess.getDLDetectionDatablock()));
|
||||
TDDataProviderRegisterFX.getInstance()
|
||||
.registerDataInfo(new DLPredictionProvider(this, dlClassifyProcess.getDLDetectionDatablock()));
|
||||
.registerDataInfo(new DLPredictionProvider(this, dlClassifyProcess.getDLPredictionDataBlock()));
|
||||
|
||||
// register the DD display
|
||||
DDPlotRegister.getInstance()
|
||||
|
@ -94,7 +94,7 @@ public class DLPredictionPlotInfoFX extends GenericLinePlotInfo {
|
||||
if (getDlControl().getDLModel()!=null) {
|
||||
DLClassName[] classNames = getDlControl().getDLModel().getClassNames();
|
||||
|
||||
System.out.println("Class names are: !!! " + (classNames == null ? "null" : classNames.length));
|
||||
// System.out.println("Class names are: !!! " + (classNames == null ? "null" : classNames.length));
|
||||
|
||||
if (classNames!=null) {
|
||||
|
||||
@ -358,7 +358,7 @@ public class DLPredictionPlotInfoFX extends GenericLinePlotInfo {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import dataPlotsFX.data.TDDataProviderFX;
|
||||
import dataPlotsFX.layout.TDGraphFX;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
import rawDeepLearningClassifier.dlClassification.DLDetectionDataBlock;
|
||||
import rawDeepLearningClassifier.dlClassification.DLModelDataBlock;
|
||||
|
||||
/**
|
||||
* The DL detection plot provider.
|
||||
@ -22,10 +23,10 @@ public class DLPredictionProvider extends TDDataProviderFX {
|
||||
/**
|
||||
* The DL detection plot provider.
|
||||
* @param dlControl - reference to DL control.
|
||||
* @param dlDetectionDataBlock - the dl detection data block.
|
||||
* @param dlModelDataBlock - the dl detection data block.
|
||||
*/
|
||||
public DLPredictionProvider(DLControl dlControl, DLDetectionDataBlock dlDetectionDataBlock) {
|
||||
super(dlDetectionDataBlock);
|
||||
public DLPredictionProvider(DLControl dlControl, DLModelDataBlock dlModelDataBlock) {
|
||||
super(dlModelDataBlock);
|
||||
this.dlControl = dlControl;
|
||||
}
|
||||
|
||||
|
@ -2,12 +2,16 @@ package rawDeepLearningClassifier.dlClassification.animalSpot;
|
||||
|
||||
import java.io.File;
|
||||
import org.controlsfx.control.CheckComboBox;
|
||||
import org.controlsfx.control.IndexedCheckModel;
|
||||
import org.controlsfx.control.PopOver;
|
||||
import org.controlsfx.control.ToggleSwitch;
|
||||
|
||||
import PamController.SettingsPane;
|
||||
import PamUtils.PamArrayUtils;
|
||||
import PamView.dialog.warn.WarnOnce;
|
||||
import javafx.beans.property.SimpleIntegerProperty;
|
||||
import javafx.beans.value.ObservableValue;
|
||||
import javafx.collections.ListChangeListener.Change;
|
||||
import javafx.geometry.Insets;
|
||||
import javafx.geometry.Pos;
|
||||
import javafx.scene.Node;
|
||||
@ -25,6 +29,7 @@ import pamViewFX.fxNodes.PamGridPane;
|
||||
import pamViewFX.fxNodes.PamHBox;
|
||||
import pamViewFX.fxNodes.PamSpinner;
|
||||
import pamViewFX.fxNodes.PamVBox;
|
||||
import pamViewFX.validator.PamValidator;
|
||||
import rawDeepLearningClassifier.dlClassification.DLClassiferModel;
|
||||
|
||||
/**
|
||||
@ -102,6 +107,8 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
|
||||
*/
|
||||
private ProgressIndicator modelLoadIndicator;
|
||||
|
||||
PamValidator validator;
|
||||
|
||||
|
||||
|
||||
public StandardModelPane(DLClassiferModel soundSpotClassifier) {
|
||||
@ -191,10 +198,36 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
|
||||
gridPane.add(new Label(""), 2, 0);
|
||||
speciesIDBox = new CheckComboBox<String>();
|
||||
gridPane.add(speciesIDBox, 3, 0);
|
||||
speciesIDBox.setMaxWidth(100);
|
||||
speciesIDBox.setPrefWidth(100);
|
||||
//speciesIDBox.setMaxWidth(100);
|
||||
// speciesIDBox.setPrefWidth(100);
|
||||
speciesIDBox.prefHeightProperty().bind(detectionSpinner.heightProperty());
|
||||
|
||||
validator = new PamValidator();
|
||||
|
||||
final SimpleIntegerProperty checkItemsCount = new SimpleIntegerProperty();
|
||||
|
||||
speciesIDBox.getCheckModel().getCheckedItems().addListener((Change<? extends String> c)->{
|
||||
checkItemsCount.set(speciesIDBox.getCheckModel().getCheckedItems().size());
|
||||
});
|
||||
|
||||
|
||||
validator.createCheck()
|
||||
.dependsOn("species_box",checkItemsCount)
|
||||
.withMethod(c -> {
|
||||
int nChecked = c.get("species_box");
|
||||
|
||||
if (nChecked==speciesIDBox.getItems().size()) {
|
||||
c.warn("All output class are checked. If one of these classes is noise then PAMGuard will continually detect all sound data...");
|
||||
}
|
||||
|
||||
if (nChecked==0) {
|
||||
c.warn("No output classes are checked for binary classification. PAMGuard will save all prediction values but no detections will be generated");
|
||||
}
|
||||
})
|
||||
.decorates(speciesIDBox)
|
||||
.immediate();
|
||||
;
|
||||
|
||||
vBoxHolder = new PamVBox();
|
||||
vBoxHolder.setSpacing(5);
|
||||
vBoxHolder.getChildren().addAll(classiferInfoLabel, advSettings, classiferInfoLabel2, gridPane);
|
||||
|
@ -14,15 +14,18 @@ import org.jamdev.jdl4pam.ArchiveModel;
|
||||
import org.jamdev.jdl4pam.genericmodel.GenericModelParams;
|
||||
import org.jamdev.jdl4pam.transforms.DLTransform;
|
||||
import org.jamdev.jdl4pam.transforms.DLTransformsFactory;
|
||||
import org.jamdev.jdl4pam.transforms.FreqTransform;
|
||||
import org.jamdev.jdl4pam.transforms.jsonfile.DLTransformParser2;
|
||||
import org.jamdev.jdl4pam.transforms.jsonfile.DLTransformsParser;
|
||||
import org.json.JSONObject;
|
||||
|
||||
import PamUtils.PamArrayUtils;
|
||||
import PamView.dialog.warn.WarnOnce;
|
||||
import ai.djl.MalformedModelException;
|
||||
import rawDeepLearningClassifier.DLControl;
|
||||
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
|
||||
import rawDeepLearningClassifier.dlClassification.genericModel.DLModelWorker;
|
||||
import rawDeepLearningClassifier.dlClassification.genericModel.GenericModelWorker;
|
||||
import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction;
|
||||
|
||||
/**
|
||||
@ -33,7 +36,7 @@ import rawDeepLearningClassifier.dlClassification.genericModel.GenericPrediction
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
public class ArchiveModelWorker extends GenericModelWorker {
|
||||
|
||||
|
||||
/**
|
||||
@ -57,6 +60,7 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
/**
|
||||
* Prepare the model
|
||||
*/
|
||||
@Override
|
||||
public void prepModel(StandardModelParams dlParams, DLControl dlControl) {
|
||||
//ClassLoader origCL = Thread.currentThread().getContextClassLoader();
|
||||
try {
|
||||
@ -97,6 +101,7 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
//read the JSON string from the the file.
|
||||
String jsonString = DLTransformsParser.readJSONString(new File(dlModel.getAudioReprFile()));
|
||||
|
||||
|
||||
//convert the JSON string to a parameters object.
|
||||
GenericModelParams modelParams = makeModelParams( jsonString);
|
||||
|
||||
@ -108,6 +113,7 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
|
||||
//HACK there seems to be some sort of bug in ketos where the params input shape is correct but the model input shape is wrong.
|
||||
if (dlModel.getInputShape()==null || !dlModel.getInputShape().equals(modelParams.defaultInputShape)) {
|
||||
System.out.println("Model input shape: " + modelParams.defaultInputShape);
|
||||
WarnOnce.showWarning("Model shape", "The model shape does not match the model metadata. \n Metadata shape will be used used.", WarnOnce.OK_OPTION);
|
||||
dlModel.setInputShape(modelParams.defaultInputShape);
|
||||
}
|
||||
@ -131,26 +137,35 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
|
||||
//only load new transforms if defaults are selected
|
||||
if (getModelTransforms()==null || dlParams.dlTransfroms==null || dlParams.useDefaultTransfroms) {
|
||||
System.out.println(" " + transforms);
|
||||
//System.out.println("SET MODEL TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
|
||||
|
||||
//only set the transforms if they are null - otherwise handled elsewhere.
|
||||
setModelTransforms(transforms);
|
||||
dlParams.useDefaultTransfroms = true;
|
||||
dlParams.dlTransfroms=transforms;
|
||||
}
|
||||
else {
|
||||
//System.out.println("SET CURRENT TRANSFORMS: " + ketosDLParams.dlTransfroms + " " + ketosDLParams.useDefaultTransfroms);
|
||||
//use the old transforms.
|
||||
setModelTransforms(dlParams.dlTransfroms);
|
||||
}
|
||||
|
||||
//ketosDLParams.dlTransfroms = transforms; //this is done after prep model in the settings pane.
|
||||
dlParams.defaultSegmentLen = modelParams.segLen*1000.; //the segment length in microseconds.
|
||||
//ketosParams.classNames = new String[] {"Noise", "Right Whale"}; // FIXME;
|
||||
//set whether a wave or spectrogram model
|
||||
//this is important for setting the input stack into the model.
|
||||
setWaveFreqModel(dlParams);
|
||||
|
||||
// //enable softmax? - TODO
|
||||
// this.setEnableSoftMax(true);
|
||||
|
||||
//ketosDLParams.dlTransfroms = transforms; //this is done after prep model in the settings pane.
|
||||
dlParams.defaultSegmentLen = modelParams.segLen; //the segment length in microseconds.
|
||||
|
||||
if (modelParams.classNames!=null) {
|
||||
dlParams.numClasses = modelParams.classNames.length;
|
||||
}
|
||||
else {
|
||||
dlParams.numClasses = (int) modelParams.defaultOutputShape.get(1);
|
||||
}
|
||||
|
||||
//ok 0 the other values are not user selectable but this is. If we relaod the same model we probably want to keep it....
|
||||
//So this is a little bt of a hack but will probably be OK in most cases.
|
||||
//So this is a little bit of a hack but will probably be OK in most cases.
|
||||
if (dlParams.binaryClassification==null || dlParams.binaryClassification.length!=dlParams.numClasses) {
|
||||
dlParams.binaryClassification = new boolean[dlParams.numClasses];
|
||||
for (int i=0; i<dlParams.binaryClassification.length; i++) {
|
||||
@ -165,19 +180,6 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
//set the number of class names from the default output shape
|
||||
dlParams.numClasses = (int) modelParams.defaultOutputShape.get(1);
|
||||
}
|
||||
|
||||
// if (dlParams.classNames!=null) {
|
||||
// for (int i = 0; i<dlParams.classNames.length; i++) {
|
||||
// System.out.println("Class name " + i + " " + dlParams.classNames[i]);
|
||||
// }
|
||||
// }
|
||||
// ketosDLParams.classNames = dlControl.getClassNameManager().makeClassNames(ketosParams.classNames);
|
||||
// if (ketosParams.classNames!=null) {
|
||||
// for (int i = 0; i<ketosDLParams.classNames.length; i++) {
|
||||
// System.out.println("Class name " + i + " " + ketosDLParams.classNames[i].className + " ID " + ketosDLParams.classNames[i].ID );
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
catch (Exception e) {
|
||||
dlModel=null;
|
||||
@ -195,13 +197,6 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
* @throws IOException
|
||||
*/
|
||||
public ArchiveModel loadModel(String currentPath2) throws MalformedModelException, IOException {
|
||||
//note the the model should have been check for compatibility beforehand
|
||||
File file = new File(currentPath2);
|
||||
|
||||
String model = getZipFilePath(file, ".py");
|
||||
if (model==null) model = getZipFilePath(file, ".pb");
|
||||
String settings = getZipFilePath(file, ".pdtf");
|
||||
|
||||
return new SimpleArchiveModel(new File(currentPath2));
|
||||
}
|
||||
|
||||
@ -225,11 +220,11 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public float[] runModel(float[][][] transformedDataStack) {
|
||||
System.out.println("Model input: " + transformedDataStack.length + " " + transformedDataStack[0].length + " " + transformedDataStack[0][0].length);
|
||||
return dlModel.runModel(transformedDataStack);
|
||||
}
|
||||
// @Override
|
||||
// public float[] runModel(float[][][] transformedDataStack) {
|
||||
// System.out.println("Model input: " + transformedDataStack.length + " " + transformedDataStack[0].length + " " + transformedDataStack[0][0].length);
|
||||
// return dlModel.runModel(transformedDataStack);
|
||||
// }
|
||||
|
||||
|
||||
@Override
|
||||
@ -252,6 +247,7 @@ public class ArchiveModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
* Get the currently loaded mode.
|
||||
* @return - the currently loaded mode.
|
||||
*/
|
||||
@Override
|
||||
public ArchiveModel getModel() {
|
||||
return dlModel;
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ public class PamZipModelClassifier extends ArchiveModelClassifier {
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the first file within a zip folder that matches a pattern.
|
||||
* Find the first file within a zip folder that matches a pattern. This peaks into the zip file instead of decompressing it.
|
||||
* @param zipUri - uri to the zip file
|
||||
* @param filePattern - the file pattern to match - the file must contain this string.
|
||||
* @return null if no file found and the file pqth if the file is founf
|
||||
|
@ -28,7 +28,7 @@ public class SimpleArchiveModel extends ArchiveModel {
|
||||
@Override
|
||||
public String getAudioReprRelPath(String zipFolder) {
|
||||
try {
|
||||
System.out.println("SETTINGS PATH: " + getRelFilePath(zipFolder, ".pdtf"));
|
||||
//System.out.println("SETTINGS PATH: " + getRelFilePath(zipFolder, ".pdtf"));
|
||||
return getRelFilePath(zipFolder, ".pdtf");
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
@ -43,7 +43,7 @@ public class SimpleArchiveModel extends ArchiveModel {
|
||||
String model = null;
|
||||
model = getRelFilePath(zipFolder, ".pb");
|
||||
if (model==null) model = getRelFilePath(zipFolder, ".py");
|
||||
System.out.println("MODEL PATH: " +model);
|
||||
//System.out.println("MODEL PATH: " +model);
|
||||
return model;
|
||||
|
||||
} catch (IOException e) {
|
||||
@ -64,14 +64,16 @@ public class SimpleArchiveModel extends ArchiveModel {
|
||||
try (Stream<Path> walk = Files.walk(Paths.get(zipFolder))) {
|
||||
List<String> result = walk
|
||||
.filter(p -> !Files.isDirectory(p)) // not a directory
|
||||
.map(p -> p.toString().toLowerCase()) // convert path to string
|
||||
.map(p -> p.toString()) // convert path to string
|
||||
.filter(f -> f.endsWith(fileEnd)) // check end with
|
||||
.collect(Collectors.toList()); // collect all matched to a List
|
||||
|
||||
if (result.size()>0) {
|
||||
String firstFile = result.get(0);
|
||||
|
||||
String relative = new File(zipFolder).toURI().relativize(new File(firstFile).toURI()).getPath();
|
||||
//System.out.println("First file: " +firstFile);
|
||||
|
||||
String relative = firstFile.replace(zipFolder, "");
|
||||
|
||||
return relative;
|
||||
}
|
||||
|
@ -184,7 +184,6 @@ public class GenericAdvPane extends SettingsPane<GenericModelParams> {
|
||||
defualtShapeSwitch = new PamToggleSwitch("Use model default shape");
|
||||
defualtShapeSwitch.selectedProperty().addListener((obsval, oldval, newval)->{
|
||||
|
||||
System.out.println("Hello: deafult shape: " + newval);
|
||||
PamArrayUtils.printArray( currentInput.defaultShape);
|
||||
|
||||
if (newval) {
|
||||
|
@ -121,13 +121,12 @@ public class GenericDLClassifier extends StandardClassifierModel {
|
||||
|
||||
genericModelParams.dlTransfromParams=dlTransformParams;
|
||||
|
||||
if (genericModelParams.dlTransfromParams!=null) {
|
||||
System.out.println("Generic settings have been saved. : " + genericModelParams.dlTransfromParams.size());
|
||||
}
|
||||
else {
|
||||
System.out.println("Generic settings have been saved. : " + null);
|
||||
|
||||
}
|
||||
// if (genericModelParams.dlTransfromParams!=null) {
|
||||
// System.out.println("Generic settings have been saved. : " + genericModelParams.dlTransfromParams.size());
|
||||
// }
|
||||
// else {
|
||||
// System.out.println("Generic settings have been saved. : " + null);
|
||||
// }
|
||||
|
||||
return genericModelParams;
|
||||
}
|
||||
@ -211,6 +210,8 @@ public class GenericDLClassifier extends StandardClassifierModel {
|
||||
*/
|
||||
public static boolean isBinaryResult(GenericPrediction modelResult, StandardModelParams genericModelParams) {
|
||||
for (int i=0; i<modelResult.getPrediction().length; i++) {
|
||||
//System.out.println("Binary Classification: " + genericModelParams.binaryClassification.length);
|
||||
|
||||
if (modelResult.getPrediction()[i]>genericModelParams.threshold && genericModelParams.binaryClassification[i]) {
|
||||
// System.out.println("SoundSpotClassifier: prediciton: " + i + " passed threshold with val: " + modelResult.getPrediction()[i]);
|
||||
return true;
|
||||
|
@ -36,16 +36,18 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
// System.out.println("RUN GENERIC MODEL: " + transformedDataStack[0][0][0]);
|
||||
float[] results;
|
||||
if (freqTransform)
|
||||
results = genericModel.runModel(transformedDataStack);
|
||||
results = getModel().runModel(transformedDataStack);
|
||||
else {
|
||||
//run a model if it is waveform info.
|
||||
float[][] waveStack = new float[transformedDataStack.length][];
|
||||
for (int i=0; i<waveStack.length; i++) {
|
||||
waveStack[i] = transformedDataStack[i][0];
|
||||
}
|
||||
results = genericModel.runModel(waveStack);
|
||||
|
||||
//System.out.println("RUN GENERIC MODEL WAVE: " + transformedDataStack.length + " " + transformedDataStack[0].length);
|
||||
results = getModel().runModel(waveStack);
|
||||
}
|
||||
System.out.println("GENERIC MODEL RESULTS: " + results== null ? null : results.length);
|
||||
//System.out.println("GENERIC MODEL RESULTS: " + results== null ? null : results.length);
|
||||
return results;
|
||||
}
|
||||
|
||||
@ -97,13 +99,7 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
setModelTransforms(genericParams.dlTransfroms);
|
||||
|
||||
//is this a waveform or a spectrogram model?
|
||||
DLTransform transform = genericParams.dlTransfroms.get(genericParams.dlTransfroms.size()-1);
|
||||
if (transform instanceof FreqTransform) {
|
||||
freqTransform = true;
|
||||
}
|
||||
else {
|
||||
freqTransform = false;
|
||||
}
|
||||
setWaveFreqModel(genericParams);
|
||||
|
||||
//use softmax or not?
|
||||
String extension = FilenameUtils.getExtension(genericParams.modelPath);
|
||||
@ -131,6 +127,18 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
//Thread.currentThread().setContextClassLoader(origCL);
|
||||
}
|
||||
|
||||
|
||||
protected void setWaveFreqModel(StandardModelParams genericParams) {
|
||||
//is this a waveform or a spectrogram model?
|
||||
DLTransform transform = genericParams.dlTransfroms.get(genericParams.dlTransfroms.size()-1);
|
||||
if (transform instanceof FreqTransform) {
|
||||
freqTransform = true;
|
||||
}
|
||||
else {
|
||||
freqTransform = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a long[] array to a Long[] array.
|
||||
* @param inArr - the input long[] array.
|
||||
|
@ -32,6 +32,8 @@ public class KetosWorker2 extends ArchiveModelWorker {
|
||||
public GenericModelParams makeModelParams(String jsonString) {
|
||||
//ketos parameters are non standard and need a bit of extra work to get right.
|
||||
//This also deal with legacy paramters.
|
||||
return new KetosParams(jsonString);
|
||||
KetosParams params = new KetosParams(jsonString);
|
||||
// System.out.println(params);
|
||||
return params;
|
||||
}
|
||||
}
|
||||
|
@ -377,6 +377,7 @@ public class DLModelSelectPane extends PamBorderPane {
|
||||
//we are loading model from a file - anything can happen so put in a try catch.
|
||||
DLStatus status = currentClassifierModel.setModel(file);
|
||||
|
||||
|
||||
if (status.isError()) {
|
||||
System.err.println("Model load failed: " + currentClassifierModel.getModelStatus());
|
||||
currentClassifierModel=null;
|
||||
@ -437,7 +438,7 @@ public class DLModelSelectPane extends PamBorderPane {
|
||||
pathLabel.setText("No classifier model loaded: Select model");
|
||||
pathLabel.setTooltip(new Tooltip("Use the browse button/ URI botton to select a model or select a default model"));
|
||||
}
|
||||
if (currentClassifierModel.getModelStatus().isError()) {
|
||||
else if (currentClassifierModel.getModelStatus().isError()) {
|
||||
pathLabel.setGraphic(createErrorIcon(currentClassifierModel.getModelStatus()));
|
||||
pathLabel.setText(currentClassifierModel.getModelStatus().getName());
|
||||
pathLabel.setTooltip(new Tooltip(currentClassifierModel.getModelStatus().getDescription()));
|
||||
|
@ -20,8 +20,10 @@ import javafx.scene.control.Label;
|
||||
import javafx.scene.control.Spinner;
|
||||
import javafx.scene.control.Tooltip;
|
||||
import javafx.scene.control.Alert.AlertType;
|
||||
import javafx.scene.layout.ColumnConstraints;
|
||||
import javafx.scene.layout.HBox;
|
||||
import javafx.scene.layout.Pane;
|
||||
import javafx.scene.layout.Priority;
|
||||
import pamViewFX.PamGuiManagerFX;
|
||||
import pamViewFX.fxGlyphs.PamGlyphDude;
|
||||
import pamViewFX.fxNodes.PamBorderPane;
|
||||
@ -208,6 +210,14 @@ public class DLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
});
|
||||
|
||||
PamGridPane segmenterGridPane = new PamGridPane();
|
||||
segmenterGridPane.setHgap(5);
|
||||
|
||||
ColumnConstraints col1 = new ColumnConstraints();
|
||||
col1.setHgrow( Priority.SOMETIMES );
|
||||
ColumnConstraints col2 = new ColumnConstraints();
|
||||
col2.setHgrow( Priority.ALWAYS );
|
||||
segmenterGridPane.getColumnConstraints().addAll(col1, col2 );
|
||||
|
||||
segmenterGridPane.add(new Label("Window length"), 0, 0);
|
||||
segmenterGridPane.add(windowLength, 1, 0);
|
||||
segmenterGridPane.add(new Label("samples"), 2, 0);
|
||||
|
@ -31,6 +31,8 @@ public class DLDataUnitDatagram implements DatagramProvider {
|
||||
public int addDatagramData(PamDataUnit dataUnit, float[] dataGramLine) {
|
||||
DLDataUnit dlDataUnit = (DLDataUnit) dataUnit;
|
||||
|
||||
if (dataGramLine==null || dataGramLine.length==0) return 0;
|
||||
|
||||
if (dlDataUnit.getPredicitionResult().getPrediction()!=null) {
|
||||
for (int i=0; i<dlDataUnit.getPredicitionResult().getPrediction().length; i++) {
|
||||
dataGramLine[i] += (float) dlDataUnit.getPredicitionResult().getPrediction()[i];
|
||||
|
Loading…
Reference in New Issue
Block a user