Getting click deep learning model

This commit is contained in:
Jamie Mac 2024-08-07 06:11:24 +01:00
parent 1174e371f0
commit fc4c9495cf
18 changed files with 201 additions and 17 deletions

4
.gitignore vendored
View File

@ -112,3 +112,7 @@ settings.xml
.classpath .classpath
.settings/org.eclipse.jdt.core.prefs .settings/org.eclipse.jdt.core.prefs
.classpath .classpath
.classpath
.settings/org.eclipse.core.resources.prefs
.settings/org.eclipse.jdt.core.prefs
dependency-reduced-pom.xml

View File

@ -1,6 +1,5 @@
eclipse.preferences.version=1 eclipse.preferences.version=1
encoding//src/rawDeepLearningClassifer/segmenter/SegmenterProcess.java=UTF-8 encoding//src/rawDeepLearningClassifer/segmenter/SegmenterProcess.java=UTF-8
encoding//src/test=UTF-8
encoding//src/test/resources=UTF-8 encoding//src/test/resources=UTF-8
encoding/<project>=UTF-8 encoding/<project>=UTF-8
encoding/src=UTF-8 encoding/src=UTF-8

View File

@ -1,9 +1,9 @@
eclipse.preferences.version=1 eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate
org.eclipse.jdt.core.compiler.codegen.targetPlatform=21 org.eclipse.jdt.core.compiler.codegen.targetPlatform=18
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
org.eclipse.jdt.core.compiler.compliance=21 org.eclipse.jdt.core.compiler.compliance=18
org.eclipse.jdt.core.compiler.debug.lineNumber=generate org.eclipse.jdt.core.compiler.debug.lineNumber=generate
org.eclipse.jdt.core.compiler.debug.localVariable=generate org.eclipse.jdt.core.compiler.debug.localVariable=generate
org.eclipse.jdt.core.compiler.debug.sourceFile=generate org.eclipse.jdt.core.compiler.debug.sourceFile=generate
@ -13,4 +13,4 @@ org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=warning org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=warning
org.eclipse.jdt.core.compiler.release=enabled org.eclipse.jdt.core.compiler.release=enabled
org.eclipse.jdt.core.compiler.source=21 org.eclipse.jdt.core.compiler.source=18

View File

@ -4,7 +4,7 @@
<groupId>org.pamguard</groupId> <groupId>org.pamguard</groupId>
<artifactId>Pamguard</artifactId> <artifactId>Pamguard</artifactId>
<name>Pamguard</name> <name>Pamguard</name>
<version>2.02.12</version> <version>2.02.11d</version>
<description>Pamguard using Maven to control dependencies</description> <description>Pamguard using Maven to control dependencies</description>
<url>www.pamguard.org</url> <url>www.pamguard.org</url>
<organization> <organization>

View File

