Finished testing suite for DelphinID

This commit is contained in:
Jamie Mac 2024-04-30 17:06:11 +01:00
parent 9ba3e7e3f7
commit fa6991eb80
8 changed files with 147 additions and 24 deletions

View File

@ -719,6 +719,17 @@ public class PamArrayUtils {
}
}
/**
* Print an array to the console.
* @param array to print
*/
public static void printArray(float[] array) {
if (array==null) System.out.println("null");
for (int i=0; i<array.length; i++) {
System.out.println(i + ": " + array[i]);
}
}
/**
* Print an array to the console with no index numbers
@ -806,14 +817,17 @@ public class PamArrayUtils {
/**
* Check whether there are duplicates within an array
* @param the array.
* @return true if there are duplicates.
* @return true if the array is unique
*/
public static boolean unique(double[] array) {
boolean duplicates=false;
boolean duplicates=true;
for (int j=0;j<array.length;j++)
for (int k=j+1;k<array.length;k++)
if (k!=j && array[k] == array[j])
duplicates=true;
if (k!=j && array[k] == array[j]) {
//if
duplicates=false;
return duplicates;
}
return duplicates;
}

View File

@ -234,9 +234,10 @@ public class DLControl extends PamControlledUnit implements PamSettings {
//manages default models
defaultModelManager = new DLDefaultModelManager(this);
//manages downloading models
modelDownloadManager = new DLDownloadManager();
// add storage options etc.
dlBinaryDataSource = new DLResultBinarySource(dlClassifyProcess);

View File

@ -128,6 +128,7 @@ public class ArchiveModelWorker extends GenericModelWorker {
// }
//generate the transforms from the KetosParams objects.
System.out.println(modelParams.dlTransforms);
ArrayList<DLTransform> transforms = DLTransformsFactory.makeDLTransforms(modelParams.dlTransforms);
// ///HACK here for now to fix an issue with dB and Ketos transforms having zero length somehow...

View File

@ -3,16 +3,19 @@ package rawDeepLearningClassifier.dlClassification.delphinID;
import java.io.IOException;
import java.util.ArrayList;
import org.jamdev.jdl4pam.utils.DLMatFile;
import PamUtils.PamArrayUtils;
import PamguardMVC.DataUnitBaseData;
import PamguardMVC.PamDataUnit;
import rawDeepLearningClassifier.dlClassification.animalSpot.StandardModelParams;
import rawDeepLearningClassifier.dlClassification.genericModel.StandardPrediction;
import rawDeepLearningClassifier.segmenter.SegmenterDetectionGroup;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.format.Mat5File;
import us.hebi.matlab.mat.types.MatFile;
import us.hebi.matlab.mat.types.Matrix;
import us.hebi.matlab.mat.types.Struct;
import whistleClassifier.WhistleContour;
import whistlesAndMoans.AbstractWhistleDataUnit;
@ -24,10 +27,10 @@ import whistlesAndMoans.AbstractWhistleDataUnit;
*/
public class DelphinIDTest {
public static DelphinIDWorker prepDelphinIDModel(String modelPath) {
public static DelphinIDWorkerTest prepDelphinIDModel(String modelPath) {
//create the delphinID worker.
DelphinIDWorker delphinIDWorker = new DelphinIDWorker();
DelphinIDWorkerTest delphinIDWorker = new DelphinIDWorkerTest();
StandardModelParams params = new StandardModelParams();
params.modelPath = modelPath;
@ -134,7 +137,7 @@ public class DelphinIDTest {
SegmenterDetectionGroup whistleGroup;
while (segStart<endTime){
whistleGroup = new SegmenterDetectionGroup(segStart, 1, segEnd, segHop);
whistleGroup = new SegmenterDetectionGroup(segStart, 1, segEnd, segLen);
for (AbstractWhistleDataUnit whislte: whistles) {
whistleStart = whislte.getTimeMilliseconds();
@ -148,6 +151,8 @@ public class DelphinIDTest {
}
group.add(whistleGroup);
// System.out.println("SegStart: " + (segStart - dataStartMillis));
segStart = (long) (segStart+segHop);
segEnd = (long) (segStart+segLen);
@ -185,6 +190,27 @@ public class DelphinIDTest {
}
public static class DelphinIDWorkerTest extends DelphinIDWorker {
private float[][][] lastModelInput;
public float[][][] dataUnits2ModelInput(ArrayList<? extends PamDataUnit> dataUnits, float sampleRate, int iChan){
float[][][] data = super.dataUnits2ModelInput(dataUnits, sampleRate, iChan);
this.lastModelInput = data;
return data;
}
public float[][][] getLastModelInput() {
return lastModelInput;
}
}
/**
* Main class for running the test.
@ -199,15 +225,20 @@ public class DelphinIDTest {
long dataStartMillis = 1340212413000L;
//path to the .mat containing whistle contours.
String whistleContourPath = "D:\\Dropbox\\PAMGuard_dev\\Deep_Learning\\delphinID\\testencounter415\\whistle_contours.mat";
String whistleContourPath = "D:/Dropbox/PAMGuard_dev/Deep_Learning/delphinID/testencounter415/whistle_contours.mat";
//the path to the model
String modelPath = "D:/Dropbox/PAMGuard_dev/Deep_Learning/delphinID/testencounter415/whistle_4s_encounter415.zip";
//the path to the model
String matImageSave = "C:/Users/Jamie Macaulay/MATLAB Drive/MATLAB/PAMGUARD/deep_learning/delphinID/whistleimages.mat";
//create MatFile for saving the image data to.
MatFile matFile = Mat5.newMatFile();
//get the whislte contours form a .mat file.
ArrayList<AbstractWhistleDataUnit> whistleContours = getWhistleContoursMAT(whistleContourPath);
//segment the whistle detections
ArrayList<SegmenterDetectionGroup> segments = segmentWhsitleData(whistleContours, dataStartMillis,
segLen, segHop);
@ -216,7 +247,13 @@ public class DelphinIDTest {
System.out.println("Segment " + i + " contains " + segments.get(i).getSubDetectionsCount() + " whistles");
}
DelphinIDWorker model = prepDelphinIDModel(modelPath);
//prepare the model - this loads the zip file and loads the correct transforms.
DelphinIDWorkerTest model = prepDelphinIDModel(modelPath);
model.setEnableSoftMax(false);
//initialise strcuture for image data
Struct imageStruct = Mat5.newStruct(segments.size(), 1);
for (int i=0; i<segments.size(); i++) {
@ -228,16 +265,35 @@ public class DelphinIDTest {
ArrayList<StandardPrediction> predicition = model.runModel(aSegment, sampleRate, 1);
float[] output = predicition.get(0).getPrediction();
System.out.println();
System.out.print("Segment: " +(aSegment.get(0).getSegmentStartMillis()-dataStartMillis)/1000.);
for (int j=0; j<output.length; j++) {
System.out.print(" " + output[j]);
}
Matrix image = DLMatFile.array2Matrix(PamArrayUtils.float2Double(model.getLastModelInput()[0]));
imageStruct.set("image", i, image);
imageStruct.set("startmillis", i, Mat5.newScalar(aSegment.get(0).getSegmentStartMillis()));
imageStruct.set("startseconds", i, Mat5.newScalar((aSegment.get(0).getSegmentStartMillis()-dataStartMillis)/1000.));
imageStruct.set("prediction", i, DLMatFile.array2Matrix(PamArrayUtils.float2Double(output)));
}
matFile.addArray("whistle_images", imageStruct);
// Serialize to disk using default configurations
try {
Mat5.writeToFile(matFile,matImageSave);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
// for (int i=0; i<whistleContours.size(); i++) {
// System.out.println("Whislte: " + i);
// PamArrayUtils.printArray(whistleContours.get(i).getFreqsHz());
// }
}
}

View File

@ -109,8 +109,8 @@ public class DelphinIDWorker extends ArchiveModelWorker {
//contained within the SegmenterDetectionGroup unit.
Whistles2Image whistles2Image = new Whistles2Image(whistleGroups.get(j), whistleImageParams);
//add the transform to the current list.
this.getModelTransforms().add(0, whistles2Image);
//set the spec transform
((FreqTransform) this.getModelTransforms().get(0)).setSpecTransfrom(whistles2Image.getSpecTransfrom());
//process all the transforms.
DLTransform transform = modelTransforms.get(0);

View File

@ -2,11 +2,13 @@ package rawDeepLearningClassifier.dlClassification.delphinID;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.util.ArrayList;
import org.jamdev.jdl4pam.transforms.FreqTransform;
import org.jamdev.jpamutils.spectrogram.SpecTransform;
import PamUtils.PamArrayUtils;
import rawDeepLearningClassifier.segmenter.SegmenterDetectionGroup;
import whistlesAndMoans.AbstractWhistleDataUnit;
@ -61,17 +63,21 @@ public class Whistles2Image extends FreqTransform {
ArrayList<double[][]> points = whistContours2Points(whistleGroup);
//does not work becaue it has to be on the AWT thread.
BufferedImage canvas = makeScatterImage(points, size, new double[]{0, whistleGroup.getDurationInMilliseconds()}, freqLimits, 5.);
BufferedImage canvas = makeScatterImage(points, size, new double[]{0, whistleGroup.getSegmentDuration()/1000.}, freqLimits, 5.);
double[][] imaged = new double[(int) size[0]][(int) size[1]];
int color;
float[] color = new float[3];
Raster raster = canvas.getData();
for (int i=0; i<imaged.length; i++) {
for (int j=0; j<imaged[0].length; j++) {
color = canvas.getRGB(i, j);
imaged[i][j] = color;
color = raster.getPixel(i, j, color);
imaged[i][j] = color[0]/255.; //normalize
}
}
//
// System.out.println("Original image: ");
// PamArrayUtils.printArray(imaged);
specTransform.setSpecData(imaged);
specTransform.setSampleRate((float) (freqLimits[1]*2));
@ -90,15 +96,36 @@ public class Whistles2Image extends FreqTransform {
AbstractWhistleDataUnit whistleContour;
long segStart = whistleGroup.getTimeMilliseconds();
long segStart = whistleGroup.getSegmentStartMillis();
long segEnd = (long) (whistleGroup.getSegmentStartMillis() + whistleGroup.getSegmentDuration());
// for (int i=0; i<whistleGroup.getSubDetectionsCount(); i++) {
// whistleContour = (AbstractWhistleDataUnit) whistleGroup.getSubDetection(i);
//
// long whistleStart = whistleContour.getTimeMilliseconds();
// long whistleEnd = (long) (whistleContour.getTimeMilliseconds() + whistleContour.getDurationInMilliseconds());
//
// if ((whistleStart>=segStart && whistleStart<segEnd) || ((whistleEnd>=segStart && whistleEnd<segEnd))){
// //some part of the whistle is in the segment.
// System.out.println("Whistle in group? true");
//
// }
// else {
// System.out.println("Whistle in group? false!!!");
// }
//
// }
for (int i=0; i<whistleGroup.getSubDetectionsCount(); i++) {
whistleContour = (AbstractWhistleDataUnit) whistleGroup.getSubDetection(i);
// System.out.println("Whistle start time: " + (segStart - whistleContour.getTimeMilliseconds())/1000. + " end: " + (segStart - whistleContour.getTimeMilliseconds() + whistleContour.getDurationInMilliseconds())/1000.);
double[][] contourD = new double[whistleContour.getSliceCount()][2];
for (int j=0; j<whistleContour.getSliceCount(); j++) {
contourD[j][0] = (segStart - whistleContour.getTimeMilliseconds())/1000. + whistleContour.getTimesInSeconds()[i];
contourD[j][0] = (whistleContour.getTimeMilliseconds()-segStart)/1000. + whistleContour.getTimesInSeconds()[j];
contourD[j][1] = whistleContour.getFreqsHz()[j];
}
contours.add(contourD);
@ -158,7 +185,9 @@ public class Whistles2Image extends FreqTransform {
//Calculate x and y in pixels.
x = ((points.get(j)[i][0]-xlims[0])/(xlims[1]-xlims[0]))*size[0];
y = ((points.get(j)[i][0]-xlims[0])/(xlims[1]-xlims[0]))*size[0];
y = ((points.get(j)[i][1]-ylims[0])/(ylims[1]-ylims[0]))*size[1];
// System.out.println("Fill oval: x" + x + " y: " + y + " time: " + points.get(j)[i][0]);
canvas.getGraphics().fillOval((int) (x+markerSize/2),(int) (y-markerSize/2), (int) markerSize,(int) markerSize);
}

View File

@ -21,5 +21,5 @@ public class SegmenterDataBlock extends PamDataBlock<GroupedRawData> {
//need this to notify classifier in viewer mode.
return true;
}
}

View File

@ -10,6 +10,17 @@ import PamguardMVC.PamDataUnit;
*/
public class SegmenterDetectionGroup extends GroupDetection<PamDataUnit> {
/**
* The duration of the segment in millis.
*/
private double segDuration;
/**
* The start time fo the segment in millis.
*/
private long segMillis;
/**
* Constructor for a group of detections within a detection. Note that some
* longer detections (e.g. whistles) may have sections outside the segment.
@ -22,6 +33,8 @@ public class SegmenterDetectionGroup extends GroupDetection<PamDataUnit> {
public SegmenterDetectionGroup(long timeMilliseconds, int channelBitmap, long startSample, double duration) {
super(timeMilliseconds, channelBitmap, startSample, (long) duration);
this.setDurationInMilliseconds(duration);
this.segMillis =timeMilliseconds;
this.segDuration = duration;
}
@Override
@ -29,6 +42,15 @@ public class SegmenterDetectionGroup extends GroupDetection<PamDataUnit> {
//segmetns share sub detections
return true;
}
public long getSegmentStartMillis() {
return segMillis;
}
public double getSegmentDuration() {
return segDuration;
}
}