Ketos model fixes

Fixed bug in ketos models were, if no transforms were present, the model would not load. Also fixed an error when no class names were provided.
This commit is contained in:
Jamie Mac 2022-01-14 09:39:49 +00:00
parent 490daf4481
commit fd31d94aa9
6 changed files with 46 additions and 25 deletions

View File

@ -502,7 +502,7 @@ public class BinaryOutputStream {
*/
public synchronized boolean storeData(DataOutputStream outputStream, int objectId, DataUnitBaseData baseData, BinaryObjectData binaryObjectData) {
if (lastObjectType == BinaryTypes.MODULE_FOOTER) {
System.out.printf("Storing binary object type %d in file %s with no module header\n", objectId, outputStream.toString());
System.out.printf("Storing binary object type %d in file %s with no module header\n", objectId, outputStream == null ? null : outputStream.toString());
}
byte[] data = binaryObjectData.getData();
int objectLength = binaryObjectData.getDataLength();

View File

@ -421,11 +421,16 @@ public abstract class StandardModelPane extends SettingsPane<StandardModelParams
//get class names from the paramClone as these may have been set by a loaded model
//instea of a use changing a control.
//instead of a using changing a control.
currParams.classNames = paramsClone.classNames;
currParams.numClasses = paramsClone.numClasses;
currParams.useDefaultSegLen = usedefaultSeg.isSelected();
//System.out.println("Get CLASS NAMES: currParams.classNames: " + currParams.classNames + " Num classes " + currParams.numClasses);
return currParams;
}

View File

@ -49,6 +49,12 @@ public class StandardModelParams implements Serializable, Cloneable {
*/
public int numClasses = 0;
/**
* The class names. e.g. porpoise, noise, bat
*/
public DLClassName[] classNames;
/**
* List of transforms for the raw data e.g. filtering, spectrogram, spectrogram normalisation etc.
* This is only used for saving serialised settings
@ -74,10 +80,7 @@ public class StandardModelParams implements Serializable, Cloneable {
public boolean useDefaultSegLen = false;
/**
* The class names. e.g. porpoise, noise, bat
*/
public DLClassName[] classNames;
/**
* Which classes to apply binary classification to.

View File

@ -224,7 +224,6 @@ public class KetosClassifier implements DLClassiferModel, PamSettings {
* @param ketosDLParams - the Ketos parameters.
*/
public KetosDLParams getKetosParams() {
// TODO Auto-generated method stub
return ketosDLParams;
}

View File

@ -13,6 +13,8 @@ import org.jamdev.jdl4pam.transforms.WaveTransform;
import org.jamdev.jdl4pam.transforms.jsonfile.DLTransformsParser;
import org.jamdev.jdl4pam.utils.DLUtils;
import org.jamdev.jpamutils.wavFiles.AudioData;
import org.jamdev.jpamutils.JamArr;
/**
* Test the Ketos models in PAMGuard.
@ -33,8 +35,9 @@ public class KetosModelTest {
//File file = new File("/Volumes/GoogleDrive-108005893101854397430/My Drive/PAMGuard_dev/Deep_Learning/Meridian/orca/kw_detector_v11_5s.ktpb");
//the wav file to test.
String wavFilePath = "/Volumes/GoogleDrive/My Drive/PAMGuard_dev/Deep_Learning/Meridian/right_whales/for_pamguard/input.wav";
//String wavFilePath = "/Volumes/GoogleDrive/My Drive/PAMGuard_dev/Deep_Learning/Meridian/right_whales/for_pamguard/input.wav";
String wavFilePath = "/Volumes/GoogleDrive-108005893101854397430/My Drive/PAMGuard_dev/Deep_Learning/Meridian/humpback_whales/wav/5353.210403161502.wav";
try {
//the ketos model.
KetosModel ketosModel = new KetosModel(file);
@ -44,12 +47,14 @@ public class KetosModelTest {
//get the audio representation file.
KetosParams ketosParams = new KetosParams(jsonString);
ketosParams.defaultOutputShape = ketosModel.getOutShape();
//System.out.println(ketosParams.toString());
System.out.println("Output shape" + ketosParams.defaultOutputShape);
//Open wav files.
AudioData soundData = DLUtils.loadWavFile(wavFilePath);
soundData = soundData.trim(0, (int) (soundData.getSampleRate()*3.0));
soundData = soundData.trim(0, (int) (soundData.getSampleRate()*3.52));
//generate the transforms.
ArrayList<DLTransform> transforms = DLTransformsFactory.makeDLTransforms(ketosParams.dlTransforms);
@ -83,6 +88,7 @@ public class KetosModelTest {
for (int j=0; j<nStack; j++) {
data[j] = DLUtils.toFloatArray(transformedData);
}
System.out.println("Input len: " + data.length + " " + data[0].length + " " + data[0][0].length);
output = ketosModel.runModel(data);
long time2 = System.currentTimeMillis();
System.out.println("Time to run model: " + (time2-time1) + " ms");

View File

@ -70,7 +70,12 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
//convert the JSON string to a parameters object.
KetosParams ketosParams = new KetosParams(jsonString);
//important to add this for Ketos models because the JSON string does not necessarily contain and output shape.
//System.out.println("----Default output shape: " + ketosParams.defaultOutputShape + " " + ketosModel.getOutShape());
if (ketosParams.defaultOutputShape==null) {
ketosParams.defaultOutputShape = ketosModel.getOutShape();
}
///HACK here for now to fix an issue with dB and Ketos transforms having zero length somehow...
@ -99,7 +104,7 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
ketosDLParams.defaultSegmentLen = ketosParams.seglen*1000.; //the segment length in microseconds.
//ketosParams.classNames = new String[] {"Noise", "Right Whale"}; // FIXME;
ketosDLParams.numClasses = (int) ketosModel.getOutShape().get(1);
//ok 0 the other values are not user selectable but this is. If we relaod the same model we probably want to keep it....
//So this is a little bt of a hack but will probably be OK in most cases.
if (ketosDLParams.binaryClassification==null || ketosDLParams.binaryClassification.length!=ketosDLParams.numClasses) {
@ -110,18 +115,21 @@ public class KetosWorker extends DLModelWorker<KetosResult> {
}
// if (dlParams.classNames!=null) {
// for (int i = 0; i<dlParams.classNames.length; i++) {
// System.out.println("Class name " + i + " " + dlParams.classNames[i]);
// }
// }
// ketosDLParams.classNames = dlControl.getClassNameManager().makeClassNames(ketosParams.classNames);
//
// if (ketosParams.classNames!=null) {
// for (int i = 0; i<ketosDLParams.classNames.length; i++) {
// System.out.println("Class name " + i + " " + ketosDLParams.classNames[i].className + " ID " + ketosDLParams.classNames[i].ID );
// }
// }
// if (dlParams.classNames!=null) {
// for (int i = 0; i<dlParams.classNames.length; i++) {
// System.out.println("Class name " + i + " " + dlParams.classNames[i]);
// }
// }
// ketosDLParams.classNames = dlControl.getClassNameManager().makeClassNames(ketosParams.classNames);
// if (ketosParams.classNames!=null) {
// for (int i = 0; i<ketosDLParams.classNames.length; i++) {
// System.out.println("Class name " + i + " " + ketosDLParams.classNames[i].className + " ID " + ketosDLParams.classNames[i].ID );
// }
// }
}
catch (Exception e) {