Unit tests for multi species classifier and deep learning message improvements

This commit is contained in:
Jamie Mac 2025-02-14 15:14:57 +00:00
parent 02b9cc5000
commit e3b9094675
9 changed files with 255 additions and 80 deletions

View File

@ -6,11 +6,12 @@ package rawDeepLearningClassifier;
public enum DLStatus {
FILE_NULL("The input file is null",
"The loaded file was null. If the file was download it may not have downloaded properly.", ErrorLevel.ERROR),
"The loaded file was null. If the file was downloaded, it may not have downloaded properly.", ErrorLevel.ERROR),
//generic error when something has gone wrong.
MODEL_LOAD_FAIL("The model failed to load",
" The model failed to load - this could be because it is incompatible with PAMGuard or an uknown file format.",
" The model failed to load - this could be because it is incompatible with PAMGuard or an unknown file format.",
ErrorLevel.ERROR),
MODEL_LOAD_SUCCESS("The model loaded", " The model successfully loaded", ErrorLevel.NO_ERROR),
@ -43,7 +44,10 @@ public enum DLStatus {
MODEL_ENGINE_FAIL("Deep learning engine failure", "The deep learning engine was not able to load. It is likely PAMGuard could not download the relevent libraries. Visit www.pamguard.org/deeplearning for a work around.", ErrorLevel.ERROR),
MODEL_META_FAIL("Model metadata failure", "PAMGuard could not locte the metadata within the model. if this is a custom model ensure the associated .pgdl file is present and JSON valid", ErrorLevel.ERROR);
MODEL_META_FAIL("Model metadata failure", "PAMGuard could not locte the metadata within the model. If this is a custom model ensure the associated .pgdl file is present and JSON valid", ErrorLevel.ERROR),
MODEL_FILE_EXISTS("The input file does not exists",
"The selected model file does not exist on your computer. This can happen if you use a configuration on a different computer or have moved the file", ErrorLevel.ERROR),;
/**

View File

@ -1,6 +1,8 @@
package rawDeepLearningClassifier.defaultModels;
import java.util.ArrayList;
import PamModel.SMRUEnable;
import rawDeepLearningClassifier.DLControl;
/**
@ -27,7 +29,9 @@ public class DLDefaultModelManager {
defaultModels.add(new RightWhaleModel1());
defaultModels.add(new HumpbackWhaleGoogle());
defaultModels.add(new HumpbackWhaleAtlantic());
defaultModels.add(new MultiSpeciesGoogle());
if (SMRUEnable.isDevEnable()) {
defaultModels.add(new MultiSpeciesGoogle());
}
}

View File

@ -1,5 +1,6 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.io.File;
import java.nio.file.Paths;
import org.jamdev.jdl4pam.animalSpot.AnimalSpotModel;
@ -49,6 +50,14 @@ public class SoundSpotWorker extends DLModelWorker<StandardPrediction> {
//System.out.println("prepModel: " + soundSpotParams.useDefaultTransfroms);
try {
if (soundSpotParams.modelPath==null) {
return DLStatus.FILE_NULL;
}
if (!new File(soundSpotParams.modelPath).exists()) {
return DLStatus.MODEL_FILE_EXISTS;
}
// // get the plugin class loader and set it as the context class loader
// // NOTE THAT THIS IS REQUIRED TO MAKE THIS MODULE RUN AS A PLUGIN WHEN THE CLASS FILES

View File

@ -57,6 +57,15 @@ public class ArchiveModelWorker extends GenericModelWorker {
public synchronized DLStatus prepModel(StandardModelParams dlParams, DLControl dlControl) {
//ClassLoader origCL = Thread.currentThread().getContextClassLoader();
try {
if (dlParams.modelPath==null) {
return DLStatus.FILE_NULL;
}
if (!new File(dlParams.modelPath).exists()) {
System.out.println("FILE DOES NOT EXISTS!! " + dlParams.modelPath);
return DLStatus.MODEL_FILE_EXISTS;
}
// get the plugin class loader and set it as the context class loader
// NOTE THAT THIS IS REQUIRED TO MAKE THIS MODULE RUN AS A PLUGIN WHEN THE CLASS FILES

View File

@ -378,7 +378,7 @@ public class DLModelSelectPane extends PamBorderPane {
File fileChck = new File(file);
if (!fileChck.exists()) {
return DLStatus.MODEL_LOAD_FAIL;
return DLStatus.MODEL_FILE_EXISTS;
}
//we are loading model from a file - anything can happen so put in a try catch.

View File

@ -1,27 +1,30 @@
package test.rawDeepLearningClassifier;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import org.jamdev.jdl4pam.genericmodel.GenericModel;
import org.jamdev.jdl4pam.transforms.DLTransform.DLTransformType;
import org.jamdev.jdl4pam.transforms.SimpleTransformParams;
import org.jamdev.jdl4pam.utils.DLMatFile;
import org.jamdev.jdl4pam.utils.DLUtils;
import org.jamdev.jpamutils.wavFiles.AudioData;
import org.junit.jupiter.api.Test;
import PamUtils.PamArrayUtils;
import rawDeepLearningClassifier.defaultModels.HumpbackWhaleAtlantic;
import rawDeepLearningClassifier.defaultModels.MultiSpeciesGoogle;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.StandardPrediction;
import rawDeepLearningClassifier.segmenter.GroupedRawData;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.types.MatFile;
import us.hebi.matlab.mat.types.Matrix;
import us.hebi.matlab.mat.types.Struct;
/**
@ -37,11 +40,11 @@ public class PamZipDLClassifierTest {
// public void blueWhaleKooguTest() {
// //relative paths to the resource folders.
// String relModelPath = "./src/test/resources/rawDeepLearningClassifier/PamZip/blue_whale_24.zip";
//
//
// //the zip classifier is the same as the
// String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Koogu/blue_whale_24/20190527_190000.wav";
// String relMatPath = "./src/test/resources/rawDeepLearningClassifier/Koogu/blue_whale_24/rawScores_20190527_190000.mat";
//
//
// //metadata says it should be used with Koogu classifier.
// KooguDLClassifierTest.runKooguClassifier( relModelPath, relWavPath, relMatPath);
// }
@ -49,17 +52,145 @@ public class PamZipDLClassifierTest {
/**
* Test Google's humpback whale model.
* Test google's multi-species model against outputs obtained in Python.
*/
@Test
public void multiSpeciesGoogleTest() {
System.out.println("*****Generic DL: Humpback whale test*****");
//relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/multispecies-whale-tensorflow2-default-v2/saved_model.pb";
String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz.wav";
String resultsPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz_scores.csv";
Path path = Paths.get(relModelPath);
GenericModelWorker genericModelWorker = new GenericModelWorker();
GenericModelParams genericModelParams = new GenericModelParams();
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
path = Paths.get(relWavPath);
String wavFilePath = path.toAbsolutePath().normalize().toString();
MultiSpeciesGoogle multiSpeciesGoogle = new MultiSpeciesGoogle();
multiSpeciesGoogle.setParams(genericModelParams);
System.out.println(genericModelParams);
try {
//load audio
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
// double[] amplitudes = soundData.getScaledSampleAmplitudes();
//
//
// float[] ampltiudesf = PamArrayUtils.double2Float(amplitudes);
// float[] ampltiudesfchunk = new float[120000];
// int offset = 24000;
// System.arraycopy(ampltiudesf, 0+offset, ampltiudesfchunk, 0, 120000);
//
// // System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(amplitudes));
//
// GenericModel genericModel = new GenericModel( genericModelParams.modelPath) ;
// float[][] input = new float[][]{ampltiudesfchunk};
// float[] output2 = genericModel.runModel(input);
//
// System.out.println("----TEST OUT----");
// PamArrayUtils.printArray(output2);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
//load true predictions file.
File file = new File(resultsPath);
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
int ind=0;
int startChunk=0;
//each line is a list of prediction for each class;
int chunkSize=120000;
int classIndex = 1;
while((line = br.readLine()) != null){
if (ind>0) {
//read the data from the text file
String[] data = line.split(",");
double[] predictions = new double[data.length];
for (int i=0; i<data.length; i++) {
predictions[i] = Double.valueOf(data[i]);
}
GroupedRawData groupedRawData = new GroupedRawData(0, 1, 0, chunkSize, chunkSize);
groupedRawData.copyRawData(soundData.getScaledSampleAmplitudes(), startChunk, chunkSize, 0);
// System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(groupedRawData.getRawData()[0]));
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
groupedData.add(groupedRawData);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData, soundData.sampleRate, 0);
float[] output = genericPrediction.get(0).getPrediction();
boolean passed = (output[classIndex]>predictions[classIndex]-0.05 && output[classIndex]<predictions[classIndex]+0.05);
System.out.println(String.format("Chunk %d %d Predicted output: %.5f true output: %.5f passed: %b -- sum %.2f ", ind, startChunk,
output[classIndex], predictions[classIndex], passed, PamArrayUtils.sum(output)));
//PamArrayUtils.printArray(output);
//allow 10% scrumph to take account of slight differences in Java input.
//assertTrue(passed); //humpback whale detection
startChunk+=24000; //one second step
}
ind++;
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* Test Google'smulti species model by feeding it two wav files, one which is upsampled to 24kHz and another which needs upsampled by the transforms
* to 24kHz. Verify that the upsampling works
*/
@Test
public void multiSpeciesGoogleTest() {
public void multiSpeciesGoogleTest2() {
System.out.println("*****Generic DL: Humpback whale test*****");
System.out.println("*****Generic DL: google-multi-species test 2*****");
//relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/multispecies-whale-tensorflow2-default-v2/saved_model.pb";
String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz.wav";
String resultsPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz_scores.csv";
//path to the same file at different sample rates
String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500.wav";
String relWavPath2 = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500_upsample.wav";
// String relWavPath2 = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500.wav";
String matFileOut = "/Users/jdjm/MATLAB-Drive/MATLAB/PAMGUARD/deep_learning/google_multi_species/google_multi_species.mat";
//hold the paths in an array
String[] soundFiles = new String[] {relWavPath, relWavPath2};
Path path = Paths.get(relModelPath);
@ -69,100 +200,118 @@ public class PamZipDLClassifierTest {
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
path = Paths.get(relWavPath);
String wavFilePath = path.toAbsolutePath().normalize().toString();
MultiSpeciesGoogle multiSpeciesGoogle = new MultiSpeciesGoogle();
multiSpeciesGoogle.setParams(genericModelParams);
genericModelParams.dlTransfromParams.set(0, new SimpleTransformParams(DLTransformType.DECIMATE, 24000.));
System.out.println(genericModelParams);
//TEMP
double segSize = 5.; //one second hop size.
double segHop = 5.; //one second hop size.
int classIndex = 5; //Right whale atlantic - jus for output
//create MatFile for saving the image data to.
MatFile matFile = Mat5.newMatFile();
try {
//load audio
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
// double[] amplitudes = soundData.getScaledSampleAmplitudes();
//
//
// float[] ampltiudesf = PamArrayUtils.double2Float(amplitudes);
// float[] ampltiudesfchunk = new float[120000];
// int offset = 24000;
// System.arraycopy(ampltiudesf, 0+offset, ampltiudesfchunk, 0, 120000);
//
// // System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(amplitudes));
//
// GenericModel genericModel = new GenericModel( genericModelParams.modelPath) ;
// float[][] input = new float[][]{ampltiudesfchunk};
// float[] output2 = genericModel.runModel(input);
//
// System.out.println("----TEST OUT----");
// PamArrayUtils.printArray(output2);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
ArrayList<float[][]> fileOutputs = new ArrayList<float[][]>();
for (int i=0; i<soundFiles.length; i++) {
try {
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
path = Paths.get(soundFiles[i]);
String wavFilePath = path.toAbsolutePath().normalize().toString();
//load audio
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
// if (i==1) {
// soundData=soundData.interpolate(24000);
// }
//load true predictions file.
File file = new File(resultsPath);
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
int ind=0;
int nseg = (int) (soundData.samples.length/(segHop*soundData.sampleRate));
float[][] outputs = new float[nseg][];
int startChunk=0;
int startChunk=0;
//initialise strcuture for image data
Struct waveStruct = Mat5.newStruct(nseg, 1);
int classIndex = 1;
while((line = br.readLine()) != null){
if (ind>0) {
//read the data from the text file
String[] data = line.split(",");
//each line is a list of prediction for each class;
int chunkSize=(int) (segSize*soundData.sampleRate);
int ind = 0;
System.out.println("Generic DL: google-multi-species test: processing file: " + i + " chunkSize: " + chunkSize + " nseg " + nseg);
double[] predictions = new double[data.length];
for (int i=0; i<data.length; i++) {
predictions[i] = Double.valueOf(data[i]);
}
//each line is a list of prediction for each class;
int chunkSize=120000;
while(startChunk<(soundData.samples.length-chunkSize)){
GroupedRawData groupedRawData = new GroupedRawData(0, 1, 0, chunkSize, chunkSize);
groupedRawData.copyRawData(soundData.getScaledSampleAmplitudes(), startChunk, chunkSize, 0);
// System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(groupedRawData.getRawData()[0]));
//System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(groupedRawData.getRawData()[0]));
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
groupedData.add(groupedRawData);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData, soundData.sampleRate, 0);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData, soundData.sampleRate, 0);
float[] output = genericPrediction.get(0).getPrediction();
//----TEST output for MATLAB----
float[][][] dataOut = genericModelWorker.dataUnits2ModelInput(groupedData, soundData.sampleRate, 0);
float[] waveIn = dataOut[0][0];
boolean passed = (output[classIndex]>predictions[classIndex]-0.05 && output[classIndex]<predictions[classIndex]+0.05);
Matrix modelinput = DLMatFile.array2Matrix(PamArrayUtils.float2Double(waveIn));
Matrix modeloutput = DLMatFile.array2Matrix(PamArrayUtils.float2Double(output));
System.out.println(String.format("Chunk %d %d Predicted output: %.5f true output: %.5f passed: %b -- sum %.2f ", ind, startChunk,
output[classIndex], predictions[classIndex], passed, PamArrayUtils.sum(output)));
waveStruct.set("modelinput", ind, modelinput);
waveStruct.set("startseconds", ind, Mat5.newScalar(startChunk/soundData.sampleRate));
waveStruct.set("prediction", ind, modeloutput);
// System.out.println(String.format("File %d Chunk %d %d Predicted output: %.5f ", i, ind, startChunk,
// output[classIndex]));
outputs[ind] = output;
ind++;
//PamArrayUtils.printArray(output);
//allow 10% scrumph to take account of slight differences in Java input.
//assertTrue(passed); //humpback whale detection
startChunk+=24000; //one second step
startChunk+=(int) (segHop*soundData.sampleRate); //one second step
}
ind++;
matFile.addArray(("file_" + i + "_outputs"), waveStruct);
matFile.addArray(("file_" + i), Mat5.newString(wavFilePath));
fileOutputs.add(outputs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
//now compare the outputer
for (int i=0; i<fileOutputs.get(0).length; i++) {
if (fileOutputs.get(0)[i]==null) continue;
System.out.println(String.format("Chunk %d File 0 - %.3f File - 1 %.3f diff %.3f", i,
fileOutputs.get(0)[i][classIndex], fileOutputs.get(1)[i][classIndex],
(Math.abs(fileOutputs.get(0)[i][classIndex]-fileOutputs.get(1)[i][classIndex]))));
}
if (matFileOut!=null) {
System.out.println("Writing mat file");
// Serialize to disk using default configurations
try {
Mat5.writeToFile(matFile,matFileOut);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}