Unit tests for multi species classifier and deep learning message improvements

This commit is contained in:
Jamie Mac 2025-02-14 15:14:57 +00:00
parent 02b9cc5000
commit e3b9094675
9 changed files with 255 additions and 80 deletions

View File

@ -6,11 +6,12 @@ package rawDeepLearningClassifier;
public enum DLStatus {
FILE_NULL("The input file is null",
"The loaded file was null. If the file was download it may not have downloaded properly.", ErrorLevel.ERROR),
"The loaded file was null. If the file was downloaded, it may not have downloaded properly.", ErrorLevel.ERROR),
//generic error when something has gone wrong.
MODEL_LOAD_FAIL("The model failed to load",
" The model failed to load - this could be because it is incompatible with PAMGuard or an uknown file format.",
" The model failed to load - this could be because it is incompatible with PAMGuard or an unknown file format.",
ErrorLevel.ERROR),
MODEL_LOAD_SUCCESS("The model loaded", " The model successfully loaded", ErrorLevel.NO_ERROR),
@ -43,7 +44,10 @@ public enum DLStatus {
MODEL_ENGINE_FAIL("Deep learning engine failure", "The deep learning engine was not able to load. It is likely PAMGuard could not download the relevent libraries. Visit www.pamguard.org/deeplearning for a work around.", ErrorLevel.ERROR),
MODEL_META_FAIL("Model metadata failure", "PAMGuard could not locte the metadata within the model. if this is a custom model ensure the associated .pgdl file is present and JSON valid", ErrorLevel.ERROR);
MODEL_META_FAIL("Model metadata failure", "PAMGuard could not locte the metadata within the model. If this is a custom model ensure the associated .pgdl file is present and JSON valid", ErrorLevel.ERROR),
MODEL_FILE_EXISTS("The input file does not exists",
"The selected model file does not exist on your computer. This can happen if you use a configuration on a different computer or have moved the file", ErrorLevel.ERROR),;
/**

View File

@ -1,6 +1,8 @@
package rawDeepLearningClassifier.defaultModels;
import java.util.ArrayList;
import PamModel.SMRUEnable;
import rawDeepLearningClassifier.DLControl;
/**
@ -27,8 +29,10 @@ public class DLDefaultModelManager {
defaultModels.add(new RightWhaleModel1());
defaultModels.add(new HumpbackWhaleGoogle());
defaultModels.add(new HumpbackWhaleAtlantic());
if (SMRUEnable.isDevEnable()) {
defaultModels.add(new MultiSpeciesGoogle());
}
}
/**

View File

@ -1,5 +1,6 @@
package rawDeepLearningClassifier.dlClassification.animalSpot;
import java.io.File;
import java.nio.file.Paths;
import org.jamdev.jdl4pam.animalSpot.AnimalSpotModel;
@ -50,6 +51,14 @@ public class SoundSpotWorker extends DLModelWorker<StandardPrediction> {
try {
if (soundSpotParams.modelPath==null) {
return DLStatus.FILE_NULL;
}
if (!new File(soundSpotParams.modelPath).exists()) {
return DLStatus.MODEL_FILE_EXISTS;
}
// // get the plugin class loader and set it as the context class loader
// // NOTE THAT THIS IS REQUIRED TO MAKE THIS MODULE RUN AS A PLUGIN WHEN THE CLASS FILES
// // ARE BUNDLED INTO A FATJAR, HOWEVER THIS WILL STOP THE PLUGIN FROM RUNNING AS A SEPARATE

View File

@ -58,6 +58,15 @@ public class ArchiveModelWorker extends GenericModelWorker {
//ClassLoader origCL = Thread.currentThread().getContextClassLoader();
try {
if (dlParams.modelPath==null) {
return DLStatus.FILE_NULL;
}
if (!new File(dlParams.modelPath).exists()) {
System.out.println("FILE DOES NOT EXISTS!! " + dlParams.modelPath);
return DLStatus.MODEL_FILE_EXISTS;
}
// get the plugin class loader and set it as the context class loader
// NOTE THAT THIS IS REQUIRED TO MAKE THIS MODULE RUN AS A PLUGIN WHEN THE CLASS FILES
// ARE BUNDLED INTO A FATJAR, HOWEVER THIS WILL STOP THE PLUGIN FROM RUNNING AS A SEPARATE

View File

@ -378,7 +378,7 @@ public class DLModelSelectPane extends PamBorderPane {
File fileChck = new File(file);
if (!fileChck.exists()) {
return DLStatus.MODEL_LOAD_FAIL;
return DLStatus.MODEL_FILE_EXISTS;
}
//we are loading model from a file - anything can happen so put in a try catch.

View File

@ -1,27 +1,30 @@
package test.rawDeepLearningClassifier;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import org.jamdev.jdl4pam.genericmodel.GenericModel;
import org.jamdev.jdl4pam.transforms.DLTransform.DLTransformType;
import org.jamdev.jdl4pam.transforms.SimpleTransformParams;
import org.jamdev.jdl4pam.utils.DLMatFile;
import org.jamdev.jdl4pam.utils.DLUtils;
import org.jamdev.jpamutils.wavFiles.AudioData;
import org.junit.jupiter.api.Test;
import PamUtils.PamArrayUtils;
import rawDeepLearningClassifier.defaultModels.HumpbackWhaleAtlantic;
import rawDeepLearningClassifier.defaultModels.MultiSpeciesGoogle;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.GenericModelWorker;
import rawDeepLearningClassifier.dlClassification.genericModel.StandardPrediction;
import rawDeepLearningClassifier.segmenter.GroupedRawData;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.types.MatFile;
import us.hebi.matlab.mat.types.Matrix;
import us.hebi.matlab.mat.types.Struct;
/**
@ -49,7 +52,7 @@ public class PamZipDLClassifierTest {
/**
* Test Google's humpback whale model.
* Test google's multi-species model against outputs obtained in Python.
*/
@Test
public void multiSpeciesGoogleTest() {
@ -58,9 +61,12 @@ public class PamZipDLClassifierTest {
//relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/multispecies-whale-tensorflow2-default-v2/saved_model.pb";
String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz.wav";
String resultsPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/Cross_24kHz_scores.csv";
Path path = Paths.get(relModelPath);
GenericModelWorker genericModelWorker = new GenericModelWorker();
@ -77,31 +83,26 @@ public class PamZipDLClassifierTest {
System.out.println(genericModelParams);
//TEMP
try {
//load audio
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
// double[] amplitudes = soundData.getScaledSampleAmplitudes();
//
//
// float[] ampltiudesf = PamArrayUtils.double2Float(amplitudes);
// float[] ampltiudesfchunk = new float[120000];
// int offset = 24000;
// System.arraycopy(ampltiudesf, 0+offset, ampltiudesfchunk, 0, 120000);
//
// // System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(amplitudes));
//
// GenericModel genericModel = new GenericModel( genericModelParams.modelPath) ;
// float[][] input = new float[][]{ampltiudesfchunk};
// float[] output2 = genericModel.runModel(input);
//
// System.out.println("----TEST OUT----");
// PamArrayUtils.printArray(output2);
// double[] amplitudes = soundData.getScaledSampleAmplitudes();
//
//
// float[] ampltiudesf = PamArrayUtils.double2Float(amplitudes);
// float[] ampltiudesfchunk = new float[120000];
// int offset = 24000;
// System.arraycopy(ampltiudesf, 0+offset, ampltiudesfchunk, 0, 120000);
//
// // System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(amplitudes));
//
// GenericModel genericModel = new GenericModel( genericModelParams.modelPath) ;
// float[][] input = new float[][]{ampltiudesfchunk};
// float[] output2 = genericModel.runModel(input);
//
// System.out.println("----TEST OUT----");
// PamArrayUtils.printArray(output2);
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
@ -115,6 +116,8 @@ public class PamZipDLClassifierTest {
int ind=0;
int startChunk=0;
//each line is a list of prediction for each class;
int chunkSize=120000;
int classIndex = 1;
while((line = br.readLine()) != null){
@ -127,8 +130,6 @@ public class PamZipDLClassifierTest {
predictions[i] = Double.valueOf(data[i]);
}
//each line is a list of prediction for each class;
int chunkSize=120000;
GroupedRawData groupedRawData = new GroupedRawData(0, 1, 0, chunkSize, chunkSize);
groupedRawData.copyRawData(soundData.getScaledSampleAmplitudes(), startChunk, chunkSize, 0);
@ -166,4 +167,152 @@ public class PamZipDLClassifierTest {
}
/**
* Test Google'smulti species model by feeding it two wav files, one which is upsampled to 24kHz and another which needs upsampled by the transforms
* to 24kHz. Verify that the upsampling works
*/
@Test
public void multiSpeciesGoogleTest2() {
System.out.println("*****Generic DL: google-multi-species test 2*****");
//relative paths to the resource folders.
String relModelPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/multispecies-whale-tensorflow2-default-v2/saved_model.pb";
//path to the same file at different sample rates
String relWavPath = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500.wav";
String relWavPath2 = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500_upsample.wav";
// String relWavPath2 = "./src/test/resources/rawDeepLearningClassifier/Generic/multi-species-Google/NOPP6_EST_20090329_121500.wav";
String matFileOut = "/Users/jdjm/MATLAB-Drive/MATLAB/PAMGUARD/deep_learning/google_multi_species/google_multi_species.mat";
//hold the paths in an array
String[] soundFiles = new String[] {relWavPath, relWavPath2};
Path path = Paths.get(relModelPath);
GenericModelWorker genericModelWorker = new GenericModelWorker();
GenericModelParams genericModelParams = new GenericModelParams();
genericModelParams.modelPath = path.toAbsolutePath().normalize().toString();
MultiSpeciesGoogle multiSpeciesGoogle = new MultiSpeciesGoogle();
multiSpeciesGoogle.setParams(genericModelParams);
genericModelParams.dlTransfromParams.set(0, new SimpleTransformParams(DLTransformType.DECIMATE, 24000.));
System.out.println(genericModelParams);
double segSize = 5.; //one second hop size.
double segHop = 5.; //one second hop size.
int classIndex = 5; //Right whale atlantic - jus for output
//create MatFile for saving the image data to.
MatFile matFile = Mat5.newMatFile();
//prep the model
genericModelWorker.prepModel(genericModelParams, null);
ArrayList<float[][]> fileOutputs = new ArrayList<float[][]>();
for (int i=0; i<soundFiles.length; i++) {
try {
path = Paths.get(soundFiles[i]);
String wavFilePath = path.toAbsolutePath().normalize().toString();
//load audio
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
// if (i==1) {
// soundData=soundData.interpolate(24000);
// }
int nseg = (int) (soundData.samples.length/(segHop*soundData.sampleRate));
float[][] outputs = new float[nseg][];
int startChunk=0;
//initialise strcuture for image data
Struct waveStruct = Mat5.newStruct(nseg, 1);
//each line is a list of prediction for each class;
int chunkSize=(int) (segSize*soundData.sampleRate);
int ind = 0;
System.out.println("Generic DL: google-multi-species test: processing file: " + i + " chunkSize: " + chunkSize + " nseg " + nseg);
while(startChunk<(soundData.samples.length-chunkSize)){
GroupedRawData groupedRawData = new GroupedRawData(0, 1, 0, chunkSize, chunkSize);
groupedRawData.copyRawData(soundData.getScaledSampleAmplitudes(), startChunk, chunkSize, 0);
//System.out.println("MAX AMPLITUDE: " + PamArrayUtils.max(groupedRawData.getRawData()[0]));
ArrayList<GroupedRawData> groupedData = new ArrayList<GroupedRawData>();
groupedData.add(groupedRawData);
ArrayList<StandardPrediction> genericPrediction = genericModelWorker.runModel(groupedData, soundData.sampleRate, 0);
float[] output = genericPrediction.get(0).getPrediction();
//----TEST output for MATLAB----
float[][][] dataOut = genericModelWorker.dataUnits2ModelInput(groupedData, soundData.sampleRate, 0);
float[] waveIn = dataOut[0][0];
Matrix modelinput = DLMatFile.array2Matrix(PamArrayUtils.float2Double(waveIn));
Matrix modeloutput = DLMatFile.array2Matrix(PamArrayUtils.float2Double(output));
waveStruct.set("modelinput", ind, modelinput);
waveStruct.set("startseconds", ind, Mat5.newScalar(startChunk/soundData.sampleRate));
waveStruct.set("prediction", ind, modeloutput);
// System.out.println(String.format("File %d Chunk %d %d Predicted output: %.5f ", i, ind, startChunk,
// output[classIndex]));
outputs[ind] = output;
ind++;
//PamArrayUtils.printArray(output);
startChunk+=(int) (segHop*soundData.sampleRate); //one second step
}
matFile.addArray(("file_" + i + "_outputs"), waveStruct);
matFile.addArray(("file_" + i), Mat5.newString(wavFilePath));
fileOutputs.add(outputs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
//now compare the outputer
for (int i=0; i<fileOutputs.get(0).length; i++) {
if (fileOutputs.get(0)[i]==null) continue;
System.out.println(String.format("Chunk %d File 0 - %.3f File - 1 %.3f diff %.3f", i,
fileOutputs.get(0)[i][classIndex], fileOutputs.get(1)[i][classIndex],
(Math.abs(fileOutputs.get(0)[i][classIndex]-fileOutputs.get(1)[i][classIndex]))));
}
if (matFileOut!=null) {
System.out.println("Writing mat file");
// Serialize to disk using default configurations
try {
Mat5.writeToFile(matFile,matFileOut);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}