Getting dolphin deep learning click detector working

This commit is contained in:
Jamie Mac 2025-03-20 16:29:36 +00:00
parent 83ca569531
commit 1374a0bd26
3 changed files with 156 additions and 33 deletions

View File

@ -1381,7 +1381,25 @@ public class PamArrayUtils {
}
/**
* Convert a matrix to a
* @param matrix - the MAT file matrix
* @return double[][] array of results
*/
public static float[][] matrix2arrayF(Matrix matrix) {
if (matrix==null) return null;
float[][] arrayOut = new float[matrix.getNumRows()][];
float[] arrayRow;
for (int i=0; i<matrix.getNumRows(); i++) {
arrayRow=new float[matrix.getNumCols()];
for (int j=0; j<matrix.getNumCols(); j++) {
arrayRow[j] = matrix.getFloat(i, j);
}
arrayOut[i]=arrayRow;
}
return arrayOut;
}

View File

@ -49,7 +49,7 @@ public class GenericModelWorker extends DLModelWorker<StandardPrediction> {
waveStack[i] = transformedDataStack[i][0];
}
// System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0] + " " + PamArrayUtils.max( waveStack[0]));
System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0] + " " + PamArrayUtils.max( waveStack[0]));
// PamArrayUtils.printArray(waveStack[0]);
results = getModel().runModel(waveStack);
}

View File