@ -713,6 +713,34 @@ public class PamArrayUtils {
} }
/**
* Normalise an array
* @param arr - the array to normalise
* @param scaleFactor - multiply the resulting array by a scale factor.
* @return normalised copy of the array
*/
public static float[] normalise(float[] arr, double scaleFactor) {
// //first find the sum of the square of the wave
if (arr != null) {
int n = arr.length;
double sum = 0.0;
for (int i = 0; i < n; i++) {
sum += arr[i] * arr[i];
}
sum = Math.pow(sum, 0.5);
float[] normArr=new float[arr.length];
for (int i=0; i<normArr.length; i++) {
normArr[i]=(float) (scaleFactor*arr[i]/sum);
}
return normArr;
}
else return null;
}
/** /**
* Flip a double array so that it is in the reverse order. Note the array is * Flip a double array so that it is in the reverse order. Note the array is
* cloned. * cloned.
@ -1274,6 +1302,7 @@ public class PamArrayUtils {
} }

View File

@ -1,5 +1,7 @@
package rawDeepLearningClassifier.dlClassification.genericModel; package rawDeepLearningClassifier.dlClassification.genericModel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import org.jamdev.jdl4pam.genericmodel.GenericModel; import org.jamdev.jdl4pam.genericmodel.GenericModel;
@ -13,6 +15,13 @@ import org.jamdev.jdl4pam.transforms.WaveTransform;
import org.jamdev.jdl4pam.utils.DLUtils; import org.jamdev.jdl4pam.utils.DLUtils;
import org.jamdev.jpamutils.wavFiles.AudioData; import org.jamdev.jpamutils.wavFiles.AudioData;
import PamUtils.PamArrayUtils;
import rawDeepLearningClassifier.segmenter.GroupedRawData;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.format.Mat5File;
import us.hebi.matlab.mat.types.Matrix;
import us.hebi.matlab.mat.types.Struct;
public class GenericModelTest { public class GenericModelTest {
public static void rightWhaleTest() { public static void rightWhaleTest() {
@ -98,9 +107,126 @@ public class GenericModelTest {
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
} }
} }
public static void clickDLTest() {
float SAMPLE_RATE = 500000;
//relative paths to the resource folders.
System.out.println("*****Click classification Deep Learning C*****");
//relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/uniform_model/saved_model.pb";
String clicksPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/clicks.mat";
Path path = Paths.get(relModelPath);
GenericModelWorker genericModelWorker = new GenericModelWorker();
GenericModelParams genericModelParams = new GenericModelParams();
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
//create the transforms.
ArrayList<DLTransfromParams> dlTransformParamsArr = new ArrayList<DLTransfromParams>();
//waveform transforms.
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.DECIMATE_SCIPY, 250000.));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.PEAK_TRIM, 128, 1));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.NORMALISE_WAV, 0., 1, AudioData.ZSCORE));
genericModelParams.dlTransfromParams = dlTransformParamsArr;
genericModelParams.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>)genericModelParams.dlTransfromParams);
//create the clicks.
path = Paths.get(clicksPath);
ArrayList<GroupedRawData> clicks = importClicks(path.toAbsolutePath().normalize().toString(), SAMPLE_RATE);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
System.out.println("Model has loaded");
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
float prediction = 0;
for (int i=0; i<clicks.size() ; i++) {
groupedData = new ArrayList<GroupedRawData>();
groupedData.add(clicks.get(i)); //TODO for loop
// System.out.println("Waveform input: " + groupedData.get(i).getRawData().length + " " + groupedData.get(i).getRawData()[0].length + " " + + groupedData.get(i).getRawData()[0][0]);
//RUN THE RAW MODEL
// System.out.println("Min max before: ");
// PamArrayUtils.printArray(PamArrayUtils.minmax(groupedData.get(i).getRawData()[0]));
// double[] wav = PamArrayUtils.normalise(groupedData.get(i).getRawData()[0]);
//
// System.out.println("Min max: ");
// PamArrayUtils.printArray(PamArrayUtils.minmax(wav));
// float[][] input1 = new float[][] {PamArrayUtils.double2Float(wav)};
// float[] output1 = genericModelWorker.getModel().runModel(input1);
// System.out.println("Output1: " );
// PamArrayUtils.printArray(output1);
//RUN THROUGH THE GENERIC MODEL CLASSIIFER.
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData,SAMPLE_RATE, 0);
float[] output = genericPrediction.get(0).getPrediction();
System.out.println(String.format("Click %d Predicted output: %.6f true output: %.6f passed: %b", clicks.get(i).getUID(),
output[0], prediction, output[0]>prediction*0.9 && output[0]<prediction*1.1));
}
}
/**
* Import a bunch of clicks from a .mat file
*/
public static ArrayList<GroupedRawData> importClicks(String filePath, float sR) {
try {
Mat5File mfr = Mat5.readFromFile(filePath);
// //get array of a name "my_array" from file
Struct mlArrayRetrived = mfr.getStruct( "clickpreds" );
int numClicks= mlArrayRetrived.getNumCols();
ArrayList<GroupedRawData> clicks = new ArrayList<GroupedRawData>(numClicks);
GroupedRawData clickData;
for (int i=0; i<numClicks; i++) {
Matrix clickWav= mlArrayRetrived.get("wave", i);
double[][] clickwaveform= PamArrayUtils.matrix2array(clickWav);
clickwaveform = PamArrayUtils.transposeMatrix(clickwaveform);
//System.out.println("click: " + click[0].length + " num: " + numClicks);
Matrix clickUID= mlArrayRetrived.get("UID", i);
Matrix clickmillis= mlArrayRetrived.get("millis", i);
Matrix channelMap= mlArrayRetrived.get("channelMap", i);
Matrix startSample= mlArrayRetrived.get("startSample", i);
Matrix sampleDuration= mlArrayRetrived.get("sampleDuration", i);
clickData = new GroupedRawData(clickmillis.getLong(0), channelMap.getInt(0), startSample.getLong(0), sampleDuration.getLong(0), sampleDuration.getInt(0));
clickData.setUID(clickUID.getLong(0));
clickData.setRawData(clickwaveform);
clicks.add(clickData);
}
return clicks;
}
catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
}
/** /**
* The bat Pytorch test. * The bat Pytorch test.
*/ */
@ -110,7 +236,8 @@ public class GenericModelTest {
} }
public static void main(String args[]) { public static void main(String args[]) {
rightWhaleTest(); // rightWhaleTest();
clickDLTest();
} }
} }

