Specifying the HyDe-CNN architecture¶
The Python scripts containing code to train HyDe-CNN can be found in the files
train_*_cnn.py
. They can be run using the following command commands:
# Here divergence_scaling is either 0.5, 1.0, or 2.0 coalescent units
# Minimum dXY network
python3 train_min_cnn.py --coal_units <divergence_scaling>
# Mean dXY network
python3 train_mean_cnn.py --coal_units <divergence_scaling>
# Minimum+mean dXY network
python3 train_min-mean_cnn.py --coal_units <divergence_scaling>
TensorFlow imports¶
import tensorflow.keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Dense,
Dropout,
Flatten
)
from tensorflow.keras.layers import (
Conv2D,
AveragePooling2D
)
from tensorflow.keras.callbacks import (
EarlyStopping,
ModelCheckpoint
)
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam
Specifying the network architecture¶
model = Sequential()
model.add(
Conv2D(
12, kernel_size=(4,2),
activation='relu',
input_shape=(xtrain.shape[1],xtrain.shape[2],xtrain.shape[3])
)
)
model.add(
AveragePooling2D(
pool_size=(2,1),
strides=(2,1)
)
)
model.add(Dropout(0.2))
model.add(
Conv2D(
24, kernel_size=(4,2),
activation='relu'
)
)
model.add(
AveragePooling2D(
pool_size=(2,1),
strides=(2,1)
)
)
model.add(Dropout(0.2))
model.add(
Conv2D(
36, kernel_size=(4,2),
activation='relu'
)
)
model.add(
AveragePooling2D(
pool_size=(2,1),
strides=(2,1)
)
)
model.add(Dropout(0.2))
model.add(
Conv2D(
48, kernel_size=(4,2),
activation='relu'
)
)
model.add(
AveragePooling2D(
pool_size=(2,1),
strides=(2,1)
)
)
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(60, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(4, activation='softmax'))
model.compile(
loss=categorical_crossentropy,
optimizer=Adam(),
metrics=['accuracy']
)
print(model.summary())
callbacks = [
EarlyStopping(monitor='val_loss'),
ModelCheckpoint(
filepath='hyde_cnn_min_{}.mdl'.format(cu),
monitor='val_loss',
save_best_only=True
)
]
model.fit(
xtrain, ytrain,
batch_size=32,
epochs=10,
verbose=1,
callbacks=callbacks,
validation_data=(xval,yval)
)