@ -7,6 +7,7 @@ import java.util.ArrayList;
import org.jamdev.jdl4pam.transforms.DLTransformsFactory;
import org.jamdev.jdl4pam.transforms.DLTransfromParams;
import org.jamdev.jdl4pam.transforms.SimpleTransformParams;
import org.jamdev.jdl4pam.utils.DLMatFile;
import org.jamdev.jpamutils.wavFiles.AudioData;
import org.jamdev.jdl4pam.transforms.DLTransform.DLTransformType;
import org.junit.jupiter.api.Test;
@ -27,10 +28,114 @@ import us.hebi.matlab.mat.types.Struct;
*/
public class ClickDLTest {
/**
* Test just one click
*/
@Test
public void aclickDLTest() {
System.out.println("CLickDLTest: Single click test");
float SAMPLE_RATE = 500000;
//relative paths to the resource folders.
System.out.println("*****Click classification Deep Learning C*****");
//relative paths to the resource folders.
String relModelPath = "/Users/jdjm/Library/CloudStorage/Dropbox/PAMGuard_dev/Deep_Learning/click_classifier_Thomas/model_v2/model_pb/saved_model.pb";
String clicksPath = "/Users/jdjm/Library/CloudStorage/Dropbox/PAMGuard_dev/Deep_Learning/click_classifier_Thomas/model_v2/example_2000021.mat";
//load the click up
// load the click data up.
Path clkPath = Paths.get(clicksPath);
PredGroupedRawData clickData = null;
try {
Mat5File mfr = Mat5.readFromFile(clkPath.toAbsolutePath().normalize().toString());
// //get array of a name "my_array" from file
Struct mlArrayRetrived = mfr.getStruct( "newStruct" );
Matrix clickWavM = mlArrayRetrived.get("wave", 0);
Matrix modelInputM= mlArrayRetrived.get("wave_pad", 0);
double[][] clickWaveform= PamArrayUtils.matrix2array(clickWavM);
//get the raw model input so we can test the model directly.
double[][] pythonModelInput= PamArrayUtils.matrix2array(modelInputM);
pythonModelInput = PamArrayUtils.transposeMatrix(pythonModelInput);
float[] pythonModelInputF = PamArrayUtils.double2Float(pythonModelInput[0]);
Matrix clickUID= mlArrayRetrived.get("UID", 0);
Matrix pred= mlArrayRetrived.get("pred", 0);
//create a click object whihc we can pass through transforms etc.
clickData = new PredGroupedRawData(0L, 1, 0, clickWaveform[0].length, clickWaveform[0].length);
clickData.setUID(clickUID.getLong(0));
clickData.setRawData(clickWaveform);
clickData.setPrediction(new double[] {pred.getDouble(0)});
// load the model up
Path path = Paths.get(relModelPath);
GenericModelWorker genericModelWorker = new GenericModelWorker();
GenericModelParams genericModelParams = new GenericModelParams();
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
//create the transforms.
ArrayList<DLTransfromParams> dlTransformParamsArr = new ArrayList<DLTransfromParams>();
//waveform transforms.
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.DECIMATE_SCIPY, 248000.));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.PEAK_TRIM, 64, 1));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.NORMALISE_WAV, 0., 1, AudioData.ZSCORE));
genericModelParams.dlTransfromParams = dlTransformParamsArr;
genericModelParams.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>)genericModelParams.dlTransfromParams);
//create the clicks.
path = Paths.get(clicksPath);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
groupedData.add(clickData);
System.out.println("Waveform input: " + groupedData.get(0).getRawData().length + " " + groupedData.get(0).getRawData()[0].length);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData,SAMPLE_RATE, 0);
System.out.println("PAMGuard input len: " + pythonModelInputF.length);
float[] outputPAMGuard = genericPrediction.get(0).getPrediction();
System.out.println("Model output PAMGuard: " + outputPAMGuard[0]);
//RUN THE RAW MODEL with Python transformed input
System.out.println("Python input len: " + pythonModelInputF.length);
float[] outPutPython = genericModelWorker.getModel().runModel(new float[][] {pythonModelInputF});
System.out.println("Model output Python: " + outPutPython[0]);
}
catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
return;
}
}
@Test
public void clickDLTest() {
public void clicksDLTest() {
float SAMPLE_RATE = 500000;
//relative paths to the resource folders.
System.out.println("*****Click classification Deep Learning C*****");
@ -46,7 +151,7 @@ public class ClickDLTest {
GenericModelParams genericModelParams = new GenericModelParams();
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
//create the transforms.
ArrayList<DLTransfromParams> dlTransformParamsArr = new ArrayList<DLTransfromParams>();
@ -54,48 +159,48 @@ public class ClickDLTest {
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.DECIMATE_SCIPY, 248000.));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.PEAK_TRIM, 128, 1));
dlTransformParamsArr.add(new SimpleTransformParams(DLTransformType.NORMALISE_WAV, 0., 1, AudioData.ZSCORE));
genericModelParams.dlTransfromParams = dlTransformParamsArr;
genericModelParams.dlTransfroms = DLTransformsFactory.makeDLTransforms((ArrayList<DLTransfromParams>)genericModelParams.dlTransfromParams);
//create the clicks.
path = Paths.get(clicksPath);
ArrayList<PredGroupedRawData> clicks = importClicks(path.toAbsolutePath().normalize().toString(), SAMPLE_RATE);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
System.out.println("Model has loaded");
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
for (int i=0; i<1; i++) {
float prediction = (float) clicks.get(i).getPrediction()[0];
groupedData.add(clicks.get(i)); //TODO for loop
//System.out.println("Waveform input: " + groupedData.get(i).getRawData().length + " " + groupedData.get(i).getRawData()[0].length);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData,SAMPLE_RATE, 0);
float[] output = genericPrediction.get(i).getPrediction();
System.out.println(String.format("Click %d Predicted output: %.2f true output: %.2f passed: %b", clicks.get(i).getUID(),
output[0], prediction, output[0]>prediction*0.9 && output[0]<prediction*1.1));
}
}
/**
* Import a bunch of clicks from a .mat file
*/
public static ArrayList<PredGroupedRawData> importClicks(String filePath, float sR) {
try {
Mat5File mfr = Mat5.readFromFile(filePath);
Mat5File mfr = Mat5.readFromFile(filePath);
// //get array of a name "my_array" from file
Struct mlArrayRetrived = mfr.getStruct( "clickpreds" );
@ -106,12 +211,12 @@ public class ClickDLTest {
PredGroupedRawData clickData;
for (int i=0; i<numClicks; i++) {
Matrix clickWav= mlArrayRetrived.get("wave", i);
double[][] clickwaveform= PamArrayUtils.matrix2array(clickWav);
clickwaveform = PamArrayUtils.transposeMatrix(clickwaveform);
//System.out.println("click: " + click[0].length + " num: " + numClicks);
Matrix clickUID= mlArrayRetrived.get("UID", i);
Matrix clickmillis= mlArrayRetrived.get("millis", i);
Matrix channelMap= mlArrayRetrived.get("channelMap", i);
@ -123,10 +228,10 @@ public class ClickDLTest {
clickData.setUID(clickUID.getLong(0));
clickData.setRawData(clickwaveform);
clickData.setPrediction(new double[] {pred.getDouble(0), pred.getDouble(1)});
clicks.add(clickData);
}
return clicks;
}
catch (Exception e) {
@ -135,11 +240,11 @@ public class ClickDLTest {
return null;
}
}
public static class PredGroupedRawData extends GroupedRawData {
private double[] prediction;
public double[] getPrediction() {
return prediction;
}
@ -151,10 +256,10 @@ public class ClickDLTest {
public PredGroupedRawData(long timeMilliseconds, int channelBitmap, long startSample, long duration, int samplesize) {
super(timeMilliseconds, channelBitmap, startSample, duration, samplesize);
}
}
}