Link Search Menu Expand Document

min2net.model.DeepConvNet

View source on GitHub

Table of contents

  1. min2net.model.DeepConvNet
    1. About DeepConvNet
    2. DeepConvNet class
    3. Build method
    4. Fit method
    5. Predict method
    6. Example

About DeepConvNet

If you use the DeepConvNet model in your research, please cite the following paper:

@article{hbm23730,
  author = {Schirrmeister Robin Tibor and 
            Springenberg Jost Tobias and 
            Fiederer Lukas Dominique Josef and 
            Glasstetter Martin and 
            Eggensperger Katharina and 
            Tangermann Michael and 
            Hutter Frank and 
            Burgard Wolfram and 
            Ball Tonio},
  title = {Deep learning with convolutional neural networks for EEG decoding and visualization},
  journal = {Human Brain Mapping},
  volume = {38},
  number = {11},
  pages = {5391-5420},
  keywords = {electroencephalography, EEG analysis, machine learning, end‐to‐end learning, brain–machine interface, brain–computer interface, model interpretability, brain mapping},
  doi = {10.1002/hbm.23730},
  url = {https://onlinelibrary.wiley.com/doi/abs/10.1002/hbm.23730}
}

DeepConvNet class

Configures the model for training. Based on tf.keras.Model.

min2net.model.DeepConvNet()

Arguments:

ArgumentsDescriptionDefault
input_shapetuple of integers.
(1, #channel, #time_point)
(1,20,400)
num_classint number of class.2
lossstr (name of objective function), objective function or tf.keras.losses.Loss instance.'sparse_categorical_crossentropy'
epochsint number of epochs to train the model.200
batch_sizeint or None. Number of samples per gradient update.100
optimizerstr (name of optimizer) or optimizer instance. See tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999, epsilon=1e-08)
lrfloat the start learning rate0.01
min_lrfloat lower bound on the learning rate. See tf.keras.callbacks.ReduceLROnPlateau.0.01
factorfloat factor by which the learning rate will be reduced. See tf.keras.callbacks.ReduceLROnPlateau.0.25
patienceint number of epochs with no improvement after which learning rate will be reduced. See tf.keras.callbacks.ReduceLROnPlateau.10
es_patienceint number of epochs with no improvement after which training will be stopped. See tf.keras.callbacks.EarlyStopping.20
verbose0 or 1. Verbosity mode. 0 = silent, 1 = progress bar.1
log_pathstr path to save model‘logs’
model_namestr prefix to save model‘DeepConvNet’
**kwargsKeyword argument to pass into the function for replacing the variables of the model such as kernLength, F1, D, F2, norm_rate, dropout_rate, f1_average, data_format, shuffle, metrics, monitor, mode, save_best_only, save_weight_only, seed, and class_balancing-

Build method

Build the model that group layers into an object with training and inference features.

DeepConvNet.build()
    Keras implementation of the Deep Convolutional Network as described in
    Schirrmeister et. al. (2017), Human Brain Mapping.

    This implementation assumes the input is a 2-second EEG signal sampled at
    128Hz, as opposed to signals sampled at 250Hz as described in the original
    paper. We also perform temporal convolutions of length (1, 5) as opposed
    to (1, 10) due to this sampling rate difference.

    Note that we use the max_norm constraint on all convolutional layers, as
    well as the classification layer. We also change the defaults for the
    BatchNormalization layer. We used this based on a personal communication
    with the original authors.

                      ours        original paper
    pool_size        1, 2        1, 3
    strides          1, 2        1, 3
    conv filters     1, 5        1, 10

Returns: Model (tf.keras.Model): Model object


Fit method

Fit the model according to the given training and validation data. This method was implemented based on tf.keras.Model.fit(). The model weights and logs will save at 'log_path'.

DeepConvNet.fit(X_train, 
                y_train, 
                X_val, 
                y_val)

Arguments:

ArgumentsDescription
X_trainndarray Training EEG signals. shape (#trial, #depth, #channel, #time_point)
y_trainndarray Label of training set. shape (#trial)
X_valndarray Validation EEG signals. shape (#trial,#depth, #channel, #time_point)
y_valndarray Label of validation set. shape (#trial)

Predict method

Generates output predictions & the loss value & metrics values for the model in test mode for the input samples. This medthod was implemented based on tf.keras.Model.predict() and tf.keras.Model.evaluate().

DeepConvNet.predict(X_test, 
                    y_test)
ArgumentsDescription
X_testndarray Testing EEG signals. shape (#trial, #depth, #channel, #time_point)
y_testndarray Label of test set. shape (#trial)

Returns:

  • Y: dictionary of {y_true, y_pred}
  • evaluation: dictionary of {loss, accuracy, f1_score}

Example

from min2net.model import DeepConvNet
import ndarray as np

model = DeepConvNet(input_shape=(1,20,400), num_class=2, dropout_rate=0.25, shuffle=True)
model.fit(X_train, y_train, X_val, y_val)

Y, evaluation = model.predict(X_test, y_test)