mirror of
https://github.com/PAMGuard/PAMGuard.git
synced 2024-11-21 22:52:22 +00:00
Getting click deep learning model
This commit is contained in:
parent
1174e371f0
commit
fc4c9495cf
4
.gitignore
vendored
4
.gitignore
vendored
@ -112,3 +112,7 @@ settings.xml
|
||||
.classpath
|
||||
.settings/org.eclipse.jdt.core.prefs
|
||||
.classpath
|
||||
.classpath
|
||||
.settings/org.eclipse.core.resources.prefs
|
||||
.settings/org.eclipse.jdt.core.prefs
|
||||
dependency-reduced-pom.xml
|
||||
|
@ -1,6 +1,5 @@
|
||||
eclipse.preferences.version=1
|
||||
encoding//src/rawDeepLearningClassifer/segmenter/SegmenterProcess.java=UTF-8
|
||||
encoding//src/test=UTF-8
|
||||
encoding//src/test/resources=UTF-8
|
||||
encoding/<project>=UTF-8
|
||||
encoding/src=UTF-8
|
||||
|
@ -1,9 +1,9 @@
|
||||
eclipse.preferences.version=1
|
||||
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
|
||||
org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate
|
||||
org.eclipse.jdt.core.compiler.codegen.targetPlatform=21
|
||||
org.eclipse.jdt.core.compiler.codegen.targetPlatform=18
|
||||
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
|
||||
org.eclipse.jdt.core.compiler.compliance=21
|
||||
org.eclipse.jdt.core.compiler.compliance=18
|
||||
org.eclipse.jdt.core.compiler.debug.lineNumber=generate
|
||||
org.eclipse.jdt.core.compiler.debug.localVariable=generate
|
||||
org.eclipse.jdt.core.compiler.debug.sourceFile=generate
|
||||
@ -13,4 +13,4 @@ org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
|
||||
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
|
||||
org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=warning
|
||||
org.eclipse.jdt.core.compiler.release=enabled
|
||||
org.eclipse.jdt.core.compiler.source=21
|
||||
org.eclipse.jdt.core.compiler.source=18
|
||||
|
@ -4,7 +4,7 @@
|
||||
<groupId>org.pamguard</groupId>
|
||||
<artifactId>Pamguard</artifactId>
|
||||
<name>Pamguard</name>
|
||||
<version>2.02.12</version>
|
||||
<version>2.02.11d</version>
|
||||
<description>Pamguard using Maven to control dependencies</description>
|
||||
<url>www.pamguard.org</url>
|
||||
<organization>
|
||||
|
@ -713,6 +713,34 @@ public class PamArrayUtils {
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Normalise an array
|
||||
* @param arr - the array to normalise
|
||||
* @param scaleFactor - multiply the resulting array by a scale factor.
|
||||
* @return normalised copy of the array
|
||||
*/
|
||||
public static float[] normalise(float[] arr, double scaleFactor) {
|
||||
// //first find the sum of the square of the wave
|
||||
if (arr != null) {
|
||||
int n = arr.length;
|
||||
double sum = 0.0;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
sum += arr[i] * arr[i];
|
||||
}
|
||||
sum = Math.pow(sum, 0.5);
|
||||
|
||||
float[] normArr=new float[arr.length];
|
||||
for (int i=0; i<normArr.length; i++) {
|
||||
normArr[i]=(float) (scaleFactor*arr[i]/sum);
|
||||
}
|
||||
return normArr;
|
||||
}
|
||||
else return null;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Flip a double array so that it is in the reverse order. Note the array is
|
||||
* cloned.
|
||||
@ -1274,6 +1302,7 @@ public class PamArrayUtils {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
package rawDeepLearningClassifier.dlClassification.genericModel;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.jamdev.jdl4pam.genericmodel.GenericModel;
|
||||
@ -13,6 +15,13 @@ import org.jamdev.jdl4pam.transforms.WaveTransform;
|
||||
import org.jamdev.jdl4pam.utils.DLUtils;
|
||||
import org.jamdev.jpamutils.wavFiles.AudioData;
|
||||
|
||||
import PamUtils.PamArrayUtils;
|
||||
import rawDeepLearningClassifier.segmenter.GroupedRawData;
|
||||
import us.hebi.matlab.mat.format.Mat5;
|
||||
import us.hebi.matlab.mat.format.Mat5File;
|
||||
import us.hebi.matlab.mat.types.Matrix;
|
||||
import us.hebi.matlab.mat.types.Struct;
|
||||
|
||||
public class GenericModelTest {
|
||||
|
||||
public static void rightWhaleTest() {
|
||||
@ -98,9 +107,126 @@ public class GenericModelTest {
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static void clickDLTest() {
|
||||
|
||||
float SAMPLE_RATE = 500000;
|
||||
//relative paths to the resource folders.
|
||||
System.out.println("*****Click classification Deep Learning C*****");
|
||||
|
||||
//relative paths to the resource folders.
|
||||
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/uniform_model/saved_model.pb";
|
||||
String clicksPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/clicks.mat";
|
||||
|
||||
Path path = Paths.get(relModelPath);
|
||||
|
||||
GenericModelWorker genericModelWorker = new GenericModelWorker();
|
||||
|
||||
GenericModelParams genericModelParams = new GenericModelParams();
|
||||
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
|
||||
|
||||
//create the transforms.
|
||||
ArrayList<DLTransfromParams> dlTransformParamsArr = new ArrayList<DLTransfromParams>();
|
||||
//waveform transforms.
|
||||
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.DECIMATE_SCIPY, 250000.));
|
||||
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.PEAK_TRIM, 128, 1));
|
||||
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.NORMALISE_WAV, 0., 1, AudioData.ZSCORE));
|
||||
|
||||
genericModelParams.dlTransfromParams = dlTransformParamsArr;
|
||||
genericModelParams.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>)genericModelParams.dlTransfromParams);
|
||||
|
||||
|
||||
//create the clicks.
|
||||
path = Paths.get(clicksPath);
|
||||
ArrayList<GroupedRawData> clicks = importClicks(path.toAbsolutePath().normalize().toString(), SAMPLE_RATE);
|
||||
|
||||
//prep the model
|
||||
genericModelWorker.prepModel(genericModelParams, null);
|
||||
|
||||
System.out.println("Model has loaded");
|
||||
|
||||
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
|
||||
|
||||
|
||||
float prediction = 0;
|
||||
|
||||
for (int i=0; i<clicks.size() ; i++) {
|
||||
groupedData = new ArrayList<GroupedRawData>();
|
||||
groupedData.add(clicks.get(i)); //TODO for loop
|
||||
|
||||
// System.out.println("Waveform input: " + groupedData.get(i).getRawData().length + " " + groupedData.get(i).getRawData()[0].length + " " + + groupedData.get(i).getRawData()[0][0]);
|
||||
|
||||
//RUN THE RAW MODEL
|
||||
// System.out.println("Min max before: ");
|
||||
// PamArrayUtils.printArray(PamArrayUtils.minmax(groupedData.get(i).getRawData()[0]));
|
||||
|
||||
// double[] wav = PamArrayUtils.normalise(groupedData.get(i).getRawData()[0]);
|
||||
//
|
||||
// System.out.println("Min max: ");
|
||||
// PamArrayUtils.printArray(PamArrayUtils.minmax(wav));
|
||||
// float[][] input1 = new float[][] {PamArrayUtils.double2Float(wav)};
|
||||
// float[] output1 = genericModelWorker.getModel().runModel(input1);
|
||||
// System.out.println("Output1: " );
|
||||
// PamArrayUtils.printArray(output1);
|
||||
|
||||
//RUN THROUGH THE GENERIC MODEL CLASSIIFER.
|
||||
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData,SAMPLE_RATE, 0);
|
||||
|
||||
float[] output = genericPrediction.get(0).getPrediction();
|
||||
|
||||
System.out.println(String.format("Click %d Predicted output: %.6f true output: %.6f passed: %b", clicks.get(i).getUID(),
|
||||
output[0], prediction, output[0]>prediction*0.9 && output[0]<prediction*1.1));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Import a bunch of clicks from a .mat file
|
||||
*/
|
||||
public static ArrayList<GroupedRawData> importClicks(String filePath, float sR) {
|
||||
try {
|
||||
Mat5File mfr = Mat5.readFromFile(filePath);
|
||||
|
||||
// //get array of a name "my_array" from file
|
||||
Struct mlArrayRetrived = mfr.getStruct( "clickpreds" );
|
||||
|
||||
int numClicks= mlArrayRetrived.getNumCols();
|
||||
ArrayList<GroupedRawData> clicks = new ArrayList<GroupedRawData>(numClicks);
|
||||
|
||||
GroupedRawData clickData;
|
||||
for (int i=0; i<numClicks; i++) {
|
||||
Matrix clickWav= mlArrayRetrived.get("wave", i);
|
||||
|
||||
double[][] clickwaveform= PamArrayUtils.matrix2array(clickWav);
|
||||
|
||||
clickwaveform = PamArrayUtils.transposeMatrix(clickwaveform);
|
||||
//System.out.println("click: " + click[0].length + " num: " + numClicks);
|
||||
|
||||
Matrix clickUID= mlArrayRetrived.get("UID", i);
|
||||
Matrix clickmillis= mlArrayRetrived.get("millis", i);
|
||||
Matrix channelMap= mlArrayRetrived.get("channelMap", i);
|
||||
Matrix startSample= mlArrayRetrived.get("startSample", i);
|
||||
Matrix sampleDuration= mlArrayRetrived.get("sampleDuration", i);
|
||||
|
||||
clickData = new GroupedRawData(clickmillis.getLong(0), channelMap.getInt(0), startSample.getLong(0), sampleDuration.getLong(0), sampleDuration.getInt(0));
|
||||
clickData.setUID(clickUID.getLong(0));
|
||||
clickData.setRawData(clickwaveform);
|
||||
|
||||
clicks.add(clickData);
|
||||
}
|
||||
return clicks;
|
||||
}
|
||||
catch (Exception e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* The bat Pytorch test.
|
||||
*/
|
||||
@ -110,7 +236,8 @@ public class GenericModelTest {
|
||||
}
|
||||
|
||||
public static void main(String args[]) {
|
||||
rightWhaleTest();
|
||||
// rightWhaleTest();
|
||||
clickDLTest();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import org.apache.commons.io.FilenameUtils;
|
||||
import org.jamdev.jdl4pam.genericmodel.GenericModel;
|
||||
import org.jamdev.jdl4pam.transforms.DLTransform;
|
||||
import org.jamdev.jdl4pam.transforms.FreqTransform;
|
||||
import org.jamdev.jpamutils.JamArr;
|
||||
|
||||
import PamModel.PamModel;
|
||||
import PamModel.PamModel.PluginClassloader;
|
||||
@ -42,10 +43,12 @@ public class GenericModelWorker extends DLModelWorker<StandardPrediction> {
|
||||
//run a model if it is waveform info.
|
||||
float[][] waveStack = new float[transformedDataStack.length][];
|
||||
for (int i=0; i<waveStack.length; i++) {
|
||||
// waveStack[i] = PamArrayUtils.double2Float(JamArr.product(PamArrayUtils.float2Double(transformedDataStack[i][0]), 0.99));
|
||||
waveStack[i] = transformedDataStack[i][0];
|
||||
}
|
||||
|
||||
System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0]);
|
||||
// System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0]);
|
||||
// PamArrayUtils.printArray(waveStack[0]);
|
||||
results = getModel().runModel(waveStack);
|
||||
}
|
||||
// System.out.println("GENERIC MODEL RESULTS: " + (results== null ? null : results.length));
|
||||
|
@ -31,10 +31,10 @@ public class ClickDLTest {
|
||||
|
||||
float SAMPLE_RATE = 500000;
|
||||
//relative paths to the resource folders.
|
||||
System.out.println("*****Click classification Deep Learning*****");
|
||||
System.out.println("*****Click classification Deep Learning C*****");
|
||||
|
||||
//relative paths to the resource folders.
|
||||
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/best_model/saved_model.pb";
|
||||
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/updated_model/saved_model.pb";
|
||||
String clicksPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/clicks.mat";
|
||||
|
||||
Path path = Paths.get(relModelPath);
|
||||
@ -62,6 +62,7 @@ public class ClickDLTest {
|
||||
//prep the model
|
||||
genericModelWorker.prepModel(genericModelParams, null);
|
||||
|
||||
System.out.println("Model has loaded");
|
||||
|
||||
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
|
||||
|
||||
@ -126,7 +127,7 @@ public class ClickDLTest {
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user