View File

@ -6,6 +6,7 @@ import org.apache.commons.io.FilenameUtils;
import org.jamdev.jdl4pam.genericmodel.GenericModel; import org.jamdev.jdl4pam.genericmodel.GenericModel;
import org.jamdev.jdl4pam.transforms.DLTransform; import org.jamdev.jdl4pam.transforms.DLTransform;
import org.jamdev.jdl4pam.transforms.FreqTransform; import org.jamdev.jdl4pam.transforms.FreqTransform;
import org.jamdev.jpamutils.JamArr;
import PamModel.PamModel; import PamModel.PamModel;
import PamModel.PamModel.PluginClassloader; import PamModel.PamModel.PluginClassloader;
@ -42,10 +43,12 @@ public class GenericModelWorker extends DLModelWorker<StandardPrediction> {
//run a model if it is waveform info. //run a model if it is waveform info.
float[][] waveStack = new float[transformedDataStack.length][]; float[][] waveStack = new float[transformedDataStack.length][];
for (int i=0; i<waveStack.length; i++) { for (int i=0; i<waveStack.length; i++) {
// waveStack[i] = PamArrayUtils.double2Float(JamArr.product(PamArrayUtils.float2Double(transformedDataStack[i][0]), 0.99));
waveStack[i] = transformedDataStack[i][0]; waveStack[i] = transformedDataStack[i][0];
} }
System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0]); // System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0]);
// PamArrayUtils.printArray(waveStack[0]);
results = getModel().runModel(waveStack); results = getModel().runModel(waveStack);
} }
// System.out.println("GENERIC MODEL RESULTS: " + (results== null ? null : results.length)); // System.out.println("GENERIC MODEL RESULTS: " + (results== null ? null : results.length));

View File

@ -31,10 +31,10 @@ public class ClickDLTest {
float SAMPLE_RATE = 500000; float SAMPLE_RATE = 500000;
//relative paths to the resource folders. //relative paths to the resource folders.
System.out.println("*****Click classification Deep Learning*****"); System.out.println("*****Click classification Deep Learning C*****");
//relative paths to the resource folders. //relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/best_model/saved_model.pb"; String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/updated_model/saved_model.pb";
String clicksPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/clicks.mat"; String clicksPath = "./src/test/resources/rawDeepLearningClassifier/Generic/risso_click/clicks.mat";
Path path = Paths.get(relModelPath); Path path = Paths.get(relModelPath);
@ -62,6 +62,7 @@ public class ClickDLTest {
//prep the model //prep the model
genericModelWorker.prepModel(genericModelParams, null); genericModelWorker.prepModel(genericModelParams, null);
System.out.println("Model has loaded");
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>(); ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
@ -126,7 +127,7 @@ public class ClickDLTest {
e.printStackTrace(); e.printStackTrace();
return null; return null;
} }
} }
} }

File diff suppressed because one or more lines are too long