From c1f4ba9e376ed00c08b8296fc752ee9a0907c4d3 Mon Sep 17 00:00:00 2001 From: Jamie Mac Date: Fri, 3 May 2024 16:29:37 +0100 Subject: [PATCH] More work on delphinID classifier but not quite there yet. --- src/rawDeepLearningClassifier/DLControl.java | 1 + .../dlClassification/DLClassifyProcess.java | 184 +++++++++++------- .../StandardClassifierModel.java | 2 +- .../delphinID/DelphinIDClassifier.java | 2 +- .../delphinID/DelphinIDWorker.java | 5 +- .../genericModel/GenericDLClassifier.java | 2 +- .../genericModel/GenericModelWorker.java | 6 +- .../segmenter/SegmenterDetectionGroup.java | 6 +- .../segmenter/SegmenterGroupDataBlock.java | 2 + .../segmenter/SegmenterProcess.java | 118 +++++++++-- 10 files changed, 240 insertions(+), 88 deletions(-) diff --git a/src/rawDeepLearningClassifier/DLControl.java b/src/rawDeepLearningClassifier/DLControl.java index fda39d72..3b2ad27f 100644 --- a/src/rawDeepLearningClassifier/DLControl.java +++ b/src/rawDeepLearningClassifier/DLControl.java @@ -228,6 +228,7 @@ public class DLControl extends PamControlledUnit implements PamSettings { // classify the raw data segments. addPamProcess(dlClassifyProcess = new DLClassifyProcess(this, segmenterProcess.getSegmenterDataBlock())); + dlClassifyProcess.addMultiPlexDataBlock(segmenterProcess.getSegmenteGrouprDataBlock()); //manages the names assigned to different output classes. dlClassNameManager = new DLClassNameManager(this); diff --git a/src/rawDeepLearningClassifier/dlClassification/DLClassifyProcess.java b/src/rawDeepLearningClassifier/dlClassification/DLClassifyProcess.java index fe515cb1..bb0f8945 100644 --- a/src/rawDeepLearningClassifier/dlClassification/DLClassifyProcess.java +++ b/src/rawDeepLearningClassifier/dlClassification/DLClassifyProcess.java @@ -20,6 +20,7 @@ import rawDeepLearningClassifier.logging.DLAnnotation; import rawDeepLearningClassifier.logging.DLAnnotationType; import rawDeepLearningClassifier.segmenter.GroupedRawData; import rawDeepLearningClassifier.segmenter.SegmenterDataBlock; +import rawDeepLearningClassifier.segmenter.SegmenterDetectionGroup; /** * The deep learning classification process. This takes a segment of raw data from the segmenter. @@ -35,7 +36,7 @@ import rawDeepLearningClassifier.segmenter.SegmenterDataBlock; */ public class DLClassifyProcess extends PamInstantProcess { - + /** * Holds all model results but no other information */ @@ -64,15 +65,13 @@ public class DLClassifyProcess extends PamInstantProcess { /** * The DL buffer */ - private ArrayList classificationBuffer; - + private ArrayList classificationBuffer; /** * The DL annotation type. */ private DLAnnotationType dlAnnotationType; - /** * The last parent data for grouped data. This is used to ensure that DLDetections * correspond to the raw chunk of data from a parent detection e.g. a click detection. @@ -82,8 +81,8 @@ public class DLClassifyProcess extends PamInstantProcess { public DLClassifyProcess(DLControl dlControl, SegmenterDataBlock parentDataBlock) { super(dlControl); - - + + this.setParentDataBlock(parentDataBlock); // this.setParentDataBlock(parentDataBlock); @@ -117,7 +116,7 @@ public class DLClassifyProcess extends PamInstantProcess { overlayGraphics.setDetectionData(true); dlDetectionDataBlock.setOverlayDraw(overlayGraphics); - classificationBuffer = new ArrayList(); + classificationBuffer = new ArrayList(); //the process name. setProcessName("Deep Learning Classifier"); @@ -140,7 +139,7 @@ public class DLClassifyProcess extends PamInstantProcess { System.err.println("Raw Deep Learning Classifier: The grouped source parameters were null." + " A new instance has been created: Possible de-serialization error."); } - + //important for downstream processes such as the bearing localiser. dlModelResultDataBlock.setChannelMap(dlControl.getDLParams().groupedSourceParams.getChannelBitmap()); @@ -166,7 +165,7 @@ public class DLClassifyProcess extends PamInstantProcess { public void prepareProcess() { setupClassifierProcess(); } - + /** * called for every process once the system model has been created. @@ -194,36 +193,71 @@ public class DLClassifyProcess extends PamInstantProcess { */ @Override public void newData(PamObservable obs, PamDataUnit pamRawData) { +// System.out.println("NEW SEGMENTER DATA"); - //the raw data units should appear in sequential channel order - GroupedRawData rawDataUnit = (GroupedRawData) pamRawData; - - if (checkGroupData(rawDataUnit)) { - //check whether the classification buffer is full. If it is then run - if (isClassificationBufferFull(classificationBuffer, rawDataUnit)) { - - //first call run model to clear out the classification buffer if needs be - runModel(); + if (pamRawData instanceof SegmenterDetectionGroup) { + if (classificationBuffer.size()>=1) { + System.out.println("RUN THE MODEL FOR WHISTLES: "); + runDetectionGroupModel(); classificationBuffer.clear(); } - - classificationBuffer.add(rawDataUnit); + else { + classificationBuffer.add(pamRawData); + } } -// System.out.println("New raw data in: chan: " + PamUtils.getSingleChannel(pamRawData.getChannelBitmap()) + -// " Size: " + pamRawData.getSampleDuration() + " first sample: " + rawDataUnit.getRawData()[0][0] -// + "Parent UID: " + rawDataUnit.getParentDataUnit().getUID()); + + if (pamRawData instanceof GroupedRawData) { + //the raw data units should appear in sequential channel order + GroupedRawData rawDataUnit = (GroupedRawData) pamRawData; + + if (checkGroupData(rawDataUnit)) { + //check whether the classification buffer is full. If it is then run + if (isRawClassificationBufferFull(classificationBuffer, rawDataUnit)) { + + //first call run model to clear out the classification buffer if needs be + runRawModel(); + classificationBuffer.clear(); + } + + classificationBuffer.add(rawDataUnit); + + } + } + // System.out.println("New raw data in: chan: " + PamUtils.getSingleChannel(pamRawData.getChannelBitmap()) + + // " Size: " + pamRawData.getSampleDuration() + " first sample: " + rawDataUnit.getRawData()[0][0] + // + "Parent UID: " + rawDataUnit.getParentDataUnit().getUID()); } - + + + /** + * Run a model for which the input is a detection group. + */ + private void runDetectionGroupModel() { + if (classificationBuffer.size()<=0) return; + ArrayList classificationBufferTemp = (ArrayList) classificationBuffer.clone(); + + ArrayList modelResults = this.dlControl.getDLModel().runModel(classificationBufferTemp); + + for (int i=0; i classificationBufferTemp = (ArrayList) classificationBuffer.clone(); - ArrayList modelResults = this.dlControl.getDLModel().runModel(classificationBuffer); + ArrayList modelResults = this.dlControl.getDLModel().runModel(classificationBufferTemp); if (modelResults==null) { return; //there has been a problem @@ -235,18 +269,18 @@ public class DLClassifyProcess extends PamInstantProcess { // System.out.println("Compare Times: " + PamCalendar.formatDBDateTime(modelResults.get(i).getTimeMillis(), true) + // " " + PamCalendar.formatDBDateTime(classificationBufferTemp.get(i).getTimeMilliseconds(), true) + " " + // modelResults.get(i).getPrediction()[1]); - newModelResult(modelResults.get(i), classificationBufferTemp.get(i)); + newRawModelResult(modelResults.get(i), classificationBufferTemp.get(i)); } } } /** - * Check whether the buffer is full and the results should be passed to the classification model. + * Check whether the buffer is full and the results should be passed to the classification model if we are using GrpoupDataUnits * @param classificationBuffer2 - the classification buffer. * @param rawDataUnit - the next raw data unit to add to the buffer. * @return true if the buffer is full. */ - private boolean isClassificationBufferFull(ArrayList classificationBuffer2, GroupedRawData rawDataUnit) { + private boolean isRawClassificationBufferFull(ArrayList classificationBuffer2, GroupedRawData rawDataUnit) { if (classificationBuffer2.size()==0) return false; @@ -254,16 +288,16 @@ public class DLClassifyProcess extends PamInstantProcess { //1) It's over a max time //2) Contains different parent data units (if not from raw data). - GroupedRawData lastUnit = classificationBuffer2.get(classificationBuffer2.size()-1); + GroupedRawData lastUnit = (GroupedRawData) classificationBuffer2.get(classificationBuffer2.size()-1); if (!(lastUnit.getParentDataUnit() instanceof RawDataUnit) && lastUnit.getParentDataUnit()!=rawDataUnit.getParentDataUnit()) { //there is a new parent data unit. return true; } - //get the start time. Use min value instead of first data just in case units ar enot in order. + //get the start time. Use min value instead of first data just in case units are not in order. long min = Long.MAX_VALUE; - for (GroupedRawData groupedRawData: classificationBuffer2) { + for (PamDataUnit groupedRawData: classificationBuffer2) { if (groupedRawData.getTimeMilliseconds()0) { - //System.out.println("Save click annotation to " + lastParentDataUnit[i].getUID()); + //System.out.println("Save click annotation to " + lastParentDataUnit[i].getUID()); addDLAnnotation(dataUnit,groupDataBuffer[i],modelResultDataBuffer[i]); lastParentDataUnit[i]=null; clearBuffer(i); @@ -532,15 +573,22 @@ public class DLClassifyProcess extends PamInstantProcess { @Override public void pamStart() { - // TODO Auto-generated method stub - System.out.println("PREP MODEL:"); +// System.out.println("PREP MODEL:"); this.dlControl.getDLModel().prepModel(); } @Override public void pamStop() { - runModel(); //make sure to run the last data in the buffer. - + //make sure to run the last data in the buffer. + if (this.classificationBuffer.size()>0) { + if (classificationBuffer.get(0) instanceof GroupedRawData) { + runRawModel(); //raw data or raw data units + } + if (classificationBuffer.get(0) instanceof SegmenterDetectionGroup) { + runDetectionGroupModel(); //any other data units. + } + } + //21/11/2022 - it seems like this causes a memory leak when models are reopened and closed every file... //this.dlControl.getDLModel().closeModel(); } diff --git a/src/rawDeepLearningClassifier/dlClassification/StandardClassifierModel.java b/src/rawDeepLearningClassifier/dlClassification/StandardClassifierModel.java index 0fbce3de..b6138e44 100644 --- a/src/rawDeepLearningClassifier/dlClassification/StandardClassifierModel.java +++ b/src/rawDeepLearningClassifier/dlClassification/StandardClassifierModel.java @@ -271,7 +271,7 @@ public abstract class StandardClassifierModel implements DLClassiferModel, PamSe */ protected void newResult(StandardPrediction modelResult, PamDataUnit groupedRawData) { if (groupedRawData instanceof GroupedRawData) { - this.dlControl.getDLClassifyProcess().newModelResult(modelResult, (GroupedRawData) groupedRawData); + this.dlControl.getDLClassifyProcess().newRawModelResult(modelResult, (GroupedRawData) groupedRawData); } } // diff --git a/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDClassifier.java b/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDClassifier.java index b55c8655..0a7e2813 100644 --- a/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDClassifier.java +++ b/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDClassifier.java @@ -91,7 +91,7 @@ public class DelphinIDClassifier extends StandardClassifierModel { @Override public boolean isDecision(StandardPrediction modelResult, StandardModelParams modelParmas) { //TODO - //DelphinID uses a different decision making process to most of the standard classifiers which just pass a binary threhsoild. + //DelphinID uses a different decision making process to most of the standard classifiers which just pass a binary threshold. return false; } diff --git a/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDWorker.java b/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDWorker.java index b3ef293d..c411cce1 100644 --- a/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDWorker.java +++ b/src/rawDeepLearningClassifier/dlClassification/delphinID/DelphinIDWorker.java @@ -49,8 +49,8 @@ public class DelphinIDWorker extends ArchiveModelWorker { this.setModel(null); // set model to null to make sure nothing works and errors are thrown } - dlParams.binaryClassification = new boolean[dlParams.classNames.length]; - for (int i=0; i { @Override public float[] runModel(float[][][] transformedDataStack) { - //System.out.println("RUN GENERIC MODEL: " + transformedDataStack.length + " " + transformedDataStack[0].length + " " + transformedDataStack[0][0].length); + System.out.println("RUN GENERIC MODEL: " + transformedDataStack.length + " " + transformedDataStack[0].length + " " + transformedDataStack[0][0].length); // System.out.println("RUN GENERIC MODEL: " + transformedDataStack[0][0][0]); float[] results; if (freqTransform) @@ -47,7 +48,8 @@ public class GenericModelWorker extends DLModelWorker { //System.out.println("RUN GENERIC MODEL WAVE: " + waveStack.length + " " + waveStack[0].length + " " + waveStack[0][0]); results = getModel().runModel(waveStack); } - //System.out.println("GENERIC MODEL RESULTS: " + results== null ? null : results.length); + System.out.println("GENERIC MODEL RESULTS: " + (results== null ? null : results.length)); + PamArrayUtils.printArray(results); return results; } diff --git a/src/rawDeepLearningClassifier/segmenter/SegmenterDetectionGroup.java b/src/rawDeepLearningClassifier/segmenter/SegmenterDetectionGroup.java index 93bbb7e3..fe218fbf 100644 --- a/src/rawDeepLearningClassifier/segmenter/SegmenterDetectionGroup.java +++ b/src/rawDeepLearningClassifier/segmenter/SegmenterDetectionGroup.java @@ -28,7 +28,7 @@ public class SegmenterDetectionGroup extends GroupDetection { * @param timeMilliseconds - this is the start of the SEGMENT - Note that the * @param channelBitmap - channels of all detections * @param startSample - the stratSample of the SEGMENT. - * @param duration - the duration of the SEGMENT. + * @param duration - the duration of the SEGMENT in milliseconds. */ public SegmenterDetectionGroup(long timeMilliseconds, int channelBitmap, long startSample, double duration) { super(timeMilliseconds, channelBitmap, startSample, (long) duration); @@ -52,5 +52,9 @@ public class SegmenterDetectionGroup extends GroupDetection { return segDuration; } + public long getSegmentEndMillis() { + return (long) (segMillis+segDuration); + } + } diff --git a/src/rawDeepLearningClassifier/segmenter/SegmenterGroupDataBlock.java b/src/rawDeepLearningClassifier/segmenter/SegmenterGroupDataBlock.java index 3bda33fd..3a69c23a 100644 --- a/src/rawDeepLearningClassifier/segmenter/SegmenterGroupDataBlock.java +++ b/src/rawDeepLearningClassifier/segmenter/SegmenterGroupDataBlock.java @@ -12,6 +12,8 @@ public class SegmenterGroupDataBlock extends PamDataBlock=segStart && whistleStart=segStart && whistleEnd