Examples 2D

!pip install k3im --upgrade
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples

batch_size = 128
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.CategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    score = model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])
from k3im.cait import CaiTModel # jax ✅, tensorflow ✅, torch ✅
model = CaiTModel(
    image_size=(28, 28),
    patch_size=(7, 7),
Model: "functional_5"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer (InputLayer)  │ (None, 28, 28, 1)      │          0 │ -                          │
│ extract_patches           │ (None, 4, 4, 49)       │          0 │ input_layer[0][0]          │
│ (ExtractPatches)          │                        │            │                            │
│ reshape (Reshape)         │ (None, 16, 49)         │          0 │ extract_patches[0][0]      │
│ layer_normalization       │ (None, 16, 49)         │         98 │ reshape[0][0]              │
│ (LayerNormalization)      │                        │            │                            │
│ dense (Dense)             │ (None, 16, 32)         │      1,600 │ layer_normalization[0][0]  │
│ layer_normalization_1     │ (None, 16, 32)         │         64 │ dense[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
│ position_emb              │ (None, 16, 32)         │        512 │ layer_normalization_1[0][ │
│ (PositionEmb)             │                        │            │                            │
│ multi_head_attention      │ (None, 16, 32)         │     67,104 │ position_emb[0][0],        │
│ (MultiHeadAttention)      │                        │            │ position_emb[0][0]         │
│ add (Add)                 │ (None, 16, 32)         │          0 │ position_emb[0][0],        │
│                           │                        │            │ multi_head_attention[0][0] │
│ sequential (Sequential)   │ (None, 16, 32)         │      4,256 │ add[0][0]                  │
│ add_1 (Add)               │ (None, 16, 32)         │          0 │ add[0][0],                 │
│                           │                        │            │ sequential[0][0]           │
│ multi_head_attention_1    │ (None, 16, 32)         │     67,104 │ add_1[0][0], add_1[0][0]   │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_2 (Add)               │ (None, 16, 32)         │          0 │ add_1[0][0],               │
│                           │                        │            │ multi_head_attention_1[0]… │
│ sequential_1 (Sequential) │ (None, 16, 32)         │      4,256 │ add_2[0][0]                │
│ add_3 (Add)               │ (None, 16, 32)         │          0 │ add_2[0][0],               │
│                           │                        │            │ sequential_1[0][0]         │
│ layer_normalization_4     │ (None, 16, 32)         │         64 │ add_3[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
│ cls__token (CLS_Token)    │ [(None, 17, 32),       │         32 │ layer_normalization_4[0][ │
│                           │ (None, 1, 32)]         │            │                            │
│ concatenate (Concatenate) │ (None, 17, 32)         │          0 │ cls__token[0][1],          │
│                           │                        │            │ layer_normalization_4[0][ │
│ multi_head_attention_2    │ (None, 1, 32)          │     67,104 │ cls__token[0][1],          │
│ (MultiHeadAttention)      │                        │            │ concatenate[0][0]          │
│ add_4 (Add)               │ (None, 1, 32)          │          0 │ cls__token[0][1],          │
│                           │                        │            │ multi_head_attention_2[0]… │
│ sequential_2 (Sequential) │ (None, 1, 32)          │      4,256 │ add_4[0][0]                │
│ add_5 (Add)               │ (None, 1, 32)          │          0 │ add_4[0][0],               │
│                           │                        │            │ sequential_2[0][0]         │
│ concatenate_1             │ (None, 17, 32)         │          0 │ add_5[0][0],               │
│ (Concatenate)             │                        │            │ layer_normalization_4[0][ │
│ multi_head_attention_3    │ (None, 1, 32)          │     67,104 │ add_5[0][0],               │
│ (MultiHeadAttention)      │                        │            │ concatenate_1[0][0]        │
│ add_6 (Add)               │ (None, 1, 32)          │          0 │ add_5[0][0],               │
│                           │                        │            │ multi_head_attention_3[0]… │
│ sequential_3 (Sequential) │ (None, 1, 32)          │      4,256 │ add_6[0][0]                │
│ add_7 (Add)               │ (None, 1, 32)          │          0 │ add_6[0][0],               │
│                           │                        │            │ sequential_3[0][0]         │
│ layer_normalization_7     │ (None, 1, 32)          │         64 │ add_7[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
│ squeeze (Squeeze)         │ (None, 32)             │          0 │ layer_normalization_7[0][ │
│ dense_9 (Dense)           │ (None, 10)             │        330 │ squeeze[0][0]              │
 Total params: 288,204 (1.10 MB)
 Trainable params: 288,204 (1.10 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 22s 52ms/step - accuracy: 0.5326 - loss: 1.3825 - val_accuracy: 0.9125 - val_loss: 0.2950
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 19s 45ms/step - accuracy: 0.9129 - loss: 0.2858 - val_accuracy: 0.9593 - val_loss: 0.1494
Test loss: 0.16203546524047852
Test accuracy: 0.9521999955177307

from k3im.cct import CCT  # jax ✅, tensorflow ✅, torch ✅

model = CCT(
    transformer_units=[16, 32],
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 14s 33ms/step - accuracy: 0.6643 - loss: 1.0071 - val_accuracy: 0.9318 - val_loss: 0.2200
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 32ms/step - accuracy: 0.9292 - loss: 0.2308 - val_accuracy: 0.9532 - val_loss: 0.1575
Test loss: 0.1650298684835434
Test accuracy: 0.947700023651123

from k3im.convmixer import ConvMixer # jax ✅ # tf something not right # something aint right with torch

model = ConvMixer(
    image_size=28, filters=64, depth=8, kernel_size=3, patch_size=2, num_classes=10, num_channels=1
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 25s 59ms/step - accuracy: 0.6054 - loss: 1.3247 - val_accuracy: 0.1113 - val_loss: 9747.7539
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 30s 70ms/step - accuracy: 0.8339 - loss: 0.5841 - val_accuracy: 0.1113 - val_loss: 1693.4771
Test loss: 1714.0731201171875
Test accuracy: 0.10279999673366547

from k3im.cross_vit import CrossViT # jax ✅, tensorflow ✅, torch ✅
model = CrossViT(
Model: "functional_21"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_8             │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ extract_patches_2         │ (None, 4, 4, 49)       │          0 │ input_layer_8[0][0]        │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_2 (Reshape)       │ (None, 16, 49)         │          0 │ extract_patches_2[0][0]    │
│ layer_normalization_15    │ (None, 16, 49)         │         98 │ reshape_2[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ dense_18 (Dense)          │ (None, 16, 42)         │      2,100 │ layer_normalization_15[0]… │
│ layer_normalization_16    │ (None, 16, 42)         │         84 │ dense_18[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
│ extract_patches_1         │ (None, 7, 7, 16)       │          0 │ input_layer_8[0][0]        │
│ (ExtractPatches)          │                        │            │                            │
│ cls__token_2 (CLS_Token)  │ [(None, 17, 42),       │         42 │ layer_normalization_16[0]… │
│                           │ (None, 1, 42)]         │            │                            │
│ reshape_1 (Reshape)       │ (None, 49, 16)         │          0 │ extract_patches_1[0][0]    │
│ position_emb_2            │ (None, 17, 42)         │        714 │ cls__token_2[0][0]         │
│ (PositionEmb)             │                        │            │                            │
│ layer_normalization_13    │ (None, 49, 16)         │         32 │ reshape_1[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_11 (Dropout)      │ (None, 17, 42)         │          0 │ position_emb_2[0][0]       │
│ dense_17 (Dense)          │ (None, 49, 32)         │        544 │ layer_normalization_13[0]… │
│ multi_head_attention_7    │ (None, 17, 42)         │     98,538 │ dropout_11[0][0],          │
│ (MultiHeadAttention)      │                        │            │ dropout_11[0][0]           │
│ layer_normalization_14    │ (None, 49, 32)         │         64 │ dense_17[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
│ add_22 (Add)              │ (None, 17, 42)         │          0 │ dropout_11[0][0],          │
│                           │                        │            │ multi_head_attention_7[0]… │
│ cls__token_1 (CLS_Token)  │ [(None, 50, 32),       │         32 │ layer_normalization_14[0]… │
│                           │ (None, 1, 32)]         │            │                            │
│ sequential_6 (Sequential) │ (None, 17, 42)         │      7,266 │ add_22[0][0]               │
│ position_emb_1            │ (None, 50, 32)         │      1,600 │ cls__token_1[0][0]         │
│ (PositionEmb)             │                        │            │                            │
│ add_23 (Add)              │ (None, 17, 42)         │          0 │ add_22[0][0],              │
│                           │                        │            │ sequential_6[0][0]         │
│ dropout_10 (Dropout)      │ (None, 50, 32)         │          0 │ position_emb_1[0][0]       │
│ multi_head_attention_8    │ (None, 17, 42)         │     98,538 │ add_23[0][0], add_23[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ multi_head_attention_6    │ (None, 50, 32)         │     58,720 │ dropout_10[0][0],          │
│ (MultiHeadAttention)      │                        │            │ dropout_10[0][0]           │
│ add_24 (Add)              │ (None, 17, 42)         │          0 │ add_23[0][0],              │
│                           │                        │            │ multi_head_attention_8[0]… │
│ add_20 (Add)              │ (None, 50, 32)         │          0 │ dropout_10[0][0],          │
│                           │                        │            │ multi_head_attention_6[0]… │
│ sequential_7 (Sequential) │ (None, 17, 42)         │      7,266 │ add_24[0][0]               │
│ sequential_5 (Sequential) │ (None, 50, 32)         │      3,216 │ add_20[0][0]               │
│ add_25 (Add)              │ (None, 17, 42)         │          0 │ add_24[0][0],              │
│                           │                        │            │ sequential_7[0][0]         │
│ add_21 (Add)              │ (None, 50, 32)         │          0 │ add_20[0][0],              │
│                           │                        │            │ sequential_5[0][0]         │
│ layer_normalization_21    │ (None, 17, 42)         │         84 │ add_25[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_18    │ (None, 50, 32)         │         64 │ add_21[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_16 (Dropout)      │ (None, 17, 42)         │          0 │ layer_normalization_21[0]… │
│ dropout_13 (Dropout)      │ (None, 50, 32)         │          0 │ layer_normalization_18[0]… │
│ get_item_2 (GetItem)      │ (None, 1, 42)          │          0 │ dropout_16[0][0]           │
│ get_item_1 (GetItem)      │ (None, 49, 32)         │          0 │ dropout_13[0][0]           │
│ dense_27 (Dense)          │ (None, 1, 32)          │      1,376 │ get_item_2[0][0]           │
│ concatenate_4             │ (None, 50, 32)         │          0 │ dense_27[0][0],            │
│ (Concatenate)             │                        │            │ get_item_1[0][0]           │
│ multi_head_attention_11   │ (None, 1, 32)          │     67,104 │ dense_27[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_4[0][0]        │
│ add_29 (Add)              │ (None, 1, 32)          │          0 │ dense_27[0][0],            │
│                           │                        │            │ multi_head_attention_11[0… │
│ get_item (GetItem)        │ (None, 1, 32)          │          0 │ dropout_13[0][0]           │
│ concatenate_5             │ (None, 50, 32)         │          0 │ add_29[0][0],              │
│ (Concatenate)             │                        │            │ get_item_1[0][0]           │
│ dense_25 (Dense)          │ (None, 1, 42)          │      1,386 │ get_item[0][0]             │
│ get_item_3 (GetItem)      │ (None, 16, 42)         │          0 │ dropout_16[0][0]           │
│ multi_head_attention_12   │ (None, 1, 32)          │     67,104 │ add_29[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_5[0][0]        │
│ concatenate_2             │ (None, 17, 42)         │          0 │ dense_25[0][0],            │
│ (Concatenate)             │                        │            │ get_item_3[0][0]           │
│ add_30 (Add)              │ (None, 1, 32)          │          0 │ add_29[0][0],              │
│                           │                        │            │ multi_head_attention_12[0… │
│ multi_head_attention_9    │ (None, 1, 42)          │     87,594 │ dense_25[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_2[0][0]        │
│ layer_normalization_23    │ (None, 1, 32)          │         64 │ add_30[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ add_26 (Add)              │ (None, 1, 42)          │          0 │ dense_25[0][0],            │
│                           │                        │            │ multi_head_attention_9[0]… │
│ dropout_22 (Dropout)      │ (None, 1, 32)          │          0 │ layer_normalization_23[0]… │
│ concatenate_3             │ (None, 17, 42)         │          0 │ add_26[0][0],              │
│ (Concatenate)             │                        │            │ get_item_3[0][0]           │
│ dense_28 (Dense)          │ (None, 1, 42)          │      1,386 │ dropout_22[0][0]           │
│ multi_head_attention_10   │ (None, 1, 42)          │     87,594 │ add_26[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_3[0][0]        │
│ add_31 (Add)              │ (None, 1, 42)          │          0 │ dense_28[0][0],            │
│                           │                        │            │ get_item_2[0][0]           │
│ add_27 (Add)              │ (None, 1, 42)          │          0 │ add_26[0][0],              │
│                           │                        │            │ multi_head_attention_10[0… │
│ concatenate_7             │ (None, 17, 42)         │          0 │ add_31[0][0],              │
│ (Concatenate)             │                        │            │ get_item_3[0][0]           │
│ layer_normalization_22    │ (None, 1, 42)          │         84 │ add_27[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ multi_head_attention_14   │ (None, 17, 42)         │     98,538 │ concatenate_7[0][0],       │
│ (MultiHeadAttention)      │                        │            │ concatenate_7[0][0]        │
│ dropout_19 (Dropout)      │ (None, 1, 42)          │          0 │ layer_normalization_22[0]… │
│ add_34 (Add)              │ (None, 17, 42)         │          0 │ concatenate_7[0][0],       │
│                           │                        │            │ multi_head_attention_14[0… │
│ dense_26 (Dense)          │ (None, 1, 32)          │      1,376 │ dropout_19[0][0]           │
│ sequential_9 (Sequential) │ (None, 17, 42)         │      7,266 │ add_34[0][0]               │
│ add_28 (Add)              │ (None, 1, 32)          │          0 │ dense_26[0][0],            │
│                           │                        │            │ get_item[0][0]             │
│ add_35 (Add)              │ (None, 17, 42)         │          0 │ add_34[0][0],              │
│                           │                        │            │ sequential_9[0][0]         │
│ concatenate_6             │ (None, 50, 32)         │          0 │ add_28[0][0],              │
│ (Concatenate)             │                        │            │ get_item_1[0][0]           │
│ multi_head_attention_15   │ (None, 17, 42)         │     98,538 │ add_35[0][0], add_35[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ multi_head_attention_13   │ (None, 50, 32)         │     58,720 │ concatenate_6[0][0],       │
│ (MultiHeadAttention)      │                        │            │ concatenate_6[0][0]        │
│ add_36 (Add)              │ (None, 17, 42)         │          0 │ add_35[0][0],              │
│                           │                        │            │ multi_head_attention_15[0… │
│ add_32 (Add)              │ (None, 50, 32)         │          0 │ concatenate_6[0][0],       │
│                           │                        │            │ multi_head_attention_13[0… │
│ sequential_10             │ (None, 17, 42)         │      7,266 │ add_36[0][0]               │
│ (Sequential)              │                        │            │                            │
│ sequential_8 (Sequential) │ (None, 50, 32)         │      3,216 │ add_32[0][0]               │
│ add_37 (Add)              │ (None, 17, 42)         │          0 │ add_36[0][0],              │
│                           │                        │            │ sequential_10[0][0]        │
│ add_33 (Add)              │ (None, 50, 32)         │          0 │ add_32[0][0],              │
│                           │                        │            │ sequential_8[0][0]         │
│ layer_normalization_28    │ (None, 17, 42)         │         84 │ add_37[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_25    │ (None, 50, 32)         │         64 │ add_33[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_27 (Dropout)      │ (None, 17, 42)         │          0 │ layer_normalization_28[0]… │
│ dropout_24 (Dropout)      │ (None, 50, 32)         │          0 │ layer_normalization_25[0]… │
│ get_item_6 (GetItem)      │ (None, 1, 42)          │          0 │ dropout_27[0][0]           │
│ get_item_5 (GetItem)      │ (None, 49, 32)         │          0 │ dropout_24[0][0]           │
│ dense_37 (Dense)          │ (None, 1, 32)          │      1,376 │ get_item_6[0][0]           │
│ concatenate_10            │ (None, 50, 32)         │          0 │ dense_37[0][0],            │
│ (Concatenate)             │                        │            │ get_item_5[0][0]           │
│ multi_head_attention_18   │ (None, 1, 32)          │     67,104 │ dense_37[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_10[0][0]       │
│ get_item_4 (GetItem)      │ (None, 1, 32)          │          0 │ dropout_24[0][0]           │
│ add_41 (Add)              │ (None, 1, 32)          │          0 │ dense_37[0][0],            │
│                           │                        │            │ multi_head_attention_18[0… │
│ dense_35 (Dense)          │ (None, 1, 42)          │      1,386 │ get_item_4[0][0]           │
│ get_item_7 (GetItem)      │ (None, 16, 42)         │          0 │ dropout_27[0][0]           │
│ concatenate_11            │ (None, 50, 32)         │          0 │ add_41[0][0],              │
│ (Concatenate)             │                        │            │ get_item_5[0][0]           │
│ concatenate_8             │ (None, 17, 42)         │          0 │ dense_35[0][0],            │
│ (Concatenate)             │                        │            │ get_item_7[0][0]           │
│ multi_head_attention_19   │ (None, 1, 32)          │     67,104 │ add_41[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_11[0][0]       │
│ multi_head_attention_16   │ (None, 1, 42)          │     87,594 │ dense_35[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_8[0][0]        │
│ add_42 (Add)              │ (None, 1, 32)          │          0 │ add_41[0][0],              │
│                           │                        │            │ multi_head_attention_19[0… │
│ add_38 (Add)              │ (None, 1, 42)          │          0 │ dense_35[0][0],            │
│                           │                        │            │ multi_head_attention_16[0… │
│ layer_normalization_30    │ (None, 1, 32)          │         64 │ add_42[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ concatenate_9             │ (None, 17, 42)         │          0 │ add_38[0][0],              │
│ (Concatenate)             │                        │            │ get_item_7[0][0]           │
│ dropout_33 (Dropout)      │ (None, 1, 32)          │          0 │ layer_normalization_30[0]… │
│ multi_head_attention_17   │ (None, 1, 42)          │     87,594 │ add_38[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_9[0][0]        │
│ dense_38 (Dense)          │ (None, 1, 42)          │      1,386 │ dropout_33[0][0]           │
│ add_39 (Add)              │ (None, 1, 42)          │          0 │ add_38[0][0],              │
│                           │                        │            │ multi_head_attention_17[0… │
│ add_43 (Add)              │ (None, 1, 42)          │          0 │ dense_38[0][0],            │
│                           │                        │            │ get_item_6[0][0]           │
│ layer_normalization_29    │ (None, 1, 42)          │         84 │ add_39[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ concatenate_13            │ (None, 17, 42)         │          0 │ add_43[0][0],              │
│ (Concatenate)             │                        │            │ get_item_7[0][0]           │
│ dropout_30 (Dropout)      │ (None, 1, 42)          │          0 │ layer_normalization_29[0]… │
│ multi_head_attention_21   │ (None, 17, 42)         │     98,538 │ concatenate_13[0][0],      │
│ (MultiHeadAttention)      │                        │            │ concatenate_13[0][0]       │
│ dense_36 (Dense)          │ (None, 1, 32)          │      1,376 │ dropout_30[0][0]           │
│ add_46 (Add)              │ (None, 17, 42)         │          0 │ concatenate_13[0][0],      │
│                           │                        │            │ multi_head_attention_21[0… │
│ add_40 (Add)              │ (None, 1, 32)          │          0 │ dense_36[0][0],            │
│                           │                        │            │ get_item_4[0][0]           │
│ sequential_12             │ (None, 17, 42)         │      7,266 │ add_46[0][0]               │
│ (Sequential)              │                        │            │                            │
│ concatenate_12            │ (None, 50, 32)         │          0 │ add_40[0][0],              │
│ (Concatenate)             │                        │            │ get_item_5[0][0]           │
│ add_47 (Add)              │ (None, 17, 42)         │          0 │ add_46[0][0],              │
│                           │                        │            │ sequential_12[0][0]        │
│ multi_head_attention_20   │ (None, 50, 32)         │     58,720 │ concatenate_12[0][0],      │
│ (MultiHeadAttention)      │                        │            │ concatenate_12[0][0]       │
│ multi_head_attention_22   │ (None, 17, 42)         │     98,538 │ add_47[0][0], add_47[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_44 (Add)              │ (None, 50, 32)         │          0 │ concatenate_12[0][0],      │
│                           │                        │            │ multi_head_attention_20[0… │
│ add_48 (Add)              │ (None, 17, 42)         │          0 │ add_47[0][0],              │
│                           │                        │            │ multi_head_attention_22[0… │
│ sequential_11             │ (None, 50, 32)         │      3,216 │ add_44[0][0]               │
│ (Sequential)              │                        │            │                            │
│ sequential_13             │ (None, 17, 42)         │      7,266 │ add_48[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_45 (Add)              │ (None, 50, 32)         │          0 │ add_44[0][0],              │
│                           │                        │            │ sequential_11[0][0]        │
│ add_49 (Add)              │ (None, 17, 42)         │          0 │ add_48[0][0],              │
│                           │                        │            │ sequential_13[0][0]        │
│ layer_normalization_32    │ (None, 50, 32)         │         64 │ add_45[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_35    │ (None, 17, 42)         │         84 │ add_49[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_35 (Dropout)      │ (None, 50, 32)         │          0 │ layer_normalization_32[0]… │
│ dropout_38 (Dropout)      │ (None, 17, 42)         │          0 │ layer_normalization_35[0]… │
│ get_item_8 (GetItem)      │ (None, 1, 32)          │          0 │ dropout_35[0][0]           │
│ get_item_10 (GetItem)     │ (None, 1, 42)          │          0 │ dropout_38[0][0]           │
│ dense_45 (Dense)          │ (None, 1, 42)          │      1,386 │ get_item_8[0][0]           │
│ get_item_11 (GetItem)     │ (None, 16, 42)         │          0 │ dropout_38[0][0]           │
│ get_item_9 (GetItem)      │ (None, 49, 32)         │          0 │ dropout_35[0][0]           │
│ dense_47 (Dense)          │ (None, 1, 32)          │      1,376 │ get_item_10[0][0]          │
│ concatenate_14            │ (None, 17, 42)         │          0 │ dense_45[0][0],            │
│ (Concatenate)             │                        │            │ get_item_11[0][0]          │
│ concatenate_16            │ (None, 50, 32)         │          0 │ dense_47[0][0],            │
│ (Concatenate)             │                        │            │ get_item_9[0][0]           │
│ multi_head_attention_23   │ (None, 1, 42)          │     87,594 │ dense_45[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_14[0][0]       │
│ multi_head_attention_25   │ (None, 1, 32)          │     67,104 │ dense_47[0][0],            │
│ (MultiHeadAttention)      │                        │            │ concatenate_16[0][0]       │
│ add_50 (Add)              │ (None, 1, 42)          │          0 │ dense_45[0][0],            │
│                           │                        │            │ multi_head_attention_23[0… │
│ add_53 (Add)              │ (None, 1, 32)          │          0 │ dense_47[0][0],            │
│                           │                        │            │ multi_head_attention_25[0… │
│ concatenate_15            │ (None, 17, 42)         │          0 │ add_50[0][0],              │
│ (Concatenate)             │                        │            │ get_item_11[0][0]          │
│ concatenate_17            │ (None, 50, 32)         │          0 │ add_53[0][0],              │
│ (Concatenate)             │                        │            │ get_item_9[0][0]           │
│ multi_head_attention_24   │ (None, 1, 42)          │     87,594 │ add_50[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_15[0][0]       │
│ multi_head_attention_26   │ (None, 1, 32)          │     67,104 │ add_53[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_17[0][0]       │
│ add_51 (Add)              │ (None, 1, 42)          │          0 │ add_50[0][0],              │
│                           │                        │            │ multi_head_attention_24[0… │
│ add_54 (Add)              │ (None, 1, 32)          │          0 │ add_53[0][0],              │
│                           │                        │            │ multi_head_attention_26[0… │
│ layer_normalization_36    │ (None, 1, 42)          │         84 │ add_51[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_37    │ (None, 1, 32)          │         64 │ add_54[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_41 (Dropout)      │ (None, 1, 42)          │          0 │ layer_normalization_36[0]… │
│ dropout_44 (Dropout)      │ (None, 1, 32)          │          0 │ layer_normalization_37[0]… │
│ dense_46 (Dense)          │ (None, 1, 32)          │      1,376 │ dropout_41[0][0]           │
│ dense_48 (Dense)          │ (None, 1, 42)          │      1,386 │ dropout_44[0][0]           │
│ add_52 (Add)              │ (None, 1, 32)          │          0 │ dense_46[0][0],            │
│                           │                        │            │ get_item_8[0][0]           │
│ add_55 (Add)              │ (None, 1, 42)          │          0 │ dense_48[0][0],            │
│                           │                        │            │ get_item_10[0][0]          │
│ concatenate_18            │ (None, 50, 32)         │          0 │ add_52[0][0],              │
│ (Concatenate)             │                        │            │ get_item_9[0][0]           │
│ concatenate_19            │ (None, 17, 42)         │          0 │ add_55[0][0],              │
│ (Concatenate)             │                        │            │ get_item_11[0][0]          │
│ get_item_12 (GetItem)     │ (None, 32)             │          0 │ concatenate_18[0][0]       │
│ get_item_13 (GetItem)     │ (None, 42)             │          0 │ concatenate_19[0][0]       │
│ layer_normalization_38    │ (None, 32)             │         64 │ get_item_12[0][0]          │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_39    │ (None, 42)             │         84 │ get_item_13[0][0]          │
│ (LayerNormalization)      │                        │            │                            │
│ dense_49 (Dense)          │ (None, 10)             │        330 │ layer_normalization_38[0]… │
│ dense_50 (Dense)          │ (None, 10)             │        430 │ layer_normalization_39[0]… │
│ add_56 (Add)              │ (None, 10)             │          0 │ dense_49[0][0],            │
│                           │                        │            │ dense_50[0][0]             │
 Total params: 1,772,498 (6.76 MB)
 Trainable params: 1,772,498 (6.76 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 76s 179ms/step - accuracy: 0.5782 - loss: 1.2357 - val_accuracy: 0.9185 - val_loss: 0.2657
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 73s 173ms/step - accuracy: 0.8943 - loss: 0.3341 - val_accuracy: 0.9500 - val_loss: 0.1754
Test loss: 0.18887896835803986
Test accuracy: 0.9437000155448914

from k3im.deepvit import DeepViT # jax ✅, tensorflow ✅
model = DeepViT(image_size=28,
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.6203 - loss: 1.1394 - val_accuracy: 0.9172 - val_loss: 0.2594
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.9250 - loss: 0.2398 - val_accuracy: 0.9567 - val_loss: 0.1421
Test loss: 0.16465066373348236
Test accuracy: 0.9492999911308289

from k3im.eanet import EANet # jax ✅, tensorflow ✅, torch ✅
model = EANet(
Model: "functional_27"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_21            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ patch_extract             │ (None, 16, 49)         │          0 │ input_layer_21[0][0]       │
│ (PatchExtract)            │                        │            │                            │
│ patch_embedding           │ (None, 16, 64)         │      4,224 │ patch_extract[0][0]        │
│ (PatchEmbedding)          │                        │            │                            │
│ layer_normalization_46    │ (None, 16, 64)         │        128 │ patch_embedding[0][0]      │
│ (LayerNormalization)      │                        │            │                            │
│ dense_58 (Dense)          │ (None, 16, 128)        │      8,320 │ layer_normalization_46[0]… │
│ reshape_4 (Reshape)       │ (None, 16, 32, 4)      │          0 │ dense_58[0][0]             │
│ transpose (Transpose)     │ (None, 32, 16, 4)      │          0 │ reshape_4[0][0]            │
│ dense_59 (Dense)          │ (None, 32, 16, 32)     │        160 │ transpose[0][0]            │
│ softmax_29 (Softmax)      │ (None, 32, 16, 32)     │          0 │ dense_59[0][0]             │
│ lambda (Lambda)           │ (None, 32, 16, 32)     │          0 │ softmax_29[0][0]           │
│ dropout_52 (Dropout)      │ (None, 32, 16, 32)     │          0 │ lambda[0][0]               │
│ dense_60 (Dense)          │ (None, 32, 16, 4)      │        132 │ dropout_52[0][0]           │
│ transpose_1 (Transpose)   │ (None, 16, 32, 4)      │          0 │ dense_60[0][0]             │
│ reshape_5 (Reshape)       │ (None, 16, 128)        │          0 │ transpose_1[0][0]          │
│ dense_61 (Dense)          │ (None, 16, 64)         │      8,256 │ reshape_5[0][0]            │
│ dropout_53 (Dropout)      │ (None, 16, 64)         │          0 │ dense_61[0][0]             │
│ add_61 (Add)              │ (None, 16, 64)         │          0 │ dropout_53[0][0],          │
│                           │                        │            │ patch_embedding[0][0]      │
│ layer_normalization_47    │ (None, 16, 64)         │        128 │ add_61[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dense_62 (Dense)          │ (None, 16, 32)         │      2,080 │ layer_normalization_47[0]… │
│ dropout_54 (Dropout)      │ (None, 16, 32)         │          0 │ dense_62[0][0]             │
│ dense_63 (Dense)          │ (None, 16, 64)         │      2,112 │ dropout_54[0][0]           │
│ dropout_55 (Dropout)      │ (None, 16, 64)         │          0 │ dense_63[0][0]             │
│ add_62 (Add)              │ (None, 16, 64)         │          0 │ dropout_55[0][0],          │
│                           │                        │            │ add_61[0][0]               │
│ layer_normalization_48    │ (None, 16, 64)         │        128 │ add_62[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dense_64 (Dense)          │ (None, 16, 128)        │      8,320 │ layer_normalization_48[0]… │
│ reshape_6 (Reshape)       │ (None, 16, 32, 4)      │          0 │ dense_64[0][0]             │
│ transpose_2 (Transpose)   │ (None, 32, 16, 4)      │          0 │ reshape_6[0][0]            │
│ dense_65 (Dense)          │ (None, 32, 16, 32)     │        160 │ transpose_2[0][0]          │
│ softmax_30 (Softmax)      │ (None, 32, 16, 32)     │          0 │ dense_65[0][0]             │
│ lambda_1 (Lambda)         │ (None, 32, 16, 32)     │          0 │ softmax_30[0][0]           │
│ dropout_56 (Dropout)      │ (None, 32, 16, 32)     │          0 │ lambda_1[0][0]             │
│ dense_66 (Dense)          │ (None, 32, 16, 4)      │        132 │ dropout_56[0][0]           │
│ transpose_3 (Transpose)   │ (None, 16, 32, 4)      │          0 │ dense_66[0][0]             │
│ reshape_7 (Reshape)       │ (None, 16, 128)        │          0 │ transpose_3[0][0]          │
│ dense_67 (Dense)          │ (None, 16, 64)         │      8,256 │ reshape_7[0][0]            │
│ dropout_57 (Dropout)      │ (None, 16, 64)         │          0 │ dense_67[0][0]             │
│ add_63 (Add)              │ (None, 16, 64)         │          0 │ dropout_57[0][0],          │
│                           │                        │            │ add_62[0][0]               │
│ layer_normalization_49    │ (None, 16, 64)         │        128 │ add_63[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ dense_68 (Dense)          │ (None, 16, 32)         │      2,080 │ layer_normalization_49[0]… │
│ dropout_58 (Dropout)      │ (None, 16, 32)         │          0 │ dense_68[0][0]             │
│ dense_69 (Dense)          │ (None, 16, 64)         │      2,112 │ dropout_58[0][0]           │
│ dropout_59 (Dropout)      │ (None, 16, 64)         │          0 │ dense_69[0][0]             │
│ add_64 (Add)              │ (None, 16, 64)         │          0 │ dropout_59[0][0],          │
│                           │                        │            │ add_63[0][0]               │
│ global_average_pooling1d  │ (None, 64)             │          0 │ add_64[0][0]               │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_70 (Dense)          │ (None, 10)             │        650 │ global_average_pooling1d[ │
 Total params: 47,506 (185.57 KB)
 Trainable params: 47,506 (185.57 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.3873 - loss: 1.7322 - val_accuracy: 0.8285 - val_loss: 0.5884
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.7580 - loss: 0.7342 - val_accuracy: 0.8965 - val_loss: 0.3452
Test loss: 0.4003435969352722
Test accuracy: 0.8769000172615051

from k3im.fnet import FNetModel # jax ✅, tensorflow ✅, torch ✅
model = FNetModel(
Model: "functional_31"
┃ Layer (type)                        Output Shape                       Param # ┃
│ input_layer_22 (InputLayer)        │ (None, 28, 28, 1)             │           0 │
│ patches (Patches)                  │ (None, 16, 49)                │           0 │
│ dense_71 (Dense)                   │ (None, 16, 64)                │       3,200 │
│ f_net_layer (FNetLayer)            │ (None, 16, 64)                │       8,576 │
│ f_net_layer_1 (FNetLayer)          │ (None, 16, 64)                │       8,576 │
│ global_average_pooling1d_1         │ (None, 64)                    │           0 │
│ (GlobalAveragePooling1D)           │                               │             │
│ dropout_62 (Dropout)               │ (None, 64)                    │           0 │
│ dense_76 (Dense)                   │ (None, 10)                    │         650 │
 Total params: 21,002 (82.04 KB)
 Trainable params: 21,002 (82.04 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 7s 16ms/step - accuracy: 0.3982 - loss: 1.7336 - val_accuracy: 0.8240 - val_loss: 0.5880
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 6s 15ms/step - accuracy: 0.7478 - loss: 0.7953 - val_accuracy: 0.8858 - val_loss: 0.3823
Test loss: 0.43270713090896606
Test accuracy: 0.8708000183105469

from k3im.focalnet import focalnet_kid # jax ✅, tensorflow ✅, torch ✅
model = focalnet_kid(img_size=28, in_channels=1, num_classes=10)
Model: "functional_33"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_25            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ patch_embed.proj (Conv2D) │ (None, 7, 7, 96)       │      1,632 │ input_layer_25[0][0]       │
│ reshape_8 (Reshape)       │ (None, 49, 96)         │          0 │ patch_embed.proj[0][0]     │
│ patch_embed.norm          │ (None, 49, 96)         │        192 │ reshape_8[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ dropout_63 (Dropout)      │ (None, 49, 96)         │          0 │ patch_embed.norm[0][0]     │
│ layers.0.blocks.0.norm1   │ (None, 49, 96)         │        192 │ dropout_63[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ reshape_9 (Reshape)       │ (None, 7, 7, 96)       │          0 │ layers.0.blocks.0.norm1[0… │
│ layers.0.blocks.0.modula… │ (None, 7, 7, 96)       │     40,803 │ reshape_9[0][0]            │
│ (FocalModulation)         │                        │            │                            │
│ reshape_10 (Reshape)      │ (None, 49, 96)         │          0 │ layers.0.blocks.0.modulat… │
│ stochastic_depth_4        │ (None, 49, 96)         │          0 │ reshape_10[0][0]           │
│ (StochasticDepth)         │                        │            │                            │
│ add_65 (Add)              │ (None, 49, 96)         │          0 │ dropout_63[0][0],          │
│                           │                        │            │ stochastic_depth_4[0][0]   │
│ reshape_11 (Reshape)      │ (None, 7, 7, 96)       │          0 │ add_65[0][0]               │
│ layers.0.blocks.0.norm2   │ (None, 7, 7, 96)       │        192 │ reshape_11[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ layers.0.blocks.0.mlp.fc1 │ (None, 7, 7, 384)      │     37,248 │ layers.0.blocks.0.norm2[0… │
│ (Dense)                   │                        │            │                            │
│ dropout_65 (Dropout)      │ (None, 7, 7, 384)      │          0 │ layers.0.blocks.0.mlp.fc1… │
│ layers.0.blocks.0.mlp.fc2 │ (None, 7, 7, 96)       │     36,960 │ dropout_65[0][0]           │
│ (Dense)                   │                        │            │                            │
│ dropout_66 (Dropout)      │ (None, 7, 7, 96)       │          0 │ layers.0.blocks.0.mlp.fc2… │
│ stochastic_depth_5        │ (None, 7, 7, 96)       │          0 │ dropout_66[0][0]           │
│ (StochasticDepth)         │                        │            │                            │
│ add_66 (Add)              │ (None, 7, 7, 96)       │          0 │ stochastic_depth_5[0][0],  │
│                           │                        │            │ reshape_11[0][0]           │
│ reshape_12 (Reshape)      │ (None, 49, 96)         │          0 │ add_66[0][0]               │
│ norm (LayerNormalization) │ (None, 49, 96)         │        192 │ reshape_12[0][0]           │
│ global_average_pooling1d… │ (None, 96)             │          0 │ norm[0][0]                 │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ flatten (Flatten)         │ (None, 96)             │          0 │ global_average_pooling1d_… │
│ head (Dense)              │ (None, 10)             │        970 │ flatten[0][0]              │
 Total params: 118,381 (462.43 KB)
 Trainable params: 118,381 (462.43 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 23ms/step - accuracy: 0.5703 - loss: 1.2913 - val_accuracy: 0.9475 - val_loss: 0.1807
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.9437 - loss: 0.1907 - val_accuracy: 0.9697 - val_loss: 0.1060
Test loss: 0.1195862889289856
Test accuracy: 0.9613999724388123

from k3im.gmlp import gMLPModel # jax ✅, tensorflow ✅, torch ✅
model = gMLPModel(
Model: "functional_39"
┃ Layer (type)                        Output Shape                       Param # ┃
│ input_layer_26 (InputLayer)        │ (None, 28, 28, 1)             │           0 │
│ patches_1 (Patches)                │ (None, 16, 49)                │           0 │
│ dense_77 (Dense)                   │ (None, 16, 32)                │       1,600 │
│ g_mlp_layer (gMLPLayer)            │ (None, 16, 32)                │       3,568 │
│ g_mlp_layer_1 (gMLPLayer)          │ (None, 16, 32)                │       3,568 │
│ g_mlp_layer_2 (gMLPLayer)          │ (None, 16, 32)                │       3,568 │
│ g_mlp_layer_3 (gMLPLayer)          │ (None, 16, 32)                │       3,568 │
│ global_average_pooling1d_3         │ (None, 32)                    │           0 │
│ (GlobalAveragePooling1D)           │                               │             │
│ dropout_71 (Dropout)               │ (None, 32)                    │           0 │
│ dense_90 (Dense)                   │ (None, 10)                    │         330 │
 Total params: 16,202 (63.29 KB)
 Trainable params: 16,202 (63.29 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.2239 - loss: 2.1543 - val_accuracy: 0.7037 - val_loss: 0.8542
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.5964 - loss: 1.1702 - val_accuracy: 0.8968 - val_loss: 0.3385
Test loss: 0.3988232910633087
Test accuracy: 0.8780999779701233

from k3im.mlp_mixer import MixerModel # jax ✅, tensorflow ✅, torch ✅
model = MixerModel(

Model: "functional_49"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_31            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ patches_2 (Patches)       │ (None, 16, 49)         │          0 │ input_layer_31[0][0]       │
│ dense_91 (Dense)          │ (None, 16, 32)         │      1,600 │ patches_2[0][0]            │
│ position_embedding        │ (None, 16, 32)         │        512 │ dense_91[0][0]             │
│ (PositionEmbedding)       │                        │            │                            │
│ add_67 (Add)              │ (None, 16, 32)         │          0 │ dense_91[0][0],            │
│                           │                        │            │ position_embedding[0][0]   │
│ mlp_mixer_layer           │ (None, 16, 32)         │      1,680 │ add_67[0][0]               │
│ (MLPMixerLayer)           │                        │            │                            │
│ mlp_mixer_layer_1         │ (None, 16, 32)         │      1,680 │ mlp_mixer_layer[0][0]      │
│ (MLPMixerLayer)           │                        │            │                            │
│ mlp_mixer_layer_2         │ (None, 16, 32)         │      1,680 │ mlp_mixer_layer_1[0][0]    │
│ (MLPMixerLayer)           │                        │            │                            │
│ mlp_mixer_layer_3         │ (None, 16, 32)         │      1,680 │ mlp_mixer_layer_2[0][0]    │
│ (MLPMixerLayer)           │                        │            │                            │
│ global_average_pooling1d… │ (None, 32)             │          0 │ mlp_mixer_layer_3[0][0]    │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dropout_80 (Dropout)      │ (None, 32)             │          0 │ global_average_pooling1d_… │
│ dense_108 (Dense)         │ (None, 10)             │        330 │ dropout_80[0][0]           │
 Total params: 9,162 (35.79 KB)
 Trainable params: 9,162 (35.79 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.2378 - loss: 2.1815 - val_accuracy: 0.7837 - val_loss: 0.6431
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 31ms/step - accuracy: 0.6868 - loss: 0.9226 - val_accuracy: 0.8920 - val_loss: 0.3507
Test loss: 0.4144611358642578
Test accuracy: 0.8709999918937683

from k3im.simple_vit import SimpleViT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT(
    image_size=(28, 28),
    patch_size=(7, 7),
Model: "functional_53"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_40            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ extract_patches_4         │ (None, 4, 4, 49)       │          0 │ input_layer_40[0][0]       │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_13 (Reshape)      │ (None, 16, 49)         │          0 │ extract_patches_4[0][0]    │
│ layer_normalization_66    │ (None, 16, 49)         │         98 │ reshape_13[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ dense_109 (Dense)         │ (None, 16, 32)         │      1,600 │ layer_normalization_66[0]… │
│ layer_normalization_67    │ (None, 16, 32)         │         64 │ dense_109[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ add_68 (Add)              │ (None, 16, 32)         │          0 │ layer_normalization_67[0]… │
│ multi_head_attention_29   │ (None, 16, 32)         │     33,568 │ add_68[0][0], add_68[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_69 (Add)              │ (None, 16, 32)         │          0 │ add_68[0][0],              │
│                           │                        │            │ multi_head_attention_29[0… │
│ sequential_30             │ (None, 16, 32)         │      4,321 │ add_69[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_70 (Add)              │ (None, 16, 32)         │          0 │ add_69[0][0],              │
│                           │                        │            │ sequential_30[0][0]        │
│ multi_head_attention_30   │ (None, 16, 32)         │     33,568 │ add_70[0][0], add_70[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_71 (Add)              │ (None, 16, 32)         │          0 │ add_70[0][0],              │
│                           │                        │            │ multi_head_attention_30[0… │
│ sequential_31             │ (None, 16, 32)         │      4,321 │ add_71[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_72 (Add)              │ (None, 16, 32)         │          0 │ add_71[0][0],              │
│                           │                        │            │ sequential_31[0][0]        │
│ layer_normalization_70    │ (None, 16, 32)         │         64 │ add_72[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ avg_pool                  │ (None, 32)             │          0 │ layer_normalization_70[0]… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_114 (Dense)         │ (None, 10)             │        330 │ avg_pool[0][0]             │
 Total params: 77,934 (304.43 KB)
 Trainable params: 77,934 (304.43 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.5580 - loss: 1.3136 - val_accuracy: 0.8928 - val_loss: 0.3520
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.8881 - loss: 0.3604 - val_accuracy: 0.9362 - val_loss: 0.2134
Test loss: 0.2578793466091156
Test accuracy: 0.9205999970436096

from k3im.simple_vit_with_fft import SimpleViTFFT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTFFT(image_size=28, patch_size=7, freq_patch_size=7, num_classes=num_classes, dim=32, depth=2,
                     heads=8, mlp_dim=64, channels=1,
                     dim_head = 16)
Model: "functional_57"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_43            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ lambda_2 (Lambda)         │ (None, 28, 28, 2)      │          0 │ input_layer_43[0][0]       │
│ extract_patches_5         │ (None, 4, 4, 49)       │          0 │ input_layer_43[0][0]       │
│ (ExtractPatches)          │                        │            │                            │
│ extract_patches_6         │ (None, 4, 4, 98)       │          0 │ lambda_2[0][0]             │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_14 (Reshape)      │ (None, 16, 49)         │          0 │ extract_patches_5[0][0]    │
│ reshape_15 (Reshape)      │ (None, 16, 98)         │          0 │ extract_patches_6[0][0]    │
│ layer_normalization_71    │ (None, 16, 49)         │         98 │ reshape_14[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_73    │ (None, 16, 98)         │        196 │ reshape_15[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ dense_115 (Dense)         │ (None, 16, 32)         │      1,600 │ layer_normalization_71[0]… │
│ dense_116 (Dense)         │ (None, 16, 32)         │      3,168 │ layer_normalization_73[0]… │
│ layer_normalization_72    │ (None, 16, 32)         │         64 │ dense_115[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ layer_normalization_74    │ (None, 16, 32)         │         64 │ dense_116[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ add_73 (Add)              │ (None, 16, 32)         │          0 │ layer_normalization_72[0]… │
│ add_74 (Add)              │ (None, 16, 32)         │          0 │ layer_normalization_74[0]… │
│ concatenate_20            │ (None, 32, 32)         │          0 │ add_73[0][0], add_74[0][0] │
│ (Concatenate)             │                        │            │                            │
│ multi_head_attention_31   │ (None, 32, 32)         │     16,800 │ concatenate_20[0][0],      │
│ (MultiHeadAttention)      │                        │            │ concatenate_20[0][0]       │
│ add_75 (Add)              │ (None, 32, 32)         │          0 │ concatenate_20[0][0],      │
│                           │                        │            │ multi_head_attention_31[0… │
│ sequential_32             │ (None, 32, 32)         │      4,256 │ add_75[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_76 (Add)              │ (None, 32, 32)         │          0 │ add_75[0][0],              │
│                           │                        │            │ sequential_32[0][0]        │
│ multi_head_attention_32   │ (None, 32, 32)         │     16,800 │ add_76[0][0], add_76[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_77 (Add)              │ (None, 32, 32)         │          0 │ add_76[0][0],              │
│                           │                        │            │ multi_head_attention_32[0… │
│ sequential_33             │ (None, 32, 32)         │      4,256 │ add_77[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_78 (Add)              │ (None, 32, 32)         │          0 │ add_77[0][0],              │
│                           │                        │            │ sequential_33[0][0]        │
│ layer_normalization_77    │ (None, 32, 32)         │         64 │ add_78[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ avg_pool                  │ (None, 32)             │          0 │ layer_normalization_77[0]… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_121 (Dense)         │ (None, 10)             │        330 │ avg_pool[0][0]             │
 Total params: 47,696 (186.31 KB)
 Trainable params: 47,696 (186.31 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 29ms/step - accuracy: 0.5238 - loss: 1.4078 - val_accuracy: 0.9082 - val_loss: 0.3177
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 29ms/step - accuracy: 0.9021 - loss: 0.3226 - val_accuracy: 0.9453 - val_loss: 0.1842
Test loss: 0.2127855271100998
Test accuracy: 0.9369999766349792

from k3im.simple_vit_with_register_tokens import SimpleViT_RT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT_RT(image_size=28,
Model: "functional_61"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_46            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ extract_patches_7         │ (None, 4, 4, 49)       │          0 │ input_layer_46[0][0]       │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_16 (Reshape)      │ (None, 16, 49)         │          0 │ extract_patches_7[0][0]    │
│ layer_normalization_78    │ (None, 16, 49)         │         98 │ reshape_16[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ dense_122 (Dense)         │ (None, 16, 32)         │      1,600 │ layer_normalization_78[0]… │
│ layer_normalization_79    │ (None, 16, 32)         │         64 │ dense_122[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ add_79 (Add)              │ (None, 16, 32)         │          0 │ layer_normalization_79[0]… │
│ register_tokens           │ (None, 20, 32)         │        128 │ add_79[0][0]               │
│ (RegisterTokens)          │                        │            │                            │
│ multi_head_attention_33   │ (None, 20, 32)         │     33,568 │ register_tokens[0][0],     │
│ (MultiHeadAttention)      │                        │            │ register_tokens[0][0]      │
│ add_80 (Add)              │ (None, 20, 32)         │          0 │ register_tokens[0][0],     │
│                           │                        │            │ multi_head_attention_33[0… │
│ sequential_34             │ (None, 20, 32)         │      4,256 │ add_80[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_81 (Add)              │ (None, 20, 32)         │          0 │ add_80[0][0],              │
│                           │                        │            │ sequential_34[0][0]        │
│ multi_head_attention_34   │ (None, 20, 32)         │     33,568 │ add_81[0][0], add_81[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_82 (Add)              │ (None, 20, 32)         │          0 │ add_81[0][0],              │
│                           │                        │            │ multi_head_attention_34[0… │
│ sequential_35             │ (None, 20, 32)         │      4,256 │ add_82[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_83 (Add)              │ (None, 20, 32)         │          0 │ add_82[0][0],              │
│                           │                        │            │ sequential_35[0][0]        │
│ layer_normalization_82    │ (None, 20, 32)         │         64 │ add_83[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ split (Split)             │ [(None, 16, 32),       │          0 │ layer_normalization_82[0]… │
│                           │ (None, 4, 32)]         │            │                            │
│ avg_pool                  │ (None, 32)             │          0 │ split[0][0]                │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_127 (Dense)         │ (None, 10)             │        330 │ avg_pool[0][0]             │
 Total params: 77,932 (304.42 KB)
 Trainable params: 77,932 (304.42 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.5336 - loss: 1.3876 - val_accuracy: 0.8742 - val_loss: 0.4016
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.8742 - loss: 0.4034 - val_accuracy: 0.9272 - val_loss: 0.2538
Test loss: 0.2974727153778076
Test accuracy: 0.909600019454956

from k3im.swint import SwinTModel # jax ✅, tensorflow ✅, torch ✅
model = SwinTModel(
Model: "functional_65"
┃ Layer (type)                        Output Shape                       Param # ┃
│ input_layer_49 (InputLayer)        │ (None, 28, 28, 1)             │           0 │
│ extract_patches_8 (ExtractPatches) │ (None, 4, 4, 49)              │           0 │
│ reshape_17 (Reshape)               │ (None, 16, 49)                │           0 │
│ layer_normalization_83             │ (None, 16, 49)                │          98 │
│ (LayerNormalization)               │                               │             │
│ dense_128 (Dense)                  │ (None, 16, 32)                │       1,600 │
│ swin_transformer (SwinTransformer) │ (None, 16, 32)                │       5,096 │
│ swin_transformer_1                 │ (None, 16, 32)                │       5,352 │
│ (SwinTransformer)                  │                               │             │
│ patch_merging (PatchMerging)       │ (None, 4, 64)                 │       8,192 │
│ global_average_pooling1d_5         │ (None, 64)                    │           0 │
│ (GlobalAveragePooling1D)           │                               │             │
│ dense_138 (Dense)                  │ (None, 10)                    │         650 │
 Total params: 20,988 (82.98 KB)
 Trainable params: 20,220 (78.98 KB)
 Non-trainable params: 768 (4.00 KB)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.6364 - loss: 1.1143 - val_accuracy: 0.9118 - val_loss: 0.2969
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.8750 - loss: 0.3952 - val_accuracy: 0.9358 - val_loss: 0.2215
Test loss: 0.26952508091926575
Test accuracy: 0.9175999760627747

from k3im.token_learner import ViTokenLearner # jax check with jax ✅, tensorflow ✅, torch ✅
model = ViTokenLearner(image_size=28,
    pool="mean", use_token_learner=True)
Model: "functional_72"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_52            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ conv2d_11 (Conv2D)        │ (None, 4, 4, 64)       │      3,200 │ input_layer_52[0][0]       │
│ reshape_18 (Reshape)      │ (None, 16, 64)         │          0 │ conv2d_11[0][0]            │
│ patch_encoder             │ (None, 16, 64)         │      1,024 │ reshape_18[0][0]           │
│ (PatchEncoder)            │                        │            │                            │
│ dropout_95 (Dropout)      │ (None, 16, 64)         │          0 │ patch_encoder[0][0]        │
│ layer_normalization_88    │ (None, 16, 64)         │        128 │ dropout_95[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ multi_head_attention_35   │ (None, 16, 64)         │     66,368 │ layer_normalization_88[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_88[0]… │
│ add_84 (Add)              │ (None, 16, 64)         │          0 │ multi_head_attention_35[0… │
│                           │                        │            │ dropout_95[0][0]           │
│ layer_normalization_89    │ (None, 16, 64)         │        128 │ add_84[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ sequential_38             │ (None, 16, 64)         │      4,320 │ layer_normalization_89[0]… │
│ (Sequential)              │                        │            │                            │
│ add_85 (Add)              │ (None, 16, 64)         │          0 │ sequential_38[0][0],       │
│                           │                        │            │ add_84[0][0]               │
│ layer_normalization_91    │ (None, 16, 64)         │        128 │ add_85[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ multi_head_attention_36   │ (None, 16, 64)         │     66,368 │ layer_normalization_91[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_91[0]… │
│ add_86 (Add)              │ (None, 16, 64)         │          0 │ multi_head_attention_36[0… │
│                           │                        │            │ add_85[0][0]               │
│ layer_normalization_92    │ (None, 16, 64)         │        128 │ add_86[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ sequential_39             │ (None, 16, 64)         │      4,320 │ layer_normalization_92[0]… │
│ (Sequential)              │                        │            │                            │
│ add_87 (Add)              │ (None, 16, 64)         │          0 │ sequential_39[0][0],       │
│                           │                        │            │ add_86[0][0]               │
│ layer_normalization_94    │ (None, 16, 64)         │        128 │ add_87[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ multi_head_attention_37   │ (None, 16, 64)         │     66,368 │ layer_normalization_94[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_94[0]… │
│ add_88 (Add)              │ (None, 16, 64)         │          0 │ multi_head_attention_37[0… │
│                           │                        │            │ add_87[0][0]               │
│ layer_normalization_95    │ (None, 16, 64)         │        128 │ add_88[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ sequential_40             │ (None, 16, 64)         │      4,320 │ layer_normalization_95[0]… │
│ (Sequential)              │                        │            │                            │
│ add_89 (Add)              │ (None, 16, 64)         │          0 │ sequential_40[0][0],       │
│                           │                        │            │ add_88[0][0]               │
│ reshape_19 (Reshape)      │ (None, 4, 4, 64)       │          0 │ add_89[0][0]               │
│ layer_normalization_97    │ (None, 4, 4, 64)       │        128 │ reshape_19[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ sequential_41             │ (None, 2, 16)          │      1,260 │ layer_normalization_97[0]… │
│ (Sequential)              │                        │            │                            │
│ expand_dims (ExpandDims)  │ (None, 2, 16, 1)       │          0 │ sequential_41[0][0]        │
│ reshape_21 (Reshape)      │ (None, 1, 16, 64)      │          0 │ reshape_19[0][0]           │
│ multiply (Multiply)       │ (None, 2, 16, 64)      │          0 │ expand_dims[0][0],         │
│                           │                        │            │ reshape_21[0][0]           │
│ mean (Mean)               │ (None, 2, 64)          │          0 │ multiply[0][0]             │
│ layer_normalization_98    │ (None, 2, 64)          │        128 │ mean[0][0]                 │
│ (LayerNormalization)      │                        │            │                            │
│ multi_head_attention_38   │ (None, 2, 64)          │     66,368 │ layer_normalization_98[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_98[0]… │
│ add_90 (Add)              │ (None, 2, 64)          │          0 │ multi_head_attention_38[0… │
│                           │                        │            │ mean[0][0]                 │
│ layer_normalization_99    │ (None, 2, 64)          │        128 │ add_90[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ sequential_42             │ (None, 2, 64)          │      4,320 │ layer_normalization_99[0]… │
│ (Sequential)              │                        │            │                            │
│ add_91 (Add)              │ (None, 2, 64)          │          0 │ sequential_42[0][0],       │
│                           │                        │            │ add_90[0][0]               │
│ layer_normalization_101   │ (None, 2, 64)          │        128 │ add_91[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ avg_pool                  │ (None, 64)             │          0 │ layer_normalization_101[0… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_147 (Dense)         │ (None, 10)             │        650 │ avg_pool[0][0]             │
 Total params: 290,166 (1.11 MB)
 Trainable params: 290,166 (1.11 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 24s 56ms/step - accuracy: 0.3959 - loss: 2.0571 - val_accuracy: 0.7192 - val_loss: 1.7431
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 23s 55ms/step - accuracy: 0.7304 - loss: 1.7321 - val_accuracy: 0.8182 - val_loss: 1.6422
Test loss: 1.654205560684204
Test accuracy: 0.8080999851226807

from k3im.vit import ViT # jax ✅, tensorflow ✅, torch ✅
model = ViT(
    image_size=(28, 28),
    patch_size=(7, 7),
Model: "functional_76"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_58            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ extract_patches_9         │ (None, 4, 4, 49)       │          0 │ input_layer_58[0][0]       │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_22 (Reshape)      │ (None, 16, 49)         │          0 │ extract_patches_9[0][0]    │
│ layer_normalization_102   │ (None, 16, 49)         │         98 │ reshape_22[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ dense_148 (Dense)         │ (None, 16, 32)         │      1,600 │ layer_normalization_102[0… │
│ layer_normalization_103   │ (None, 16, 32)         │         64 │ dense_148[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ class_token_position_emb  │ (None, 17, 32)         │        576 │ layer_normalization_103[0… │
│ (ClassTokenPositionEmb)   │                        │            │                            │
│ multi_head_attention_39   │ (None, 17, 32)         │     33,568 │ class_token_position_emb[ │
│ (MultiHeadAttention)      │                        │            │ class_token_position_emb[ │
│ add_92 (Add)              │ (None, 17, 32)         │          0 │ class_token_position_emb[ │
│                           │                        │            │ multi_head_attention_39[0… │
│ sequential_43             │ (None, 17, 32)         │      4,321 │ add_92[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_93 (Add)              │ (None, 17, 32)         │          0 │ add_92[0][0],              │
│                           │                        │            │ sequential_43[0][0]        │
│ multi_head_attention_40   │ (None, 17, 32)         │     33,568 │ add_93[0][0], add_93[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_94 (Add)              │ (None, 17, 32)         │          0 │ add_93[0][0],              │
│                           │                        │            │ multi_head_attention_40[0… │
│ sequential_44             │ (None, 17, 32)         │      4,321 │ add_94[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_95 (Add)              │ (None, 17, 32)         │          0 │ add_94[0][0],              │
│                           │                        │            │ sequential_44[0][0]        │
│ layer_normalization_106   │ (None, 17, 32)         │         64 │ add_95[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
│ max_pool                  │ (None, 32)             │          0 │ layer_normalization_106[0… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_153 (Dense)         │ (None, 10)             │        330 │ max_pool[0][0]             │
 Total params: 78,510 (306.68 KB)
 Trainable params: 78,510 (306.68 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 15s 36ms/step - accuracy: 0.5635 - loss: 1.3064 - val_accuracy: 0.8988 - val_loss: 0.3343
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 28ms/step - accuracy: 0.9013 - loss: 0.3197 - val_accuracy: 0.9478 - val_loss: 0.1735
Test loss: 0.18965917825698853
Test accuracy: 0.942799985408783

from k3im.vit_with_patch_dropout import SimpleViTPD # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTPD(
Model: "functional_82"
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
│ input_layer_61            │ (None, 28, 28, 1)      │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
│ extract_patches_10        │ (None, 4, 4, 49)       │          0 │ input_layer_61[0][0]       │
│ (ExtractPatches)          │                        │            │                            │
│ reshape_23 (Reshape)      │ (None, 16, 49)         │          0 │ extract_patches_10[0][0]   │
│ layer_normalization_107   │ (None, 16, 49)         │         98 │ reshape_23[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
│ dense_154 (Dense)         │ (None, 16, 32)         │      1,600 │ layer_normalization_107[0… │
│ layer_normalization_108   │ (None, 16, 32)         │         64 │ dense_154[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
│ add_96 (Add)              │ (None, 16, 32)         │          0 │ layer_normalization_108[0… │
│ spatial_dropout1d         │ (None, 16, 32)         │          0 │ add_96[0][0]               │
│ (SpatialDropout1D)        │                        │            │                            │
│ multi_head_attention_41   │ (None, 16, 32)         │     16,800 │ spatial_dropout1d[0][0],   │
│ (MultiHeadAttention)      │                        │            │ spatial_dropout1d[0][0]    │
│ add_97 (Add)              │ (None, 16, 32)         │          0 │ spatial_dropout1d[0][0],   │
│                           │                        │            │ multi_head_attention_41[0… │
│ sequential_45             │ (None, 16, 32)         │      2,826 │ add_97[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_98 (Add)              │ (None, 16, 32)         │          0 │ add_97[0][0],              │
│                           │                        │            │ sequential_45[0][0]        │
│ multi_head_attention_42   │ (None, 16, 32)         │     16,800 │ add_98[0][0], add_98[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
│ add_99 (Add)              │ (None, 16, 32)         │          0 │ add_98[0][0],              │
│                           │                        │            │ multi_head_attention_42[0… │
│ sequential_46             │ (None, 16, 32)         │      2,826 │ add_99[0][0]               │
│ (Sequential)              │                        │            │                            │
│ add_100 (Add)             │ (None, 16, 32)         │          0 │ add_99[0][0],              │
│                           │                        │            │ sequential_46[0][0]        │
│ multi_head_attention_43   │ (None, 16, 32)         │     16,800 │ add_100[0][0],             │
│ (MultiHeadAttention)      │                        │            │ add_100[0][0]              │
│ add_101 (Add)             │ (None, 16, 32)         │          0 │ add_100[0][0],             │
│                           │                        │            │ multi_head_attention_43[0… │
│ sequential_47             │ (None, 16, 32)         │      2,826 │ add_101[0][0]              │
│ (Sequential)              │                        │            │                            │
│ add_102 (Add)             │ (None, 16, 32)         │          0 │ add_101[0][0],             │
│                           │                        │            │ sequential_47[0][0]        │
│ multi_head_attention_44   │ (None, 16, 32)         │     16,800 │ add_102[0][0],             │
│ (MultiHeadAttention)      │                        │            │ add_102[0][0]              │
│ add_103 (Add)             │ (None, 16, 32)         │          0 │ add_102[0][0],             │
│                           │                        │            │ multi_head_attention_44[0… │
│ sequential_48             │ (None, 16, 32)         │      2,826 │ add_103[0][0]              │
│ (Sequential)              │                        │            │                            │
│ add_104 (Add)             │ (None, 16, 32)         │          0 │ add_103[0][0],             │
│                           │                        │            │ sequential_48[0][0]        │
│ layer_normalization_113   │ (None, 16, 32)         │         64 │ add_104[0][0]              │
│ (LayerNormalization)      │                        │            │                            │
│ avg_pool                  │ (None, 32)             │          0 │ layer_normalization_113[0… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
│ dense_163 (Dense)         │ (None, 10)             │        330 │ avg_pool[0][0]             │
 Total params: 80,660 (315.08 KB)
 Trainable params: 80,660 (315.08 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 39ms/step - accuracy: 0.4712 - loss: 1.5234 - val_accuracy: 0.8593 - val_loss: 0.4529
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 39ms/step - accuracy: 0.8093 - loss: 0.5927 - val_accuracy: 0.9195 - val_loss: 0.2688
Test loss: 0.30448290705680847
Test accuracy: 0.9017000198364258