Examples 1D

!pip install k3im --upgrade

import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
Requirement already satisfied: k3im in /usr/local/lib/python3.10/dist-packages (0.0.7)
Requirement already satisfied: keras>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from k3im) (3.0.1)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (1.4.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (1.23.5)
Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (13.7.0)
Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (0.0.7)
Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (3.9.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (0.1.8)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.0.0->k3im) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.0.0->k3im) (2.16.1)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->k3im) (0.1.2)

import numpy as np
import keras


def readucr(filename):
    data = np.loadtxt(filename, delimiter="\t")
    y = data[:, 0]
    x = data[:, 1:]
    return x, y.astype(int)


root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"

x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
x_test, y_test = readucr(root_url + "FordA_TEST.tsv")

x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))

n_classes = len(np.unique(y_train))

idx = np.random.permutation(len(x_train))
x_train = x_train[idx]
y_train = y_train[idx]

y_train[y_train == -1] = 0
y_test[y_test == -1] = 0
x_train.shape
(3601, 500, 1)
def train_model(model):
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
        metrics=["sparse_categorical_accuracy"],
    )


    model.fit(
        x_train,
        y_train,
        validation_split=0.2,
        epochs=2,
        batch_size=64,
    )
    model.evaluate(x_test, y_test, verbose=1)
from k3im.cait_1d import CAiT_1DModel
model = CAiT_1DModel(
    seq_len=500,
    patch_size=20,
    num_classes=n_classes,
    dim=64,
    dim_head=32,
    mlp_dim=64,
    depth=2,
    cls_depth=2,
    heads=4,
    channels=1,
    dropout_rate=0.0,
)
model.summary()
Model: "functional_5"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)  │ (None, 500, 1)         │          0 │ -                          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape (Reshape)         │ (None, 25, 20)         │          0 │ input_layer[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization       │ (None, 25, 20)         │         40 │ reshape[0][0]              │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense (Dense)             │ (None, 25, 64)         │      1,344 │ layer_normalization[0][0]  │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_1     │ (None, 25, 64)         │        128 │ dense[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add (Add)                 │ (None, 25, 64)         │          0 │ layer_normalization_1[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention      │ (None, 25, 64)         │     33,216 │ add[0][0], add[0][0]       │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_1 (Add)               │ (None, 25, 64)         │          0 │ add[0][0],                 │
│                           │                        │            │ multi_head_attention[0][0] │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential (Sequential)   │ (None, 25, 64)         │      8,448 │ add_1[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_2 (Add)               │ (None, 25, 64)         │          0 │ add_1[0][0],               │
│                           │                        │            │ sequential[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_1    │ (None, 25, 64)         │     33,216 │ add_2[0][0], add_2[0][0]   │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_3 (Add)               │ (None, 25, 64)         │          0 │ add_2[0][0],               │
│                           │                        │            │ multi_head_attention_1[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_1 (Sequential) │ (None, 25, 64)         │      8,448 │ add_3[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_4 (Add)               │ (None, 25, 64)         │          0 │ add_3[0][0],               │
│                           │                        │            │ sequential_1[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_4     │ (None, 25, 64)         │        128 │ add_4[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ cls__token (CLS_Token)    │ [(None, 26, 64),       │         64 │ layer_normalization_4[0][ │
│                           │ (None, 1, 64)]         │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ concatenate (Concatenate) │ (None, 26, 64)         │          0 │ cls__token[0][1],          │
│                           │                        │            │ layer_normalization_4[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_2    │ (None, 1, 64)          │     33,216 │ cls__token[0][1],          │
│ (MultiHeadAttention)      │                        │            │ concatenate[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_5 (Add)               │ (None, 1, 64)          │          0 │ cls__token[0][1],          │
│                           │                        │            │ multi_head_attention_2[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_2 (Sequential) │ (None, 1, 64)          │      8,448 │ add_5[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_6 (Add)               │ (None, 1, 64)          │          0 │ add_5[0][0],               │
│                           │                        │            │ sequential_2[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ concatenate_1             │ (None, 26, 64)         │          0 │ add_6[0][0],               │
│ (Concatenate)             │                        │            │ layer_normalization_4[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_3    │ (None, 1, 64)          │     33,216 │ add_6[0][0],               │
│ (MultiHeadAttention)      │                        │            │ concatenate_1[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_7 (Add)               │ (None, 1, 64)          │          0 │ add_6[0][0],               │
│                           │                        │            │ multi_head_attention_3[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_3 (Sequential) │ (None, 1, 64)          │      8,448 │ add_7[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_8 (Add)               │ (None, 1, 64)          │          0 │ add_7[0][0],               │
│                           │                        │            │ sequential_3[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_7     │ (None, 1, 64)          │        128 │ add_8[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ squeeze (Squeeze)         │ (None, 64)             │          0 │ layer_normalization_7[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_9 (Dense)           │ (None, 2)              │        130 │ squeeze[0][0]              │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 168,618 (658.66 KB)
 Trainable params: 168,618 (658.66 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 30s 101ms/step - loss: 0.7743 - sparse_categorical_accuracy: 0.5116 - val_loss: 0.6783 - val_sparse_categorical_accuracy: 0.4840
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.6495 - sparse_categorical_accuracy: 0.5963 - val_loss: 0.5901 - val_sparse_categorical_accuracy: 0.6768
42/42 ━━━━━━━━━━━━━━━━━━━━ 4s 53ms/step - loss: 0.5515 - sparse_categorical_accuracy: 0.7196

from k3im.cct_1d import CCT_1DModel
model = CCT_1DModel(
    input_shape=(500, 1),
    num_heads=4,
    projection_dim=154,
    kernel_size=10,
    stride=15,
    padding=5,
    transformer_units=[154],
    stochastic_depth_rate=0.5,
    transformer_layers=1,
    num_classes=n_classes,
    positional_emb=False,
)
model.summary()
Model: "functional_8"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_5             │ (None, 500, 1)         │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ cct_tokenizer1d           │ (None, 6, 154)         │     99,200 │ input_layer_5[0][0]        │
│ (CCTTokenizer1D)          │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_8     │ (None, 6, 154)         │        308 │ cct_tokenizer1d[0][0]      │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_4    │ (None, 6, 154)         │    381,458 │ layer_normalization_8[0][ │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_8[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth          │ (None, 6, 154)         │          0 │ multi_head_attention_4[0]… │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_9 (Add)               │ (None, 6, 154)         │          0 │ stochastic_depth[0][0],    │
│                           │                        │            │ cct_tokenizer1d[0][0]      │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_9     │ (None, 6, 154)         │        308 │ add_9[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_10 (Dense)          │ (None, 6, 154)         │     23,870 │ layer_normalization_9[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_5 (Dropout)       │ (None, 6, 154)         │          0 │ dense_10[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth_1        │ (None, 6, 154)         │          0 │ dropout_5[0][0]            │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_10 (Add)              │ (None, 6, 154)         │          0 │ stochastic_depth_1[0][0],  │
│                           │                        │            │ add_9[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_10    │ (None, 6, 154)         │        308 │ add_10[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequence_pooling          │ (None, 154)            │        155 │ layer_normalization_10[0]… │
│ (SequencePooling)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_12 (Dense)          │ (None, 2)              │        310 │ sequence_pooling[0][0]     │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 505,917 (1.93 MB)
 Trainable params: 505,917 (1.93 MB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 16s 56ms/step - loss: 0.7375 - sparse_categorical_accuracy: 0.5405 - val_loss: 0.6278 - val_sparse_categorical_accuracy: 0.6422
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.5604 - sparse_categorical_accuracy: 0.7140 - val_loss: 0.6223 - val_sparse_categorical_accuracy: 0.6727
42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 24ms/step - loss: 0.5649 - sparse_categorical_accuracy: 0.6980

from k3im.convmixer_1d import ConvMixer1DModel
from k3im.convmixer_1d import ConvMixer1DModel
model = ConvMixer1DModel(seq_len=500,
    n_features=1,
    filters=128,
    depth=4,
    kernel_size=15,
    patch_size=4,
    num_classes=n_classes,)
model.summary()
Model: "functional_10"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_7             │ (None, 500, 1)         │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv1d_2 (Conv1D)         │ (None, 125, 128)       │        640 │ input_layer_7[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation (Activation)   │ (None, 125, 128)       │          0 │ conv1d_2[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization       │ (None, 125, 128)       │        512 │ activation[0][0]           │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ depthwise_conv1d          │ (None, 125, 128)       │      2,048 │ batch_normalization[0][0]  │
│ (DepthwiseConv1D)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_1 (Activation) │ (None, 125, 128)       │          0 │ depthwise_conv1d[0][0]     │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_1     │ (None, 125, 128)       │        512 │ activation_1[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_11 (Add)              │ (None, 125, 128)       │          0 │ batch_normalization_1[0][ │
│                           │                        │            │ batch_normalization[0][0]  │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv1d_3 (Conv1D)         │ (None, 125, 128)       │     16,512 │ add_11[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_2 (Activation) │ (None, 125, 128)       │          0 │ conv1d_3[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_2     │ (None, 125, 128)       │        512 │ activation_2[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ depthwise_conv1d_1        │ (None, 125, 128)       │      2,048 │ batch_normalization_2[0][ │
│ (DepthwiseConv1D)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_3 (Activation) │ (None, 125, 128)       │          0 │ depthwise_conv1d_1[0][0]   │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_3     │ (None, 125, 128)       │        512 │ activation_3[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_12 (Add)              │ (None, 125, 128)       │          0 │ batch_normalization_3[0][ │
│                           │                        │            │ batch_normalization_2[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv1d_4 (Conv1D)         │ (None, 125, 128)       │     16,512 │ add_12[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_4 (Activation) │ (None, 125, 128)       │          0 │ conv1d_4[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_4     │ (None, 125, 128)       │        512 │ activation_4[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ depthwise_conv1d_2        │ (None, 125, 128)       │      2,048 │ batch_normalization_4[0][ │
│ (DepthwiseConv1D)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_5 (Activation) │ (None, 125, 128)       │          0 │ depthwise_conv1d_2[0][0]   │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_5     │ (None, 125, 128)       │        512 │ activation_5[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_13 (Add)              │ (None, 125, 128)       │          0 │ batch_normalization_5[0][ │
│                           │                        │            │ batch_normalization_4[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv1d_5 (Conv1D)         │ (None, 125, 128)       │     16,512 │ add_13[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_6 (Activation) │ (None, 125, 128)       │          0 │ conv1d_5[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_6     │ (None, 125, 128)       │        512 │ activation_6[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ depthwise_conv1d_3        │ (None, 125, 128)       │      2,048 │ batch_normalization_6[0][ │
│ (DepthwiseConv1D)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_7 (Activation) │ (None, 125, 128)       │          0 │ depthwise_conv1d_3[0][0]   │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_7     │ (None, 125, 128)       │        512 │ activation_7[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_14 (Add)              │ (None, 125, 128)       │          0 │ batch_normalization_7[0][ │
│                           │                        │            │ batch_normalization_6[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv1d_6 (Conv1D)         │ (None, 125, 128)       │     16,512 │ add_14[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_8 (Activation) │ (None, 125, 128)       │          0 │ conv1d_6[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_8     │ (None, 125, 128)       │        512 │ activation_8[0][0]         │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ global_average_pooling1d  │ (None, 128)            │          0 │ batch_normalization_8[0][ │
│ (GlobalAveragePooling1D)  │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_13 (Dense)          │ (None, 2)              │        258 │ global_average_pooling1d[ │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 79,746 (311.51 KB)
 Trainable params: 77,442 (302.51 KB)
 Non-trainable params: 2,304 (9.00 KB)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 62ms/step - loss: 0.5393 - sparse_categorical_accuracy: 0.7676 - val_loss: 0.7343 - val_sparse_categorical_accuracy: 0.4785
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.2279 - sparse_categorical_accuracy: 0.9343 - val_loss: 0.8505 - val_sparse_categorical_accuracy: 0.4785
42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 21ms/step - loss: 0.7988 - sparse_categorical_accuracy: 0.5299

from k3im.eanet_1d import EANet1DModel
model = EANet1DModel(
    seq_len=500,
    patch_size=20,
    num_classes=n_classes,
    dim=96,
    depth=3,
    heads=32,
    mlp_dim=64,
    dim_coefficient=2,
    attention_dropout=0.0,
    channels=1,
)
model.summary()
Model: "functional_15"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_8             │ (None, 500, 1)         │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_1 (Reshape)       │ (None, 25, 20)         │          0 │ input_layer_8[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_11    │ (None, 25, 20)         │         40 │ reshape_1[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_14 (Dense)          │ (None, 25, 96)         │      2,016 │ layer_normalization_11[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_12    │ (None, 25, 96)         │        192 │ dense_14[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_15 (Add)              │ (None, 25, 96)         │          0 │ layer_normalization_12[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_15 (Dense)          │ (None, 25, 192)        │     18,624 │ add_15[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_2 (Reshape)       │ (None, 25, 64, 3)      │          0 │ dense_15[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose (Transpose)     │ (None, 64, 25, 3)      │          0 │ reshape_2[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_16 (Dense)          │ (None, 64, 25, 48)     │        192 │ transpose[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ softmax_5 (Softmax)       │ (None, 64, 25, 48)     │          0 │ dense_16[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ lambda (Lambda)           │ (None, 64, 25, 48)     │          0 │ softmax_5[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_6 (Dropout)       │ (None, 64, 25, 48)     │          0 │ lambda[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_17 (Dense)          │ (None, 64, 25, 3)      │        147 │ dropout_6[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_1 (Transpose)   │ (None, 25, 64, 3)      │          0 │ dense_17[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_3 (Reshape)       │ (None, 25, 192)        │          0 │ transpose_1[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_18 (Dense)          │ (None, 25, 96)         │     18,528 │ reshape_3[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_7 (Dropout)       │ (None, 25, 96)         │          0 │ dense_18[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_16 (Add)              │ (None, 25, 96)         │          0 │ add_15[0][0],              │
│                           │                        │            │ dropout_7[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_5 (Sequential) │ (None, 25, 96)         │     12,640 │ add_16[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_17 (Add)              │ (None, 25, 96)         │          0 │ add_16[0][0],              │
│                           │                        │            │ sequential_5[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_21 (Dense)          │ (None, 25, 192)        │     18,624 │ add_17[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_4 (Reshape)       │ (None, 25, 64, 3)      │          0 │ dense_21[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_2 (Transpose)   │ (None, 64, 25, 3)      │          0 │ reshape_4[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_22 (Dense)          │ (None, 64, 25, 48)     │        192 │ transpose_2[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ softmax_6 (Softmax)       │ (None, 64, 25, 48)     │          0 │ dense_22[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ lambda_1 (Lambda)         │ (None, 64, 25, 48)     │          0 │ softmax_6[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_8 (Dropout)       │ (None, 64, 25, 48)     │          0 │ lambda_1[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_23 (Dense)          │ (None, 64, 25, 3)      │        147 │ dropout_8[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_3 (Transpose)   │ (None, 25, 64, 3)      │          0 │ dense_23[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_5 (Reshape)       │ (None, 25, 192)        │          0 │ transpose_3[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_24 (Dense)          │ (None, 25, 96)         │     18,528 │ reshape_5[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_9 (Dropout)       │ (None, 25, 96)         │          0 │ dense_24[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_18 (Add)              │ (None, 25, 96)         │          0 │ add_17[0][0],              │
│                           │                        │            │ dropout_9[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_6 (Sequential) │ (None, 25, 96)         │     12,640 │ add_18[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_19 (Add)              │ (None, 25, 96)         │          0 │ add_18[0][0],              │
│                           │                        │            │ sequential_6[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_27 (Dense)          │ (None, 25, 192)        │     18,624 │ add_19[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_6 (Reshape)       │ (None, 25, 64, 3)      │          0 │ dense_27[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_4 (Transpose)   │ (None, 64, 25, 3)      │          0 │ reshape_6[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_28 (Dense)          │ (None, 64, 25, 48)     │        192 │ transpose_4[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ softmax_7 (Softmax)       │ (None, 64, 25, 48)     │          0 │ dense_28[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ lambda_2 (Lambda)         │ (None, 64, 25, 48)     │          0 │ softmax_7[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_10 (Dropout)      │ (None, 64, 25, 48)     │          0 │ lambda_2[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_29 (Dense)          │ (None, 64, 25, 3)      │        147 │ dropout_10[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_5 (Transpose)   │ (None, 25, 64, 3)      │          0 │ dense_29[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_7 (Reshape)       │ (None, 25, 192)        │          0 │ transpose_5[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_30 (Dense)          │ (None, 25, 96)         │     18,528 │ reshape_7[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_11 (Dropout)      │ (None, 25, 96)         │          0 │ dense_30[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_20 (Add)              │ (None, 25, 96)         │          0 │ add_19[0][0],              │
│                           │                        │            │ dropout_11[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_7 (Sequential) │ (None, 25, 96)         │     12,640 │ add_20[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_21 (Add)              │ (None, 25, 96)         │          0 │ add_20[0][0],              │
│                           │                        │            │ sequential_7[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_16    │ (None, 25, 96)         │        192 │ add_21[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ avg_pool                  │ (None, 96)             │          0 │ layer_normalization_16[0]… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_33 (Dense)          │ (None, 2)              │        194 │ avg_pool[0][0]             │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 153,027 (597.76 KB)
 Trainable params: 153,027 (597.76 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 19s 92ms/step - loss: 0.7292 - sparse_categorical_accuracy: 0.4982 - val_loss: 0.6846 - val_sparse_categorical_accuracy: 0.5465
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.6754 - sparse_categorical_accuracy: 0.5374 - val_loss: 0.6517 - val_sparse_categorical_accuracy: 0.6533
42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.6384 - sparse_categorical_accuracy: 0.6737

from k3im.gmlp_1d import gMLP1DModel
model = gMLP1DModel(seq_len=500, patch_size=20, num_classes=n_classes, dim=64, depth=4, channels=1, dropout_rate=0.0)
model.summary()
Model: "functional_21"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Layer (type)                        Output Shape                       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ input_layer_12 (InputLayer)        │ (None, 500, 1)                │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_8 (Reshape)                │ (None, 25, 20)                │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_17             │ (None, 25, 20)                │          40 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_34 (Dense)                   │ (None, 25, 64)                │       1,344 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_18             │ (None, 25, 64)                │         128 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ add_22 (Add)                       │ (None, 25, 64)                │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ g_mlp_layer (gMLPLayer)            │ (None, 25, 64)                │      13,386 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ g_mlp_layer_1 (gMLPLayer)          │ (None, 25, 64)                │      13,386 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ g_mlp_layer_2 (gMLPLayer)          │ (None, 25, 64)                │      13,386 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ g_mlp_layer_3 (gMLPLayer)          │ (None, 25, 64)                │      13,386 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ avg_pool (GlobalAveragePooling1D)  │ (None, 64)                    │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_47 (Dense)                   │ (None, 2)                     │         130 │
└────────────────────────────────────┴───────────────────────────────┴─────────────┘
 Total params: 55,186 (215.57 KB)
 Trainable params: 55,186 (215.57 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 78ms/step - loss: 0.7447 - sparse_categorical_accuracy: 0.5254 - val_loss: 0.6844 - val_sparse_categorical_accuracy: 0.5520
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.6457 - sparse_categorical_accuracy: 0.6318 - val_loss: 0.6526 - val_sparse_categorical_accuracy: 0.6144
42/42 ━━━━━━━━━━━━━━━━━━━━ 3s 50ms/step - loss: 0.6393 - sparse_categorical_accuracy: 0.6402

from k3im.mlp_mixer_1d import Mixer1DModel
model = Mixer1DModel(seq_len=500, patch_size=20, num_classes=n_classes, dim=64, depth=4, channels=1, dropout_rate=0.0)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 18s 67ms/step - loss: 0.8218 - sparse_categorical_accuracy: 0.5014 - val_loss: 0.7365 - val_sparse_categorical_accuracy: 0.4979
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.7060 - sparse_categorical_accuracy: 0.5569 - val_loss: 0.6784 - val_sparse_categorical_accuracy: 0.5603
42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - loss: 0.6731 - sparse_categorical_accuracy: 0.5918

from k3im.simple_vit_1d import SimpleViT1DModel
model = SimpleViT1DModel(seq_len=500,
    patch_size=20,
    num_classes=n_classes,
    dim=32,
    depth=3,
    heads=8,
    mlp_dim=64,
    channels=1,
    dim_head=64)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 19s 139ms/step - loss: 0.7022 - sparse_categorical_accuracy: 0.5038 - val_loss: 0.6673 - val_sparse_categorical_accuracy: 0.5853
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 10s 16ms/step - loss: 0.6571 - sparse_categorical_accuracy: 0.5832 - val_loss: 0.6531 - val_sparse_categorical_accuracy: 0.5243
42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 30ms/step - loss: 0.6289 - sparse_categorical_accuracy: 0.5326

from k3im.vit_1d import ViT1DModel
model = ViT1DModel(seq_len=500,
    patch_size=20,
    num_classes=n_classes,
    dim=32,
    depth=3,
    heads=8,
    mlp_dim=64,
    channels=1,
    dim_head=64)
train_model(model)
Epoch 1/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 16s 72ms/step - loss: 0.8892 - sparse_categorical_accuracy: 0.5182 - val_loss: 0.6662 - val_sparse_categorical_accuracy: 0.5742
Epoch 2/2
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 0.6322 - sparse_categorical_accuracy: 0.6353 - val_loss: 0.6459 - val_sparse_categorical_accuracy: 0.5825
42/42 ━━━━━━━━━━━━━━━━━━━━ 3s 35ms/step - loss: 0.6078 - sparse_categorical_accuracy: 0.6317