mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-22 07:02:29 +00:00
Google humpback whale deep learning classifier
Google's humpback whale deep learning classifier can now be imported. Updated TDisplayFX to make the data selection panes cleaner and clearer. Updated the TDisplayFX so that predicitons from deep learning models now have some display options e.g. changing colour.
This commit is contained in:
parent
eabc4bc3f0
commit
53b8ca9479
@ -6,7 +6,7 @@
|
||||
<attribute name="maven.pomderived" value="true"/>
|
||||
</attributes>
|
||||
</classpathentry>
|
||||
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-16">
|
||||
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/Java 17">
|
||||
<attributes>
|
||||
<attribute name="module" value="true"/>
|
||||
<attribute name="maven.pomderived" value="true"/>
|
||||
|
2
pom.xml
2
pom.xml
@ -329,7 +329,7 @@
|
||||
<dependency>
|
||||
<groupId>io.github.macster110</groupId>
|
||||
<artifactId>jdl4pam</artifactId>
|
||||
<version>0.0.93</version>
|
||||
<version>0.0.94</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/gov.nist.math/jama -->
|
||||
|
@ -185,13 +185,14 @@ public class TDControlPaneFX extends Pane {
|
||||
private Node createGraphAddPane(){
|
||||
|
||||
//create button to add a graph
|
||||
Button addGraph=new PamButton();
|
||||
Button addGraph=new PamButton("Plot");
|
||||
// addGraph.setGraphic(PamGlyphDude.createPamGlyph(MaterialIcon.ADD, Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
addGraph.setGraphic(PamGlyphDude.createPamIcon("mdi2p-plus", Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
addGraph.setOnAction(new AddButton());
|
||||
|
||||
//create a menu button to remove graphs
|
||||
SplitMenuButton removeGraph=new SplitMenuButton();
|
||||
removeGraph.setText("Plot");
|
||||
// removeGraph.setGraphic(PamGlyphDude.createPamGlyph(MaterialIcon.REMOVE, Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
removeGraph.setGraphic(PamGlyphDude.createPamIcon("mdi2m-minus", Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
removeGraph.setOnAction(new RemoveLastGraph());
|
||||
|
@ -13,6 +13,7 @@ import pamViewFX.fxNodes.PamGridPane;
|
||||
import pamViewFX.fxNodes.PamHBox;
|
||||
import pamViewFX.fxNodes.PamTilePane;
|
||||
import pamViewFX.fxNodes.PamVBox;
|
||||
import pamViewFX.fxNodes.utilityPanes.PamToggleSwitch;
|
||||
import pamViewFX.fxNodes.utilsFX.PamUtilsFX;
|
||||
import javafx.beans.value.ChangeListener;
|
||||
import javafx.beans.value.ObservableValue;
|
||||
@ -28,6 +29,7 @@ import javafx.scene.control.Separator;
|
||||
import javafx.scene.control.TextField;
|
||||
import javafx.scene.control.Toggle;
|
||||
import javafx.scene.control.ToggleGroup;
|
||||
import javafx.scene.control.Tooltip;
|
||||
import javafx.scene.layout.Pane;
|
||||
import javafx.scene.paint.Color;
|
||||
import javafx.util.Duration;
|
||||
@ -128,20 +130,20 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
this.getChildren().clear();
|
||||
|
||||
//the data control pane.
|
||||
Label addRemoveLabel = new Label("Display Data");
|
||||
PamGuiManagerFX.titleFont2style(addRemoveLabel);
|
||||
//Label addRemoveLabel = new Label("Display Data");
|
||||
//PamGuiManagerFX.titleFont2style(addRemoveLabel);
|
||||
// addRemoveLabel.setFont(PamGuiManagerFX.titleFontSize2);
|
||||
|
||||
this.getChildren().add(addRemoveLabel);
|
||||
//this.getChildren().add(addRemoveLabel);
|
||||
this.getChildren().add(dataControlPane);
|
||||
|
||||
if (tdGraph.getDataList().size()>0) {
|
||||
Label yAxisLabel = new Label("Y-Axis");
|
||||
PamGuiManagerFX.titleFont2style(yAxisLabel);
|
||||
//Label yAxisLabel = new Label("Y-Axis");
|
||||
//PamGuiManagerFX.titleFont2style(yAxisLabel);
|
||||
// yAxisLabel.setFont(PamGuiManagerFX.titleFontSize2);
|
||||
//this.getChildren().add(new Separator());
|
||||
|
||||
this.getChildren().add(yAxisLabel);
|
||||
//this.getChildren().add(yAxisLabel);
|
||||
this.getChildren().add(createYAxisDataList());
|
||||
this.getChildren().add(axisMinMax);
|
||||
this.getChildren().add(new Separator());
|
||||
@ -160,9 +162,9 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
|
||||
PamHBox controlPane = new PamHBox();
|
||||
controlPane.setSpacing(5);
|
||||
controlPane.setAlignment(Pos.CENTER);
|
||||
controlPane.setAlignment(Pos.CENTER_LEFT);
|
||||
|
||||
addMenuButton = new MenuButton("Add");
|
||||
addMenuButton = new MenuButton("Data");
|
||||
//addMenuButton.setPrefHeight(PamGuiManagerFX.iconSize);
|
||||
// addMenuButton.setGraphic(PamGlyphDude.createPamGlyph(MaterialIcon.ADD, Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
addMenuButton.setGraphic(PamGlyphDude.createPamIcon("mdi2p-plus", Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
@ -172,7 +174,7 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
});
|
||||
TDControlPaneFX.createAddMenuItems(addMenuButton.getItems(), this.tdGraph);
|
||||
|
||||
removeMenuButton = new MenuButton("Remove");
|
||||
removeMenuButton = new MenuButton("Data");
|
||||
//removeMenuButton.setPrefHeight(PamGuiManagerFX.iconSize);
|
||||
// removeMenuButton.setGraphic(PamGlyphDude.createPamGlyph(MaterialIcon.REMOVE, Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
removeMenuButton.setGraphic(PamGlyphDude.createPamIcon("mdi2m-minus", Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
@ -188,6 +190,7 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
// button.setGraphic(PamGlyphDude.createPamGlyph(FontAwesomeIcon.COGS, Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
button.setGraphic(PamGlyphDude.createPamIcon("mdi2c-cogs", Color.WHITE, PamGuiManagerFX.iconSize));
|
||||
controlPane.getChildren().addAll(addMenuButton, removeMenuButton, button);
|
||||
button.prefHeightProperty().bind(removeMenuButton.heightProperty());
|
||||
|
||||
button.setOnAction((action)->{
|
||||
showPopMenu(button);
|
||||
@ -459,9 +462,9 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
vBox.setPadding(new Insets(10,10,10,10));
|
||||
vBox.setSpacing(10);
|
||||
|
||||
CheckBox dataCB;
|
||||
PamToggleSwitch dataCB;
|
||||
for (int i=0; i<tdGraph.getDataList().size(); i++){
|
||||
dataCB = new CheckBox(tdGraph.getDataList().get(i).getDataName());
|
||||
dataCB = new PamToggleSwitch(tdGraph.getDataList().get(i).getDataName());
|
||||
dataCB.setSelected(tdGraph.getDataList().get(i).isShowing());
|
||||
final int list=i; //need a final variable for listener.
|
||||
dataCB.selectedProperty().addListener(new ChangeListener<Boolean>() {
|
||||
@ -471,6 +474,7 @@ public class TDDataSelPaneFX extends PamVBox {
|
||||
tdGraph.repaint(0);
|
||||
}
|
||||
});
|
||||
dataCB.setTooltip(new Tooltip(tdGraph.getDataList().get(i).getDataName()));
|
||||
vBox.getChildren().add(dataCB);
|
||||
}
|
||||
return vBox;
|
||||
|
@ -121,7 +121,7 @@ public class RawClipDataInfo extends GenericDataPlotInfo {
|
||||
tdProjector.addHoverData(new HoverData(path2D, pamDataUnit, 0, plotNumber));
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,46 @@
|
||||
package rawDeepLearningClassifier.dataPlotFX;
|
||||
|
||||
import java.io.Serializable;
|
||||
import PamModel.parametermanager.ManagedParameters;
|
||||
import PamModel.parametermanager.PamParameterSet;
|
||||
import javafx.scene.paint.Color;
|
||||
|
||||
/**
|
||||
* The parameters for the deep learning prediction plot.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
public class DLPredDisplayParams implements Serializable, Cloneable, ManagedParameters {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* The line infos.
|
||||
*/
|
||||
public LineInfo[] lineInfos;
|
||||
|
||||
@Override
|
||||
public PamParameterSet getParameterSet() {
|
||||
PamParameterSet ps = PamParameterSet.autoGenerate(this);
|
||||
return ps;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see java.lang.Object#clone()
|
||||
*/
|
||||
@Override
|
||||
protected DLPredDisplayParams clone() {
|
||||
try {
|
||||
return (DLPredDisplayParams) super.clone();
|
||||
} catch (CloneNotSupportedException e) {
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
226
src/rawDeepLearningClassifier/dataPlotFX/DLPredictionPane.java
Normal file
226
src/rawDeepLearningClassifier/dataPlotFX/DLPredictionPane.java
Normal file
@ -0,0 +1,226 @@
|
||||
package rawDeepLearningClassifier.dataPlotFX;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
import dataPlotsFX.layout.TDSettingsPane;
|
||||
import javafx.geometry.Insets;
|
||||
import javafx.geometry.Pos;
|
||||
import javafx.scene.Node;
|
||||
import javafx.scene.control.ColorPicker;
|
||||
import javafx.scene.control.Label;
|
||||
import javafx.scene.layout.Pane;
|
||||
import javafx.scene.paint.Color;
|
||||
import pamViewFX.fxGlyphs.PamSVGIcon;
|
||||
import pamViewFX.fxNodes.PamBorderPane;
|
||||
import pamViewFX.fxNodes.PamHBox;
|
||||
import pamViewFX.fxNodes.PamVBox;
|
||||
import pamViewFX.fxNodes.utilityPanes.PamToggleSwitch;
|
||||
import rawDeepLearningClassifier.dlClassification.DLClassName;
|
||||
|
||||
/*
|
||||
* Symbol Options for the annotation pane
|
||||
*/
|
||||
public class DLPredictionPane extends PamBorderPane implements TDSettingsPane {
|
||||
|
||||
/**
|
||||
* The main holder pane.
|
||||
*/
|
||||
private PamBorderPane mainPane;
|
||||
|
||||
/**
|
||||
* Make the icon.
|
||||
*/
|
||||
private Node icon = makeIcon();
|
||||
//private Node icon2 = makeIcon();
|
||||
|
||||
/**
|
||||
* Holds the panes with controls to change prediction colours for each class.
|
||||
*/
|
||||
private ArrayList<PredictionColourPane> colourPanes = new ArrayList<PredictionColourPane>();
|
||||
|
||||
/**
|
||||
* Reference to the prediction data info
|
||||
*/
|
||||
private DLPredictionPlotInfoFX dlPredictionPlotInfoFX;
|
||||
|
||||
/**
|
||||
* Holds all the prediction colour controls.
|
||||
*/
|
||||
private PamVBox predColHolder;
|
||||
|
||||
|
||||
public DLPredictionPane(DLPredictionPlotInfoFX dlPredictionPlotInfoFX) {
|
||||
this.dlPredictionPlotInfoFX=dlPredictionPlotInfoFX;
|
||||
mainPane = new PamBorderPane();
|
||||
mainPane.setCenter(predColHolder = new PamVBox());
|
||||
predColHolder.setSpacing(5);
|
||||
predColHolder.setPadding(new Insets(5,0,0,0));
|
||||
|
||||
if (dlPredictionPlotInfoFX.getDlControl().getDLModel()!=null) {
|
||||
layoutColourPanes( dlPredictionPlotInfoFX.getDlControl().getDLModel().getClassNames());
|
||||
}
|
||||
|
||||
this.getParams();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Node getHidingIcon() {
|
||||
return icon;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getShowingName() {
|
||||
return "DL Prediction";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Node getShowingIcon() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pane getPane() {
|
||||
return mainPane;
|
||||
}
|
||||
|
||||
public void setParams() {
|
||||
|
||||
if (dlPredictionPlotInfoFX.getDlControl().getDLModel()!=null) {
|
||||
//populate the prediction pane.
|
||||
DLClassName[] classNames = dlPredictionPlotInfoFX.getDlControl().getDLModel().getClassNames();
|
||||
|
||||
layoutColourPanes(classNames);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void layoutColourPanes(DLClassName[] classNames){
|
||||
|
||||
//System.out.println("Class name map: " + dlPredictionPlotInfoFX.getDlControl().getDLParams().classNameMap);
|
||||
|
||||
ArrayList<PredictionColourPane> colourPanes = new ArrayList<PredictionColourPane>();
|
||||
|
||||
predColHolder.getChildren().clear();
|
||||
|
||||
for (int i=0; i<classNames.length; i++) {
|
||||
if (i<colourPanes.size()-1 && classNames[i].className.equals(this.colourPanes.get(i).getName())) {
|
||||
//probably the same
|
||||
colourPanes.add(this.colourPanes.get(i));
|
||||
}
|
||||
colourPanes.add(new PredictionColourPane(classNames[i].className, true, Color.BLACK));
|
||||
predColHolder.getChildren().add(colourPanes.get(i));
|
||||
}
|
||||
|
||||
this.colourPanes=colourPanes;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the parameters.
|
||||
*/
|
||||
private void getParams() {
|
||||
this.dlPredictionPlotInfoFX.getDlPredParams().lineInfos = new LineInfo[colourPanes.size()];
|
||||
for (int i=0; i<colourPanes.size(); i++) {
|
||||
this.dlPredictionPlotInfoFX.getDlPredParams().lineInfos[i] = new LineInfo(
|
||||
colourPanes.get(i).toggleSwitch.isSelected(), colourPanes.get(i).colourPicker.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* There are new settings. Repaints the graph.
|
||||
*/
|
||||
private void newSettings() {
|
||||
newSettings(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* There are new settings. Repaints the graph.
|
||||
* @param milliswait
|
||||
*/
|
||||
private void newSettings(long milliswait) {
|
||||
getParams();
|
||||
this.dlPredictionPlotInfoFX.getTDGraph().repaint(milliswait);
|
||||
}
|
||||
|
||||
private Node makeIcon() {
|
||||
String resourcePath = "/Resources/modules/noun_Deep Learning_2486374.svg";
|
||||
try {
|
||||
PamSVGIcon iconMaker= new PamSVGIcon();
|
||||
//PamSVGIcon svgsprite = iconMaker.create(new File(getClass().getResource(resourcePath).toURI()), Color.WHITE);
|
||||
PamSVGIcon svgsprite = iconMaker.create(getClass().getResource(resourcePath).toURI().toURL(), Color.DODGERBLUE);
|
||||
|
||||
// svgsprite.getSpriteNode().setStyle("-fx-text-color: white");
|
||||
// svgsprite.getSpriteNode().setStyle("-fx-fill: white");
|
||||
svgsprite.setFitHeight(20);
|
||||
svgsprite.setFitWidth(20);
|
||||
return svgsprite.getSpriteNode();
|
||||
}
|
||||
catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Pane for
|
||||
* @author au671271
|
||||
*
|
||||
*/
|
||||
private class PredictionColourPane extends PamHBox {
|
||||
|
||||
/**
|
||||
* The toggle switch
|
||||
*/
|
||||
private PamToggleSwitch toggleSwitch;
|
||||
|
||||
/**
|
||||
* The colour picker
|
||||
*/
|
||||
private ColorPicker colourPicker;
|
||||
|
||||
private String name;
|
||||
|
||||
public PredictionColourPane(String name, boolean enabled, Color color) {
|
||||
|
||||
this.setSpacing(5);
|
||||
this.setAlignment(Pos.CENTER_LEFT);
|
||||
|
||||
this.toggleSwitch = new PamToggleSwitch(name);
|
||||
this.toggleSwitch.selectedProperty().addListener((obsVal, oldVal, newVal)->{
|
||||
newSettings();
|
||||
});
|
||||
toggleSwitch.setPrefWidth(120);
|
||||
|
||||
this.colourPicker = new ColorPicker();
|
||||
this.colourPicker.valueProperty().addListener((obsVal, oldVal, newval)->{
|
||||
newSettings();
|
||||
});
|
||||
|
||||
this.colourPicker.setPrefWidth(80);
|
||||
|
||||
|
||||
this.name =name;
|
||||
|
||||
this.getChildren().addAll(toggleSwitch, colourPicker);
|
||||
|
||||
|
||||
setParams(name, enabled, color);
|
||||
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setParams(String name, boolean enabled, Color color) {
|
||||
this.toggleSwitch.setLabel(new Label(name));
|
||||
this.toggleSwitch.setSelected(enabled);
|
||||
this.colourPicker.setValue(color);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package rawDeepLearningClassifier.dataPlotFX;
|
||||
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import PamController.PamController;
|
||||
import PamUtils.Coordinate3d;
|
||||
import PamView.GeneralProjector;
|
||||
@ -13,10 +15,12 @@ import PamguardMVC.PamDataBlock;
|
||||
import PamguardMVC.PamDataUnit;
|
||||
import dataPlotsFX.TDManagedSymbolChooserFX;
|
||||
import dataPlotsFX.TDSymbolChooserFX;
|
||||
import dataPlotsFX.clickPlotFX.ClickDisplayParams;
|
||||
import dataPlotsFX.data.TDDataInfoFX;
|
||||
import dataPlotsFX.data.TDDataProviderFX;
|
||||
import dataPlotsFX.data.generic.GenericScaleInfo;
|
||||
import dataPlotsFX.layout.TDGraphFX;
|
||||
import dataPlotsFX.layout.TDSettingsPane;
|
||||
import dataPlotsFX.projector.TDProjectorFX;
|
||||
import javafx.geometry.Point2D;
|
||||
import javafx.scene.canvas.GraphicsContext;
|
||||
@ -28,7 +32,7 @@ import rawDeepLearningClassifier.dlClassification.DLDataUnit;
|
||||
/**
|
||||
* Plot the raw probability information.
|
||||
*
|
||||
* Plots multiple probability lines. Note that the deulat TDDataInfo can handle drawing lines for different channels
|
||||
* Plots multiple probability lines. Note that the TDDataInfo can handle drawing lines for different channels
|
||||
* but we are drawing for different prediction classes here.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
@ -54,11 +58,27 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
*/
|
||||
private GenericScaleInfo frequencyInfo;
|
||||
|
||||
/**
|
||||
* The default colour.
|
||||
*/
|
||||
//private Color color = Color.DODGERBLUE;
|
||||
|
||||
Color color = Color.DODGERBLUE;
|
||||
|
||||
/**
|
||||
* DL control.
|
||||
*/
|
||||
private DLControl dlControl;
|
||||
|
||||
/**
|
||||
* The DL prediction pane.
|
||||
*/
|
||||
private DLPredictionPane predictionSettingsPane;
|
||||
|
||||
/**
|
||||
* The display parameters.
|
||||
*/
|
||||
private DLPredDisplayParams dlPredParams = new DLPredDisplayParams();
|
||||
|
||||
|
||||
public DLPredictionPlotInfoFX(TDDataProviderFX tdDataProvider, DLControl dlContorl, TDGraphFX tdGraph, PamDataBlock pamDataBlock) {
|
||||
super(tdDataProvider, tdGraph, pamDataBlock);
|
||||
|
||||
@ -86,12 +106,13 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
private Polygon drawFrequencyData(int plotNumber, PamDataUnit pamDataUnit, GraphicsContext g, double scrollStart,
|
||||
TDProjectorFX tdProjector, int type) {
|
||||
|
||||
|
||||
Color color = getColor(0).color;
|
||||
|
||||
DLDataUnit dataUnit = (DLDataUnit) pamDataUnit;
|
||||
|
||||
|
||||
g.setStroke(color);
|
||||
g.setLineDashes(null);
|
||||
|
||||
g.setStroke(color);
|
||||
|
||||
// double[] f = pamDataUnit.getFrequency();
|
||||
// if (f == null) {
|
||||
@ -102,7 +123,21 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
// }
|
||||
|
||||
if (dataUnit.getPredicitionResult().isBinaryClassification()) {
|
||||
g.setFill(Color.color(this.color.getRed(), this.color.getGreen(), this.color.getBlue(),
|
||||
|
||||
//now which prediction is the highest.
|
||||
int index = -1;
|
||||
double max = Double.NEGATIVE_INFINITY;
|
||||
for (int i=0; i<dataUnit.getPredicitionResult().getPrediction().length; i++) {
|
||||
if (dataUnit.getPredicitionResult().getPrediction()[i]>max) {
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
//color for the highest index.
|
||||
Color predCol = getColor(index).color;
|
||||
|
||||
//prediciton has been classified so use a fill.
|
||||
g.setFill(Color.color(predCol.getRed(), predCol.getGreen(), predCol.getBlue(),
|
||||
(Math.min(0.7, dlControl.getDLParams().sampleHop/(double) dlControl.getDLParams().rawSampleSize))));
|
||||
// float[] prediciton = dataUnit.getPredicitionResult().getPrediction();
|
||||
// float max = -(Float.MAX_VALUE+1);
|
||||
@ -201,41 +236,49 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
if (tC < -1000 || tC>tdProjector.getWidth()+1000) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
//TODO -must sort out wrap
|
||||
//dlControl.getDLParams().sampleHop;
|
||||
|
||||
double dataPixel;
|
||||
Coordinate3d c;
|
||||
Color color = Color.color(this.color.getRed(), this.color.getGreen(), this.color.getBlue());
|
||||
Color color;
|
||||
for (int i=0; i<dataUnit.getPredicitionResult().getPrediction().length; i++) {
|
||||
|
||||
g.setStroke(color);
|
||||
color = Color.color(color.getRed()*0.8, color.getGreen()*0.8, color.getBlue()*0.8);
|
||||
if (getColor(i).enabled) {
|
||||
color = getColor(i).color;
|
||||
|
||||
// System.out.println("TDDataInfoFX: tc: "+tC+ " dataUnitTime: "+PamCalendar.formatTime(timeMillis)+" scrollStart: "
|
||||
// +PamCalendar.formatTime((long) scrollStart)+" (timeMillis-scrollStart)/1000. "+((timeMillis-scrollStart)/1000.));
|
||||
g.setStroke(color);
|
||||
g.setFill(color);
|
||||
|
||||
//brighten the colour up.
|
||||
//color = Color.color(color.getRed()*0.8, color.getGreen()*0.8, color.getBlue()*0.8);
|
||||
|
||||
//System.out.println("TDDataInfoFX: tc: "+tC+ " dataUnitTime: "+PamCalendar.formatTime(timeMillis)+" scrollStart: "
|
||||
//+PamCalendar.formatTime((long) scrollStart)+" (timeMillis-scrollStart)/1000. "+((timeMillis-scrollStart)/1000.));
|
||||
|
||||
|
||||
c = tdProjector.getCoord3d(timeMillis, dataUnit.getPredicitionResult().getPrediction()[i], 0);
|
||||
c = tdProjector.getCoord3d(timeMillis, dataUnit.getPredicitionResult().getPrediction()[i], 0);
|
||||
|
||||
dataPixel = tdProjector.getYPix(dataUnit.getPredicitionResult().getPrediction()[i]);
|
||||
dataPixel = tdProjector.getYPix(dataUnit.getPredicitionResult().getPrediction()[i]);
|
||||
|
||||
|
||||
if (lastUnits[i]==null) {
|
||||
lastUnits[i] = new Point2D(tC, dataPixel);
|
||||
return null;
|
||||
}
|
||||
else {
|
||||
if (tC>lastUnits[i].getX()) {
|
||||
//System.out.println("tC: " + tC + " lastUnits[i].getX(): " + lastUnits[i].getX());
|
||||
g.strokeLine(tC, dataPixel, lastUnits[i].getX(), lastUnits[i].getY());
|
||||
if (lastUnits[i]==null) {
|
||||
lastUnits[i] = new Point2D(tC, dataPixel);
|
||||
g.fillOval(tC, dataPixel, 5,5);
|
||||
return null;
|
||||
}
|
||||
else {
|
||||
if (tC>lastUnits[i].getX()) {
|
||||
//System.out.println("tC: " + tC + " lastUnits[i].getX(): " + lastUnits[i].getX());
|
||||
g.strokeLine(tC, dataPixel, lastUnits[i].getX(), lastUnits[i].getY());
|
||||
}
|
||||
lastUnits[i] = new Point2D(tC, dataPixel);
|
||||
}
|
||||
lastUnits[i] = new Point2D(tC, dataPixel);
|
||||
}
|
||||
|
||||
//getSymbolChooser().getPamSymbol(pamDataUnit,type).draw(g, new Point2D(tC, dataPixel));
|
||||
tdProjector.addHoverData(new HoverData(c , pamDataUnit, 0, plotNumber));
|
||||
//getSymbolChooser().getPamSymbol(pamDataUnit,type).draw(g, new Point2D(tC, dataPixel));
|
||||
tdProjector.addHoverData(new HoverData(c , pamDataUnit, 0, plotNumber));
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
@ -267,7 +310,6 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
public void notifyChange(int changeType) {
|
||||
//System.out.println("Prediction NOTIFYMODELCHANGED: ");
|
||||
switch (changeType) {
|
||||
|
||||
case PamController.CHANGED_PROCESS_SETTINGS:
|
||||
lastUnits = null;
|
||||
break;
|
||||
@ -276,12 +318,72 @@ public class DLPredictionPlotInfoFX extends TDDataInfoFX {
|
||||
break;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Get the color.
|
||||
* @param i - the prediction index
|
||||
* @return the color for that prediciton
|
||||
*/
|
||||
public LineInfo getColor(int i) {
|
||||
return this.dlPredParams.lineInfos[i];
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Double getDataValue(PamDataUnit pamDataUnit) {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see dataPlots.data.TDDataInfo#getHidingDialogComponent()
|
||||
*/
|
||||
@Override
|
||||
public TDSettingsPane getGraphSettingsPane() {
|
||||
if (predictionSettingsPane==null) {
|
||||
predictionSettingsPane = new DLPredictionPane(this);
|
||||
}
|
||||
return predictionSettingsPane;
|
||||
}
|
||||
|
||||
|
||||
public DLControl getDlControl() {
|
||||
return this.dlControl;
|
||||
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see dataPlots.data.TDDataInfo#getStoredSettings()
|
||||
*/
|
||||
@Override
|
||||
public Serializable getStoredSettings() {
|
||||
return dlPredParams;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see dataPlots.data.TDDataInfo#setStoredSettings(java.io.Serializable)
|
||||
*/
|
||||
@Override
|
||||
public boolean setStoredSettings(Serializable storedSettings) {
|
||||
if (ClickDisplayParams.class.isAssignableFrom(storedSettings.getClass())) {
|
||||
dlPredParams = (DLPredDisplayParams) storedSettings;
|
||||
updateSettings();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the DL prediction params.
|
||||
* @return the params
|
||||
*/
|
||||
public DLPredDisplayParams getDlPredParams() {
|
||||
return dlPredParams;
|
||||
}
|
||||
|
||||
|
||||
|
||||
private void updateSettings() {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -35,6 +35,6 @@ public class DLPredictionProvider extends TDDataProviderFX {
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return "Prediciton probability, " + dlControl.getUnitName();
|
||||
return "Prediction probability, " + dlControl.getUnitName();
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ import rawDeepLearningClassifier.logging.DLAnnotationType;
|
||||
|
||||
|
||||
/**
|
||||
* The DL symbol modifier. Colours symbols by prediciton.
|
||||
* The DL symbol modifier. Colours symbols by prediction.
|
||||
*
|
||||
* @author Jamie Macaulay.
|
||||
*
|
||||
|
@ -20,7 +20,7 @@ import rawDeepLearningClassifier.dlClassification.DLClassName;
|
||||
|
||||
|
||||
/*
|
||||
* Symbol Options for the annotaion pane
|
||||
* Symbol Options for the annotation pane
|
||||
*/
|
||||
public class DLSymbolOptionPane extends StandardSymbolModifierPane {
|
||||
|
||||
@ -55,11 +55,8 @@ public class DLSymbolOptionPane extends StandardSymbolModifierPane {
|
||||
this.setBottom(createProbPane());
|
||||
this.dlSymbolModifier = (DLSymbolModifier) symbolModifer;
|
||||
initialised=true;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
private Pane createProbPane() {
|
||||
|
||||
PamVBox holder = new PamVBox();
|
||||
|
45
src/rawDeepLearningClassifier/dataPlotFX/LineInfo.java
Normal file
45
src/rawDeepLearningClassifier/dataPlotFX/LineInfo.java
Normal file
@ -0,0 +1,45 @@
|
||||
package rawDeepLearningClassifier.dataPlotFX;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import javafx.scene.paint.Color;
|
||||
|
||||
/**
|
||||
* The line info.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
public class LineInfo implements Serializable, Cloneable{
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public LineInfo(boolean enabled, Color color) {
|
||||
this.enabled = enabled;
|
||||
this.color = color;
|
||||
}
|
||||
|
||||
/**
|
||||
* The colour of the line
|
||||
*/
|
||||
public Color color = Color.DODGERBLUE;
|
||||
|
||||
/**
|
||||
* True if the line is enabled (shown on the plot)
|
||||
*/
|
||||
public boolean enabled = true;
|
||||
|
||||
@Override
|
||||
protected LineInfo clone() {
|
||||
try {
|
||||
return (LineInfo) super.clone();
|
||||
} catch (CloneNotSupportedException e) {
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -367,7 +367,7 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
|
||||
pathLabel .setText(this.currentSelectedFile.getName());
|
||||
try {
|
||||
pathLabel.setTooltip(new Tooltip(this.currentSelectedFile.getPath()
|
||||
+ "\n" +" Processor " + Device.defaultDevice().toString()));
|
||||
+ "\n" +" Processor CPU " + Device.cpu() + " " + Device.gpu()));
|
||||
}
|
||||
catch (Exception e) {
|
||||
//sometimes get an error here for some reason
|
||||
|
@ -63,13 +63,13 @@ public abstract class DLModelWorker<T> {
|
||||
//data input into the model - a stack of spectrogram images.
|
||||
float[][][] transformedDataStack = new float[numChunks][][];
|
||||
|
||||
//geenrate the spectrogram stack.
|
||||
//generate the spectrogram stack.
|
||||
AudioData soundData;
|
||||
double[][] transformedData;
|
||||
double[][] transformedData2; //spec data
|
||||
double[] transformedData1; //waveform data
|
||||
for (int j=0; j<numChunks; j++) {
|
||||
|
||||
soundData = new AudioData(rawDataUnits.get(j).getRawData()[iChan], sampleRate);
|
||||
|
||||
// for (int i=0; i<modelTransforms.size(); i++) {
|
||||
// System.out.println("Transfrom type: " + modelTransforms.get(i).getDLTransformType());
|
||||
// }
|
||||
@ -83,20 +83,25 @@ public abstract class DLModelWorker<T> {
|
||||
DLTransform transform = modelTransforms.get(0);
|
||||
for (int i =0; i<modelTransforms.size(); i++) {
|
||||
transform = modelTransforms.get(i).transformData(transform);
|
||||
|
||||
//// //TEMP
|
||||
// //TEMP
|
||||
// if (transform instanceof FreqTransform) {
|
||||
// transformedData = ((FreqTransform) transform).getSpecTransfrom().getTransformedData();
|
||||
// System.out.println("DLModelWorker: transform : " + modelTransforms.get(i).getDLTransformType() + " "+ i + transformedData.length + " " + transformedData[0].length + " minmax: " + PamArrayUtils.minmax(transformedData)[0] + " " + PamArrayUtils.minmax(transformedData)[1]);
|
||||
// }
|
||||
}
|
||||
|
||||
//the transformed data
|
||||
transformedData = ((FreqTransform) transform).getSpecTransfrom().getTransformedData();
|
||||
|
||||
|
||||
transformedDataStack[j] = DLUtils.toFloatArray(transformedData);
|
||||
if (transform instanceof FreqTransform) {
|
||||
//add a spectrogram to the stacl
|
||||
transformedData2 = ((FreqTransform) transform).getSpecTransfrom().getTransformedData();
|
||||
transformedDataStack[j] = DLUtils.toFloatArray(transformedData2);
|
||||
|
||||
}
|
||||
else {
|
||||
//add wavefrom to the stack = we make the 2nd dimesnion 1.
|
||||
transformedData1 = ((WaveTransform) transform).getWaveData().getScaledSampleAmpliudes();
|
||||
transformedDataStack[j] = new float[1][transformedData1.length];
|
||||
transformedDataStack[j][0] = DLUtils.toFloatArray(transformedData1);
|
||||
}
|
||||
}
|
||||
|
||||
//run the model.
|
||||
@ -106,7 +111,6 @@ public abstract class DLModelWorker<T> {
|
||||
//System.out.println("Model out: " + PamArrayUtils.array2String(output, 2, ","));
|
||||
long time2 = System.currentTimeMillis();
|
||||
|
||||
|
||||
int numclasses = (int) (output.length/transformedDataStack.length);
|
||||
|
||||
// System.out.println(PamCalendar.formatDBDateTime(rawDataUnits.get(0).getTimeMilliseconds(), true) +
|
||||
@ -117,8 +121,6 @@ public abstract class DLModelWorker<T> {
|
||||
float[] prob;
|
||||
float[] classOut;
|
||||
for (int i=0; i<transformedDataStack.length; i++) {
|
||||
|
||||
|
||||
/**
|
||||
* This is super weird. Reading the documentation for copeOfRange the index from and index to are enclusive. So
|
||||
* to copy the first two elements indexfrom =0 and indexto = 1. But actually it seems that this should be indexfrom =0 and indexto =2.
|
||||
@ -128,13 +130,12 @@ public abstract class DLModelWorker<T> {
|
||||
|
||||
// System.out.println("The copyOfRange is: " + i*numclasses + " to " + ((i+1)*numclasses-1) + " class out len: " + classOut.length);
|
||||
|
||||
|
||||
if (enableSoftMax) {
|
||||
prob = new float[classOut.length];
|
||||
for (int j=0; j<classOut.length; j++) {
|
||||
//python code for this.
|
||||
// prob = torch.nn.functional.softmax(out).numpy()[n, 1]
|
||||
// pred = int(prob >= ARGS.threshold)
|
||||
//prob = torch.nn.functional.softmax(out).numpy()[n, 1]
|
||||
//pred = int(prob >= ARGS.threshold)
|
||||
//softmax function
|
||||
prob[j] = (float) DLUtils.softmax(classOut[j], classOut);
|
||||
//System.out.println("The probability is: " + j + ": " + prob[j]);
|
||||
@ -143,13 +144,12 @@ public abstract class DLModelWorker<T> {
|
||||
else {
|
||||
prob = classOut;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//does this pass binary classification
|
||||
long timeEnd = System.nanoTime();
|
||||
|
||||
T modelResult = makeModelResult(prob, (timeEnd-timeStart)/1000/1000/1000);
|
||||
// soundSpotResult.setAnalysisTime((timeEnd-timeStart)/1000/1000/1000);
|
||||
//soundSpotResult.setAnalysisTime((timeEnd-timeStart)/1000/1000/1000);
|
||||
|
||||
modelResults.add(modelResult);
|
||||
}
|
||||
|
@ -168,7 +168,7 @@ public class GenericAdvPane extends SettingsPane<GenericModelParams> {
|
||||
shapeSpinners= new Spinner[4]; //set at for now but could be different in future?
|
||||
|
||||
for (int i=0; i<shapeSpinners.length; i++) {
|
||||
shapeSpinners[i] = new Spinner<Integer>(-1, Integer.MAX_VALUE, 10, 10);
|
||||
shapeSpinners[i] = new Spinner<Integer>(-1, Integer.MAX_VALUE, 10, 1);
|
||||
shapeSpinners[i] .setPrefWidth(80);
|
||||
shapeSpinners[i] .getStyleClass().add(Spinner.STYLE_CLASS_SPLIT_ARROWS_HORIZONTAL);
|
||||
shapeSpinners[i] .setEditable(true);
|
||||
@ -201,7 +201,7 @@ public class GenericAdvPane extends SettingsPane<GenericModelParams> {
|
||||
|
||||
for (int i=0; i<outShapeSpinners.length; i++) {
|
||||
|
||||
outShapeSpinners[i] = new Spinner<Integer>(-1, Integer.MAX_VALUE, 10, 10);
|
||||
outShapeSpinners[i] = new Spinner<Integer>(-1, Integer.MAX_VALUE, 10, 1);
|
||||
outShapeSpinners[i] .setPrefWidth(80);
|
||||
outShapeSpinners[i] .getStyleClass().add(Spinner.STYLE_CLASS_SPLIT_ARROWS_HORIZONTAL);
|
||||
outShapeSpinners[i] .setEditable(true);
|
||||
|
@ -2,6 +2,8 @@ package rawDeepLearningClassifier.dlClassification.genericModel;
|
||||
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.jamdev.jdl4pam.genericmodel.GenericModel;
|
||||
import org.jamdev.jdl4pam.transforms.DLTransform;
|
||||
import org.jamdev.jdl4pam.transforms.FreqTransform;
|
||||
|
||||
import PamModel.PamModel;
|
||||
import PamModel.PamModel.PluginClassloader;
|
||||
@ -19,13 +21,28 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
/**
|
||||
* The generic model
|
||||
*/
|
||||
private GenericModel genericModel;
|
||||
private PamGenericModel genericModel;
|
||||
|
||||
/**
|
||||
* Frequency transform.
|
||||
*/
|
||||
private boolean freqTransform = true;
|
||||
|
||||
@Override
|
||||
public float[] runModel(float[][][] transformedDataStack) {
|
||||
//System.out.println("RUN GENERIC MODEL: " + transformedDataStack.length + " " + transformedDataStack[0].length + " " + transformedDataStack[0][0].length);
|
||||
//System.out.println("RUN GENERIC MODEL: " + transformedDataStack[0][0][0]);
|
||||
float[] results = genericModel.runModel(transformedDataStack);
|
||||
float[] results;
|
||||
if (freqTransform)
|
||||
results = genericModel.runModel2(transformedDataStack);
|
||||
else {
|
||||
//run a model if it is waveform info.
|
||||
float[][] waveStack = new float[transformedDataStack.length][];
|
||||
for (int i=0; i<waveStack.length; i++) {
|
||||
waveStack[i] = transformedDataStack[i][0];
|
||||
}
|
||||
results = genericModel.runModel1(waveStack);
|
||||
}
|
||||
//System.out.println("GENERIC MODEL RESULTS: " + results== null ? null : results.length);
|
||||
return results;
|
||||
}
|
||||
@ -39,7 +56,7 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
|
||||
@Override
|
||||
public void prepModel(StandardModelParams genericParams, DLControl dlControl) {
|
||||
ClassLoader origCL = Thread.currentThread().getContextClassLoader();
|
||||
//ClassLoader origCL = Thread.currentThread().getContextClassLoader();
|
||||
try {
|
||||
if (genericParams.modelPath==null) return;
|
||||
|
||||
@ -47,14 +64,24 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
// NOTE THAT THIS IS REQUIRED TO MAKE THIS MODULE RUN AS A PLUGIN WHEN THE CLASS FILES
|
||||
// ARE BUNDLED INTO A FATJAR, HOWEVER THIS WILL STOP THE PLUGIN FROM RUNNING AS A SEPARATE
|
||||
// PROJECT IN ECLIPSE. So while testing the code and debugging, make sure the
|
||||
if (DLControl.PLUGIN_BUILD) {
|
||||
PluginClassloader newCL = PamModel.getPamModel().getClassLoader();
|
||||
Thread.currentThread().setContextClassLoader(newCL);
|
||||
}
|
||||
// if (DLControl.PLUGIN_BUILD) {
|
||||
// PluginClassloader newCL = PamModel.getPamModel().getClassLoader();
|
||||
// Thread.currentThread().setContextClassLoader(newCL);
|
||||
// }
|
||||
|
||||
//first open the model and get the correct parameters.
|
||||
genericModel = new GenericModel(genericParams.modelPath);
|
||||
genericModel = new PamGenericModel(genericParams.modelPath);
|
||||
|
||||
//is this a waveform or a spectrogram model?
|
||||
DLTransform transform = genericParams.dlTransfroms.get(genericParams.dlTransfroms.size()-1);
|
||||
if (transform instanceof FreqTransform) {
|
||||
freqTransform = true;
|
||||
}
|
||||
else {
|
||||
freqTransform = false;
|
||||
}
|
||||
|
||||
//use softmax or not?
|
||||
String extension = FilenameUtils.getExtension(genericParams.modelPath);
|
||||
if (extension.equals("pb")) {
|
||||
//TensorFlow models don't need softmax?? Need to look into this more.
|
||||
@ -63,7 +90,6 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
else {
|
||||
this.setEnableSoftMax(true);
|
||||
}
|
||||
|
||||
|
||||
GenericModelParams genericModelParams = new GenericModelParams();
|
||||
|
||||
@ -78,8 +104,7 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
//WarnOnce.showWarning(null, "Model Load Error", "There was an error loading the model file.", WarnOnce.OK_OPTION);
|
||||
}
|
||||
|
||||
Thread.currentThread().setContextClassLoader(origCL);
|
||||
|
||||
//Thread.currentThread().setContextClassLoader(origCL);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -99,14 +124,14 @@ public class GenericModelWorker extends DLModelWorker<GenericPrediction> {
|
||||
|
||||
@Override
|
||||
public void closeModel() {
|
||||
// TODO Auto-generated method stub
|
||||
genericModel.getModel().close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic model.
|
||||
* @return the generic model.
|
||||
*/
|
||||
public GenericModel getModel() {
|
||||
public PamGenericModel getModel() {
|
||||
return genericModel;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,218 @@
|
||||
package rawDeepLearningClassifier.dlClassification.genericModel;
|
||||
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.jamdev.jdl4pam.genericmodel.SpectrogramTranslator;
|
||||
import org.jamdev.jdl4pam.genericmodel.WaveformTranslator;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.Model;
|
||||
import ai.djl.engine.Engine;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.translate.TranslateException;
|
||||
|
||||
|
||||
/**
|
||||
* The generic model. This is implemented in the JPAM library as Generic model
|
||||
* but having it here gives a little more control without requiring constant changes
|
||||
* the Maven dependencies.
|
||||
*
|
||||
* @author Jamie Macaulay
|
||||
*
|
||||
*/
|
||||
public class PamGenericModel {
|
||||
|
||||
|
||||
/**
|
||||
* The currently loaded model
|
||||
*/
|
||||
private Model model;
|
||||
|
||||
/**
|
||||
* The predictor for the model.
|
||||
*/
|
||||
Predictor<float[][][], float[]> specPredictor;
|
||||
|
||||
|
||||
/**
|
||||
* Predictor for the model for waveforms.
|
||||
*/
|
||||
Predictor<float[][], float[]> wavePredictor;
|
||||
|
||||
|
||||
/**
|
||||
* The input shape from the loaded model.
|
||||
*/
|
||||
private Shape inputShape = null;
|
||||
|
||||
/**
|
||||
* The output shape from the model.
|
||||
*/
|
||||
private Shape outShape = null;
|
||||
|
||||
|
||||
private SpectrogramTranslator specTranslator;
|
||||
|
||||
|
||||
private WaveformTranslator waveTranslator;
|
||||
|
||||
|
||||
|
||||
public PamGenericModel(String modelPath) throws MalformedModelException, IOException{
|
||||
|
||||
File file = new File(modelPath);
|
||||
|
||||
//String modelPath = "/Users/au671271/Google Drive/Aarhus_research/PAMGuard_bats_2020/deep_learning/BAT/models/bats_denmark/BAT_4ms_256ft_8hop_128_NOISEAUG_40000_100000_-100_0_256000_JAMIE.pk";
|
||||
|
||||
Path modelDir = Paths.get(file.getAbsoluteFile().getParent()); //the directory of the file (in case the file is local this should also return absolute directory)
|
||||
String modelName = file.getName();
|
||||
|
||||
String extension = FilenameUtils.getExtension(file.getAbsolutePath());
|
||||
|
||||
System.out.println("Generic Model: Available engines: " + Engine.getAllEngines());
|
||||
|
||||
Model model;
|
||||
switch (extension) {
|
||||
case "pb":
|
||||
model = Model.newInstance(modelPath, "TensorFlow");
|
||||
model.load(modelDir, modelName);
|
||||
break;
|
||||
case "py":
|
||||
model = Model.newInstance(modelName);
|
||||
model.load(modelDir, modelName);
|
||||
break;
|
||||
default:
|
||||
//will try to load a model automatically - problematic but let's see.
|
||||
model = Model.newInstance(modelPath);
|
||||
break;
|
||||
}
|
||||
|
||||
if (model == null) {
|
||||
System.err.println("Generic Model: Could not load model: " + modelPath);
|
||||
}
|
||||
|
||||
else {
|
||||
if (model!=null && model.describeInput()!=null) {
|
||||
System.out.println("Generic Model: Input: " + model.describeInput().toString());
|
||||
inputShape = model.describeInput().get(0).getValue();
|
||||
}
|
||||
if (model!=null && model.describeOutput()!=null) {
|
||||
System.out.println("Generic Model: Output: " + model.describeOutput().toString());
|
||||
outShape = model.describeOutput().get(0).getValue();
|
||||
}
|
||||
|
||||
this.model=model;
|
||||
|
||||
specTranslator = new SpectrogramTranslator(inputShape);
|
||||
|
||||
waveTranslator = new WaveformTranslator(model.describeInput());
|
||||
|
||||
//predictor for the model if using images as input
|
||||
specPredictor = model.newPredictor(specTranslator);
|
||||
|
||||
//predictor for the model if using
|
||||
wavePredictor = model.newPredictor(waveTranslator);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the predictor for spectrogram images.
|
||||
* @return
|
||||
*/
|
||||
public Predictor<float[][][], float[]> getSpecPredictor() {
|
||||
return specPredictor;
|
||||
}
|
||||
|
||||
/***
|
||||
* Get the predictor for the waveform input.
|
||||
* @return the predictor for waveforms.
|
||||
*/
|
||||
public Predictor<float[][], float[]> getWavePredictor() {
|
||||
return wavePredictor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the model shape for the input.
|
||||
* @return the input shape.
|
||||
*/
|
||||
public Shape getInputShape() {
|
||||
return inputShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the input shape.
|
||||
* @param inputShape - the input shape.
|
||||
*/
|
||||
public void setInputShape(Shape inputShape) {
|
||||
this.inputShape = inputShape;
|
||||
specTranslator.setShape(inputShape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the output shape. The shape is null if the model does not specify shape.
|
||||
* @return the output shape.
|
||||
*/
|
||||
public Shape getOutShape() {
|
||||
return outShape;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Run the model on spectrogram images
|
||||
* @param specImage - the spectrogram image [no. batches][image x][image y]
|
||||
* @return the results
|
||||
*/
|
||||
public float[] runModel2(float[][][] specImage) {
|
||||
try {
|
||||
float[] results = specPredictor.predict(specImage);
|
||||
//DLUtils.printArray(results);
|
||||
return results;
|
||||
} catch (TranslateException e) {
|
||||
System.out.println("Error on model: ");
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Run the model on a raw waveform data
|
||||
* @param specImage - waveform data [no. batches][samples]
|
||||
* @return the results
|
||||
*/
|
||||
public float[] runModel1(float[][] waveform) {
|
||||
try {
|
||||
float[] results = wavePredictor.predict(waveform);
|
||||
//DLUtils.printArray(results);
|
||||
return results;
|
||||
} catch (TranslateException e) {
|
||||
System.out.println("Error on model: ");
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
public Model getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public void setModel(Model model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
@ -109,6 +109,8 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
*/
|
||||
private HBox dataSelectorPane;
|
||||
|
||||
private Label infoLabel;
|
||||
|
||||
public RawDLSettingsPane(DLControl dlControl){
|
||||
super(null);
|
||||
this.dlControl=dlControl;
|
||||
@ -173,12 +175,18 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
windowLength.setPrefWidth(100);
|
||||
windowLength.getStyleClass().add(Spinner.STYLE_CLASS_SPLIT_ARROWS_HORIZONTAL);
|
||||
windowLength.setEditable(true);
|
||||
windowLength.valueProperty().addListener((obsVal, oldVal, newVal)->{
|
||||
setSegInfoLabel();
|
||||
});
|
||||
|
||||
hopLength = new PamSpinner<Integer>(0, Integer.MAX_VALUE, 10, 10000);
|
||||
hopLength.setPrefWidth(100);
|
||||
hopLength.getStyleClass().add(Spinner.STYLE_CLASS_SPLIT_ARROWS_HORIZONTAL);
|
||||
hopLength.setEditable(true);
|
||||
|
||||
hopLength.valueProperty().addListener((obsVal, oldVal, newVal)->{
|
||||
setSegInfoLabel();
|
||||
});
|
||||
|
||||
reMergeSeg = new PamSpinner<Integer>(0, Integer.MAX_VALUE, 1, 1);
|
||||
reMergeSeg.setPrefWidth(100);
|
||||
reMergeSeg.getStyleClass().add(Spinner.STYLE_CLASS_SPLIT_ARROWS_HORIZONTAL);
|
||||
@ -208,6 +216,9 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
segmenterGridPane.add(new Label("segments"), 2, 2);
|
||||
|
||||
vBox.getChildren().add(segmenterGridPane);
|
||||
|
||||
vBox.getChildren().add(infoLabel = new Label());
|
||||
|
||||
|
||||
Label label2 = new Label("Deep Learning Model");
|
||||
label2.setPadding(new Insets(5,0,0,0));
|
||||
@ -267,6 +278,23 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
return dataSelectorPane;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set extra information in the info label.
|
||||
*/
|
||||
private void setSegInfoLabel() {
|
||||
String text;
|
||||
if (sourcePane.getSource()==null) {
|
||||
text = String.format("Window - s Hop: - s (no source data)");
|
||||
}
|
||||
else {
|
||||
float sR = sourcePane.getSource().getSampleRate();
|
||||
double windowLenS = windowLength.getValue()/sR;
|
||||
double hopLengthS = hopLength.getValue()/sR;
|
||||
|
||||
text = String.format("Window %.3f s Hop: %.3f s", windowLenS, hopLengthS);
|
||||
}
|
||||
infoLabel.setText(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates pane allowing the user to change fine scale things such as error limits.
|
||||
@ -442,6 +470,8 @@ public class RawDLSettingsPane extends SettingsPane<RawDLParams>{
|
||||
setClassifierPane();
|
||||
|
||||
enableControls();
|
||||
|
||||
setSegInfoLabel();
|
||||
|
||||
}
|
||||
|
||||
|
@ -225,10 +225,14 @@ public abstract class DLTransformImage extends PamBorderPane{
|
||||
double nSamples = this.exampleSound.getSampleRate()*((timeSlider.getHighValue()-timeSlider.getLowValue())/1000.0);
|
||||
|
||||
|
||||
if (shape!=null) {
|
||||
if (shape!=null && shape.length==2) {
|
||||
int timeShape = (int) (shape[0]*(nSamples/(double) exampleSound.getWave().length));
|
||||
timeLabel.setText(String.format("Segment size %.0f (samples) Transform shape: [%d %d]", nSamples, timeShape, shape[1]));
|
||||
}
|
||||
else if (shape!=null && shape.length==1) {
|
||||
int timeShape = (int) (shape[0]*(nSamples/(double) exampleSound.getWave().length));
|
||||
timeLabel.setText(String.format("Segment size %.0f (samples) Transform shape: [%d]", nSamples, timeShape));
|
||||
}
|
||||
else {
|
||||
timeLabel.setText(String.format("Segment size %.0f (samples) Transform shape: NaN", nSamples));
|
||||
}
|
||||
@ -245,14 +249,24 @@ public abstract class DLTransformImage extends PamBorderPane{
|
||||
ArrayList<DLTransform> transforms = getDLTransforms();
|
||||
|
||||
if (transforms==null) return null;
|
||||
|
||||
FreqTransform freqTransform = ((FreqTransform) transforms.get(transforms.size()-1));
|
||||
|
||||
if (freqTransform.getSpecTransfrom()==null) return null;
|
||||
|
||||
double[][] data2D = freqTransform.getSpecTransfrom().getTransformedData();
|
||||
|
||||
long[] shape = new long[] {data2D.length, data2D[0].length};
|
||||
long[] shape;
|
||||
if (transforms.get(transforms.size()-1) instanceof FreqTransform) {
|
||||
FreqTransform freqTransform = ((FreqTransform) transforms.get(transforms.size()-1));
|
||||
|
||||
if (freqTransform.getSpecTransfrom()==null) return null;
|
||||
|
||||
double[][] data2D = freqTransform.getSpecTransfrom().getTransformedData();
|
||||
|
||||
shape = new long[] {data2D.length, data2D[0].length};
|
||||
}
|
||||
else {
|
||||
WaveTransform waveTransform = ((WaveTransform) transforms.get(transforms.size()-1));
|
||||
|
||||
double[] data = waveTransform.getWaveData().getScaledSampleAmpliudes();
|
||||
|
||||
shape = new long[] {data.length};
|
||||
}
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
@ -40,8 +40,15 @@ public class DataTransformPaneFactory {
|
||||
DLTransformPane settingsPane = null;
|
||||
switch (dlTransfrom.getDLTransformType()) {
|
||||
case DECIMATE:
|
||||
settingsPane = new SimpleTransformPane((SimpleTransform) dlTransfrom, new String[]{"Sample rate "}, new String[]{"Hz. "});
|
||||
double sR = ((SimpleTransform) dlTransfrom).getParams()[0].doubleValue();
|
||||
double sR;
|
||||
if (((SimpleTransform) dlTransfrom).getParams()!=null){
|
||||
sR = ((SimpleTransform) dlTransfrom).getParams()[0].doubleValue();
|
||||
}
|
||||
else {
|
||||
sR = 20000;
|
||||
((SimpleTransform) dlTransfrom).setParams(new Number[] {Double.valueOf(sR)});
|
||||
}
|
||||
settingsPane = new SimpleTransformPane((SimpleTransform) dlTransfrom, new String[]{"Sample rate "}, new String[]{"Hz. "});
|
||||
((SimpleTransformPane) settingsPane).setSpinnerMinMaxValues(0, 100.0, Double.MAX_VALUE, sR>10000.0 ? 1000.0 : 100.0);
|
||||
break;
|
||||
case PREEMPHSIS:
|
||||
|
Loading…
Reference in New Issue
Block a user