Examples 2D
!pip install k3im --upgrade
Requirement already satisfied: k3im in /usr/local/lib/python3.10/dist-packages (0.0.6)
Requirement already satisfied: keras>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from k3im) (3.0.1)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (1.4.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (1.23.5)
Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (13.7.0)
Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (0.0.7)
Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (3.9.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras>=3.0.0->k3im) (0.1.8)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.0.0->k3im) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.0.0->k3im) (2.16.1)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->k3im) (0.1.2)
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
batch_size = 128
epochs = 2
def train_model(model):
model.compile(loss=keras.losses.CategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
from k3im.cait import CaiTModel # jax ✅, tensorflow ✅, torch ✅
model = CaiTModel(
image_size=(28, 28),
patch_size=(7, 7),
num_classes=10,
dim=32,
depth=2,
heads=8,
mlp_dim=64,
cls_depth=2,
channels=1,
dim_head=64,
)
model.summary()
Model: "functional_5"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ - │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches │ (None, 4, 4, 49) │ 0 │ input_layer[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape (Reshape) │ (None, 16, 49) │ 0 │ extract_patches[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization │ (None, 16, 49) │ 98 │ reshape[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_1 │ (None, 16, 32) │ 64 │ dense[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ position_emb │ (None, 16, 32) │ 512 │ layer_normalization_1[0][… │ │ (PositionEmb) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention │ (None, 16, 32) │ 67,104 │ position_emb[0][0], │ │ (MultiHeadAttention) │ │ │ position_emb[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add (Add) │ (None, 16, 32) │ 0 │ position_emb[0][0], │ │ │ │ │ multi_head_attention[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential (Sequential) │ (None, 16, 32) │ 4,256 │ add[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_1 (Add) │ (None, 16, 32) │ 0 │ add[0][0], │ │ │ │ │ sequential[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_1 │ (None, 16, 32) │ 67,104 │ add_1[0][0], add_1[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_2 (Add) │ (None, 16, 32) │ 0 │ add_1[0][0], │ │ │ │ │ multi_head_attention_1[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_1 (Sequential) │ (None, 16, 32) │ 4,256 │ add_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_3 (Add) │ (None, 16, 32) │ 0 │ add_2[0][0], │ │ │ │ │ sequential_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_4 │ (None, 16, 32) │ 64 │ add_3[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ cls__token (CLS_Token) │ [(None, 17, 32), │ 32 │ layer_normalization_4[0][… │ │ │ (None, 1, 32)] │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate (Concatenate) │ (None, 17, 32) │ 0 │ cls__token[0][1], │ │ │ │ │ layer_normalization_4[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_2 │ (None, 1, 32) │ 67,104 │ cls__token[0][1], │ │ (MultiHeadAttention) │ │ │ concatenate[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_4 (Add) │ (None, 1, 32) │ 0 │ cls__token[0][1], │ │ │ │ │ multi_head_attention_2[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_2 (Sequential) │ (None, 1, 32) │ 4,256 │ add_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_5 (Add) │ (None, 1, 32) │ 0 │ add_4[0][0], │ │ │ │ │ sequential_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_1 │ (None, 17, 32) │ 0 │ add_5[0][0], │ │ (Concatenate) │ │ │ layer_normalization_4[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_3 │ (None, 1, 32) │ 67,104 │ add_5[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_6 (Add) │ (None, 1, 32) │ 0 │ add_5[0][0], │ │ │ │ │ multi_head_attention_3[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_3 (Sequential) │ (None, 1, 32) │ 4,256 │ add_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_7 (Add) │ (None, 1, 32) │ 0 │ add_6[0][0], │ │ │ │ │ sequential_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_7 │ (None, 1, 32) │ 64 │ add_7[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ squeeze (Squeeze) │ (None, 32) │ 0 │ layer_normalization_7[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_9 (Dense) │ (None, 10) │ 330 │ squeeze[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 288,204 (1.10 MB)
Trainable params: 288,204 (1.10 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 22s 52ms/step - accuracy: 0.5326 - loss: 1.3825 - val_accuracy: 0.9125 - val_loss: 0.2950
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 19s 45ms/step - accuracy: 0.9129 - loss: 0.2858 - val_accuracy: 0.9593 - val_loss: 0.1494
Test loss: 0.16203546524047852
Test accuracy: 0.9521999955177307
from k3im.cct import CCT # jax ✅, tensorflow ✅, torch ✅
model = CCT(
input_shape=input_shape,
num_heads=8,
projection_dim=32,
kernel_size=3,
stride=3,
padding=2,
transformer_units=[16, 32],
stochastic_depth_rate=0.6,
transformer_layers=2,
num_classes=num_classes,
positional_emb=False,
)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 14s 33ms/step - accuracy: 0.6643 - loss: 1.0071 - val_accuracy: 0.9318 - val_loss: 0.2200
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 32ms/step - accuracy: 0.9292 - loss: 0.2308 - val_accuracy: 0.9532 - val_loss: 0.1575
Test loss: 0.1650298684835434
Test accuracy: 0.947700023651123
from k3im.convmixer import ConvMixer # jax ✅ # tf something not right # something aint right with torch
model = ConvMixer(
image_size=28, filters=64, depth=8, kernel_size=3, patch_size=2, num_classes=10, num_channels=1
)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 25s 59ms/step - accuracy: 0.6054 - loss: 1.3247 - val_accuracy: 0.1113 - val_loss: 9747.7539
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 30s 70ms/step - accuracy: 0.8339 - loss: 0.5841 - val_accuracy: 0.1113 - val_loss: 1693.4771
Test loss: 1714.0731201171875
Test accuracy: 0.10279999673366547
from k3im.cross_vit import CrossViT # jax ✅, tensorflow ✅, torch ✅
model = CrossViT(
image_size=28,
num_classes=10,
sm_dim=32,
lg_dim=42,
channels=1,
sm_patch_size=4,
sm_enc_depth=1,
sm_enc_heads=8,
sm_enc_mlp_dim=48,
sm_enc_dim_head=56,
lg_patch_size=7,
lg_enc_depth=2,
lg_enc_heads=8,
lg_enc_mlp_dim=84,
lg_enc_dim_head=72,
cross_attn_depth=2,
cross_attn_heads=8,
cross_attn_dim_head=64,
depth=3,
dropout=0.1,
emb_dropout=0.1
)
model.summary()
Model: "functional_21"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_8 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_2 │ (None, 4, 4, 49) │ 0 │ input_layer_8[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_2 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_15 │ (None, 16, 49) │ 98 │ reshape_2[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_18 (Dense) │ (None, 16, 42) │ 2,100 │ layer_normalization_15[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_16 │ (None, 16, 42) │ 84 │ dense_18[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_1 │ (None, 7, 7, 16) │ 0 │ input_layer_8[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ cls__token_2 (CLS_Token) │ [(None, 17, 42), │ 42 │ layer_normalization_16[0]… │ │ │ (None, 1, 42)] │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_1 (Reshape) │ (None, 49, 16) │ 0 │ extract_patches_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ position_emb_2 │ (None, 17, 42) │ 714 │ cls__token_2[0][0] │ │ (PositionEmb) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_13 │ (None, 49, 16) │ 32 │ reshape_1[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_11 (Dropout) │ (None, 17, 42) │ 0 │ position_emb_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_17 (Dense) │ (None, 49, 32) │ 544 │ layer_normalization_13[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_7 │ (None, 17, 42) │ 98,538 │ dropout_11[0][0], │ │ (MultiHeadAttention) │ │ │ dropout_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_14 │ (None, 49, 32) │ 64 │ dense_17[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_22 (Add) │ (None, 17, 42) │ 0 │ dropout_11[0][0], │ │ │ │ │ multi_head_attention_7[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ cls__token_1 (CLS_Token) │ [(None, 50, 32), │ 32 │ layer_normalization_14[0]… │ │ │ (None, 1, 32)] │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_6 (Sequential) │ (None, 17, 42) │ 7,266 │ add_22[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ position_emb_1 │ (None, 50, 32) │ 1,600 │ cls__token_1[0][0] │ │ (PositionEmb) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_23 (Add) │ (None, 17, 42) │ 0 │ add_22[0][0], │ │ │ │ │ sequential_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_10 (Dropout) │ (None, 50, 32) │ 0 │ position_emb_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_8 │ (None, 17, 42) │ 98,538 │ add_23[0][0], add_23[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_6 │ (None, 50, 32) │ 58,720 │ dropout_10[0][0], │ │ (MultiHeadAttention) │ │ │ dropout_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_24 (Add) │ (None, 17, 42) │ 0 │ add_23[0][0], │ │ │ │ │ multi_head_attention_8[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_20 (Add) │ (None, 50, 32) │ 0 │ dropout_10[0][0], │ │ │ │ │ multi_head_attention_6[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_7 (Sequential) │ (None, 17, 42) │ 7,266 │ add_24[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_5 (Sequential) │ (None, 50, 32) │ 3,216 │ add_20[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_25 (Add) │ (None, 17, 42) │ 0 │ add_24[0][0], │ │ │ │ │ sequential_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_21 (Add) │ (None, 50, 32) │ 0 │ add_20[0][0], │ │ │ │ │ sequential_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_21 │ (None, 17, 42) │ 84 │ add_25[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_18 │ (None, 50, 32) │ 64 │ add_21[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_16 (Dropout) │ (None, 17, 42) │ 0 │ layer_normalization_21[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_13 (Dropout) │ (None, 50, 32) │ 0 │ layer_normalization_18[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_2 (GetItem) │ (None, 1, 42) │ 0 │ dropout_16[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_1 (GetItem) │ (None, 49, 32) │ 0 │ dropout_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_27 (Dense) │ (None, 1, 32) │ 1,376 │ get_item_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_4 │ (None, 50, 32) │ 0 │ dense_27[0][0], │ │ (Concatenate) │ │ │ get_item_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_11 │ (None, 1, 32) │ 67,104 │ dense_27[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_29 (Add) │ (None, 1, 32) │ 0 │ dense_27[0][0], │ │ │ │ │ multi_head_attention_11[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item (GetItem) │ (None, 1, 32) │ 0 │ dropout_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_5 │ (None, 50, 32) │ 0 │ add_29[0][0], │ │ (Concatenate) │ │ │ get_item_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_25 (Dense) │ (None, 1, 42) │ 1,386 │ get_item[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_3 (GetItem) │ (None, 16, 42) │ 0 │ dropout_16[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_12 │ (None, 1, 32) │ 67,104 │ add_29[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_2 │ (None, 17, 42) │ 0 │ dense_25[0][0], │ │ (Concatenate) │ │ │ get_item_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_30 (Add) │ (None, 1, 32) │ 0 │ add_29[0][0], │ │ │ │ │ multi_head_attention_12[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_9 │ (None, 1, 42) │ 87,594 │ dense_25[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_23 │ (None, 1, 32) │ 64 │ add_30[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_26 (Add) │ (None, 1, 42) │ 0 │ dense_25[0][0], │ │ │ │ │ multi_head_attention_9[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_22 (Dropout) │ (None, 1, 32) │ 0 │ layer_normalization_23[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_3 │ (None, 17, 42) │ 0 │ add_26[0][0], │ │ (Concatenate) │ │ │ get_item_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_28 (Dense) │ (None, 1, 42) │ 1,386 │ dropout_22[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_10 │ (None, 1, 42) │ 87,594 │ add_26[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_31 (Add) │ (None, 1, 42) │ 0 │ dense_28[0][0], │ │ │ │ │ get_item_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_27 (Add) │ (None, 1, 42) │ 0 │ add_26[0][0], │ │ │ │ │ multi_head_attention_10[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_7 │ (None, 17, 42) │ 0 │ add_31[0][0], │ │ (Concatenate) │ │ │ get_item_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_22 │ (None, 1, 42) │ 84 │ add_27[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_14 │ (None, 17, 42) │ 98,538 │ concatenate_7[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_19 (Dropout) │ (None, 1, 42) │ 0 │ layer_normalization_22[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_34 (Add) │ (None, 17, 42) │ 0 │ concatenate_7[0][0], │ │ │ │ │ multi_head_attention_14[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_26 (Dense) │ (None, 1, 32) │ 1,376 │ dropout_19[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_9 (Sequential) │ (None, 17, 42) │ 7,266 │ add_34[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_28 (Add) │ (None, 1, 32) │ 0 │ dense_26[0][0], │ │ │ │ │ get_item[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_35 (Add) │ (None, 17, 42) │ 0 │ add_34[0][0], │ │ │ │ │ sequential_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_6 │ (None, 50, 32) │ 0 │ add_28[0][0], │ │ (Concatenate) │ │ │ get_item_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_15 │ (None, 17, 42) │ 98,538 │ add_35[0][0], add_35[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_13 │ (None, 50, 32) │ 58,720 │ concatenate_6[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_36 (Add) │ (None, 17, 42) │ 0 │ add_35[0][0], │ │ │ │ │ multi_head_attention_15[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_32 (Add) │ (None, 50, 32) │ 0 │ concatenate_6[0][0], │ │ │ │ │ multi_head_attention_13[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_10 │ (None, 17, 42) │ 7,266 │ add_36[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_8 (Sequential) │ (None, 50, 32) │ 3,216 │ add_32[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_37 (Add) │ (None, 17, 42) │ 0 │ add_36[0][0], │ │ │ │ │ sequential_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_33 (Add) │ (None, 50, 32) │ 0 │ add_32[0][0], │ │ │ │ │ sequential_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_28 │ (None, 17, 42) │ 84 │ add_37[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_25 │ (None, 50, 32) │ 64 │ add_33[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_27 (Dropout) │ (None, 17, 42) │ 0 │ layer_normalization_28[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_24 (Dropout) │ (None, 50, 32) │ 0 │ layer_normalization_25[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_6 (GetItem) │ (None, 1, 42) │ 0 │ dropout_27[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_5 (GetItem) │ (None, 49, 32) │ 0 │ dropout_24[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_37 (Dense) │ (None, 1, 32) │ 1,376 │ get_item_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_10 │ (None, 50, 32) │ 0 │ dense_37[0][0], │ │ (Concatenate) │ │ │ get_item_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_18 │ (None, 1, 32) │ 67,104 │ dense_37[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_4 (GetItem) │ (None, 1, 32) │ 0 │ dropout_24[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_41 (Add) │ (None, 1, 32) │ 0 │ dense_37[0][0], │ │ │ │ │ multi_head_attention_18[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_35 (Dense) │ (None, 1, 42) │ 1,386 │ get_item_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_7 (GetItem) │ (None, 16, 42) │ 0 │ dropout_27[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_11 │ (None, 50, 32) │ 0 │ add_41[0][0], │ │ (Concatenate) │ │ │ get_item_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_8 │ (None, 17, 42) │ 0 │ dense_35[0][0], │ │ (Concatenate) │ │ │ get_item_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_19 │ (None, 1, 32) │ 67,104 │ add_41[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_16 │ (None, 1, 42) │ 87,594 │ dense_35[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_42 (Add) │ (None, 1, 32) │ 0 │ add_41[0][0], │ │ │ │ │ multi_head_attention_19[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_38 (Add) │ (None, 1, 42) │ 0 │ dense_35[0][0], │ │ │ │ │ multi_head_attention_16[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_30 │ (None, 1, 32) │ 64 │ add_42[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_9 │ (None, 17, 42) │ 0 │ add_38[0][0], │ │ (Concatenate) │ │ │ get_item_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_33 (Dropout) │ (None, 1, 32) │ 0 │ layer_normalization_30[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_17 │ (None, 1, 42) │ 87,594 │ add_38[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_38 (Dense) │ (None, 1, 42) │ 1,386 │ dropout_33[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_39 (Add) │ (None, 1, 42) │ 0 │ add_38[0][0], │ │ │ │ │ multi_head_attention_17[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_43 (Add) │ (None, 1, 42) │ 0 │ dense_38[0][0], │ │ │ │ │ get_item_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_29 │ (None, 1, 42) │ 84 │ add_39[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_13 │ (None, 17, 42) │ 0 │ add_43[0][0], │ │ (Concatenate) │ │ │ get_item_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_30 (Dropout) │ (None, 1, 42) │ 0 │ layer_normalization_29[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_21 │ (None, 17, 42) │ 98,538 │ concatenate_13[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_36 (Dense) │ (None, 1, 32) │ 1,376 │ dropout_30[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_46 (Add) │ (None, 17, 42) │ 0 │ concatenate_13[0][0], │ │ │ │ │ multi_head_attention_21[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_40 (Add) │ (None, 1, 32) │ 0 │ dense_36[0][0], │ │ │ │ │ get_item_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_12 │ (None, 17, 42) │ 7,266 │ add_46[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_12 │ (None, 50, 32) │ 0 │ add_40[0][0], │ │ (Concatenate) │ │ │ get_item_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_47 (Add) │ (None, 17, 42) │ 0 │ add_46[0][0], │ │ │ │ │ sequential_12[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_20 │ (None, 50, 32) │ 58,720 │ concatenate_12[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_12[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_22 │ (None, 17, 42) │ 98,538 │ add_47[0][0], add_47[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_44 (Add) │ (None, 50, 32) │ 0 │ concatenate_12[0][0], │ │ │ │ │ multi_head_attention_20[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_48 (Add) │ (None, 17, 42) │ 0 │ add_47[0][0], │ │ │ │ │ multi_head_attention_22[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_11 │ (None, 50, 32) │ 3,216 │ add_44[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_13 │ (None, 17, 42) │ 7,266 │ add_48[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_45 (Add) │ (None, 50, 32) │ 0 │ add_44[0][0], │ │ │ │ │ sequential_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_49 (Add) │ (None, 17, 42) │ 0 │ add_48[0][0], │ │ │ │ │ sequential_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_32 │ (None, 50, 32) │ 64 │ add_45[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_35 │ (None, 17, 42) │ 84 │ add_49[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_35 (Dropout) │ (None, 50, 32) │ 0 │ layer_normalization_32[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_38 (Dropout) │ (None, 17, 42) │ 0 │ layer_normalization_35[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_8 (GetItem) │ (None, 1, 32) │ 0 │ dropout_35[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_10 (GetItem) │ (None, 1, 42) │ 0 │ dropout_38[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_45 (Dense) │ (None, 1, 42) │ 1,386 │ get_item_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_11 (GetItem) │ (None, 16, 42) │ 0 │ dropout_38[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_9 (GetItem) │ (None, 49, 32) │ 0 │ dropout_35[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_47 (Dense) │ (None, 1, 32) │ 1,376 │ get_item_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_14 │ (None, 17, 42) │ 0 │ dense_45[0][0], │ │ (Concatenate) │ │ │ get_item_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_16 │ (None, 50, 32) │ 0 │ dense_47[0][0], │ │ (Concatenate) │ │ │ get_item_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_23 │ (None, 1, 42) │ 87,594 │ dense_45[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_14[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_25 │ (None, 1, 32) │ 67,104 │ dense_47[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_16[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_50 (Add) │ (None, 1, 42) │ 0 │ dense_45[0][0], │ │ │ │ │ multi_head_attention_23[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_53 (Add) │ (None, 1, 32) │ 0 │ dense_47[0][0], │ │ │ │ │ multi_head_attention_25[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_15 │ (None, 17, 42) │ 0 │ add_50[0][0], │ │ (Concatenate) │ │ │ get_item_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_17 │ (None, 50, 32) │ 0 │ add_53[0][0], │ │ (Concatenate) │ │ │ get_item_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_24 │ (None, 1, 42) │ 87,594 │ add_50[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_15[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_26 │ (None, 1, 32) │ 67,104 │ add_53[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_17[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_51 (Add) │ (None, 1, 42) │ 0 │ add_50[0][0], │ │ │ │ │ multi_head_attention_24[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_54 (Add) │ (None, 1, 32) │ 0 │ add_53[0][0], │ │ │ │ │ multi_head_attention_26[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_36 │ (None, 1, 42) │ 84 │ add_51[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_37 │ (None, 1, 32) │ 64 │ add_54[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_41 (Dropout) │ (None, 1, 42) │ 0 │ layer_normalization_36[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_44 (Dropout) │ (None, 1, 32) │ 0 │ layer_normalization_37[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_46 (Dense) │ (None, 1, 32) │ 1,376 │ dropout_41[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_48 (Dense) │ (None, 1, 42) │ 1,386 │ dropout_44[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_52 (Add) │ (None, 1, 32) │ 0 │ dense_46[0][0], │ │ │ │ │ get_item_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_55 (Add) │ (None, 1, 42) │ 0 │ dense_48[0][0], │ │ │ │ │ get_item_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_18 │ (None, 50, 32) │ 0 │ add_52[0][0], │ │ (Concatenate) │ │ │ get_item_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_19 │ (None, 17, 42) │ 0 │ add_55[0][0], │ │ (Concatenate) │ │ │ get_item_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_12 (GetItem) │ (None, 32) │ 0 │ concatenate_18[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item_13 (GetItem) │ (None, 42) │ 0 │ concatenate_19[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_38 │ (None, 32) │ 64 │ get_item_12[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_39 │ (None, 42) │ 84 │ get_item_13[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_49 (Dense) │ (None, 10) │ 330 │ layer_normalization_38[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_50 (Dense) │ (None, 10) │ 430 │ layer_normalization_39[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_56 (Add) │ (None, 10) │ 0 │ dense_49[0][0], │ │ │ │ │ dense_50[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 1,772,498 (6.76 MB)
Trainable params: 1,772,498 (6.76 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 76s 179ms/step - accuracy: 0.5782 - loss: 1.2357 - val_accuracy: 0.9185 - val_loss: 0.2657
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 73s 173ms/step - accuracy: 0.8943 - loss: 0.3341 - val_accuracy: 0.9500 - val_loss: 0.1754
Test loss: 0.18887896835803986
Test accuracy: 0.9437000155448914
from k3im.deepvit import DeepViT # jax ✅, tensorflow ✅
model = DeepViT(image_size=28,
patch_size=7,
num_classes=10,
dim=64,
depth=2,
heads=8,
mlp_dim=84,
pool="cls",
channels=1,
dim_head=64,
dropout=0.0,
emb_dropout=0.0)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.6203 - loss: 1.1394 - val_accuracy: 0.9172 - val_loss: 0.2594
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.9250 - loss: 0.2398 - val_accuracy: 0.9567 - val_loss: 0.1421
Test loss: 0.16465066373348236
Test accuracy: 0.9492999911308289
from k3im.eanet import EANet # jax ✅, tensorflow ✅, torch ✅
model = EANet(
input_shape=input_shape,
patch_size=7,
embedding_dim=64,
num_transformer_blocks=2,
mlp_dim=32,
num_heads=16,
dim_coefficient=2,
attention_dropout=0.5,
projection_dropout=0.5,
num_classes=10,
)
model.summary()
Model: "functional_27"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_21 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patch_extract │ (None, 16, 49) │ 0 │ input_layer_21[0][0] │ │ (PatchExtract) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patch_embedding │ (None, 16, 64) │ 4,224 │ patch_extract[0][0] │ │ (PatchEmbedding) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_46 │ (None, 16, 64) │ 128 │ patch_embedding[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_58 (Dense) │ (None, 16, 128) │ 8,320 │ layer_normalization_46[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_4 (Reshape) │ (None, 16, 32, 4) │ 0 │ dense_58[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose (Transpose) │ (None, 32, 16, 4) │ 0 │ reshape_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_59 (Dense) │ (None, 32, 16, 32) │ 160 │ transpose[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ softmax_29 (Softmax) │ (None, 32, 16, 32) │ 0 │ dense_59[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ lambda (Lambda) │ (None, 32, 16, 32) │ 0 │ softmax_29[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_52 (Dropout) │ (None, 32, 16, 32) │ 0 │ lambda[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_60 (Dense) │ (None, 32, 16, 4) │ 132 │ dropout_52[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_1 (Transpose) │ (None, 16, 32, 4) │ 0 │ dense_60[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_5 (Reshape) │ (None, 16, 128) │ 0 │ transpose_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_61 (Dense) │ (None, 16, 64) │ 8,256 │ reshape_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_53 (Dropout) │ (None, 16, 64) │ 0 │ dense_61[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_61 (Add) │ (None, 16, 64) │ 0 │ dropout_53[0][0], │ │ │ │ │ patch_embedding[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_47 │ (None, 16, 64) │ 128 │ add_61[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_62 (Dense) │ (None, 16, 32) │ 2,080 │ layer_normalization_47[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_54 (Dropout) │ (None, 16, 32) │ 0 │ dense_62[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_63 (Dense) │ (None, 16, 64) │ 2,112 │ dropout_54[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_55 (Dropout) │ (None, 16, 64) │ 0 │ dense_63[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_62 (Add) │ (None, 16, 64) │ 0 │ dropout_55[0][0], │ │ │ │ │ add_61[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_48 │ (None, 16, 64) │ 128 │ add_62[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_64 (Dense) │ (None, 16, 128) │ 8,320 │ layer_normalization_48[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_6 (Reshape) │ (None, 16, 32, 4) │ 0 │ dense_64[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_2 (Transpose) │ (None, 32, 16, 4) │ 0 │ reshape_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_65 (Dense) │ (None, 32, 16, 32) │ 160 │ transpose_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ softmax_30 (Softmax) │ (None, 32, 16, 32) │ 0 │ dense_65[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ lambda_1 (Lambda) │ (None, 32, 16, 32) │ 0 │ softmax_30[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_56 (Dropout) │ (None, 32, 16, 32) │ 0 │ lambda_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_66 (Dense) │ (None, 32, 16, 4) │ 132 │ dropout_56[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_3 (Transpose) │ (None, 16, 32, 4) │ 0 │ dense_66[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_7 (Reshape) │ (None, 16, 128) │ 0 │ transpose_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_67 (Dense) │ (None, 16, 64) │ 8,256 │ reshape_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_57 (Dropout) │ (None, 16, 64) │ 0 │ dense_67[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_63 (Add) │ (None, 16, 64) │ 0 │ dropout_57[0][0], │ │ │ │ │ add_62[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_49 │ (None, 16, 64) │ 128 │ add_63[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_68 (Dense) │ (None, 16, 32) │ 2,080 │ layer_normalization_49[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_58 (Dropout) │ (None, 16, 32) │ 0 │ dense_68[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_69 (Dense) │ (None, 16, 64) │ 2,112 │ dropout_58[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_59 (Dropout) │ (None, 16, 64) │ 0 │ dense_69[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_64 (Add) │ (None, 16, 64) │ 0 │ dropout_59[0][0], │ │ │ │ │ add_63[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ global_average_pooling1d │ (None, 64) │ 0 │ add_64[0][0] │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_70 (Dense) │ (None, 10) │ 650 │ global_average_pooling1d[… │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 47,506 (185.57 KB)
Trainable params: 47,506 (185.57 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.3873 - loss: 1.7322 - val_accuracy: 0.8285 - val_loss: 0.5884
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.7580 - loss: 0.7342 - val_accuracy: 0.8965 - val_loss: 0.3452
Test loss: 0.4003435969352722
Test accuracy: 0.8769000172615051
from k3im.fnet import FNetModel # jax ✅, tensorflow ✅, torch ✅
model = FNetModel(
image_size=28,
patch_size=7,
embedding_dim=64,
num_blocks=2,
dropout_rate=0.4,
num_classes=10,
positional_encoding=False,
num_channels=1,
)
model.summary()
Model: "functional_31"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer_22 (InputLayer) │ (None, 28, 28, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ patches (Patches) │ (None, 16, 49) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_71 (Dense) │ (None, 16, 64) │ 3,200 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer (FNetLayer) │ (None, 16, 64) │ 8,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer_1 (FNetLayer) │ (None, 16, 64) │ 8,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ global_average_pooling1d_1 │ (None, 64) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dropout_62 (Dropout) │ (None, 64) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_76 (Dense) │ (None, 10) │ 650 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
Total params: 21,002 (82.04 KB)
Trainable params: 21,002 (82.04 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 7s 16ms/step - accuracy: 0.3982 - loss: 1.7336 - val_accuracy: 0.8240 - val_loss: 0.5880
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 6s 15ms/step - accuracy: 0.7478 - loss: 0.7953 - val_accuracy: 0.8858 - val_loss: 0.3823
Test loss: 0.43270713090896606
Test accuracy: 0.8708000183105469
from k3im.focalnet import focalnet_kid # jax ✅, tensorflow ✅, torch ✅
model = focalnet_kid(img_size=28, in_channels=1, num_classes=10)
model.summary()
Model: "functional_33"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_25 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patch_embed.proj (Conv2D) │ (None, 7, 7, 96) │ 1,632 │ input_layer_25[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_8 (Reshape) │ (None, 49, 96) │ 0 │ patch_embed.proj[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patch_embed.norm │ (None, 49, 96) │ 192 │ reshape_8[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_63 (Dropout) │ (None, 49, 96) │ 0 │ patch_embed.norm[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layers.0.blocks.0.norm1 │ (None, 49, 96) │ 192 │ dropout_63[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_9 (Reshape) │ (None, 7, 7, 96) │ 0 │ layers.0.blocks.0.norm1[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layers.0.blocks.0.modula… │ (None, 7, 7, 96) │ 40,803 │ reshape_9[0][0] │ │ (FocalModulation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_10 (Reshape) │ (None, 49, 96) │ 0 │ layers.0.blocks.0.modulat… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth_4 │ (None, 49, 96) │ 0 │ reshape_10[0][0] │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_65 (Add) │ (None, 49, 96) │ 0 │ dropout_63[0][0], │ │ │ │ │ stochastic_depth_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_11 (Reshape) │ (None, 7, 7, 96) │ 0 │ add_65[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layers.0.blocks.0.norm2 │ (None, 7, 7, 96) │ 192 │ reshape_11[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layers.0.blocks.0.mlp.fc1 │ (None, 7, 7, 384) │ 37,248 │ layers.0.blocks.0.norm2[0… │ │ (Dense) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_65 (Dropout) │ (None, 7, 7, 384) │ 0 │ layers.0.blocks.0.mlp.fc1… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layers.0.blocks.0.mlp.fc2 │ (None, 7, 7, 96) │ 36,960 │ dropout_65[0][0] │ │ (Dense) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_66 (Dropout) │ (None, 7, 7, 96) │ 0 │ layers.0.blocks.0.mlp.fc2… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth_5 │ (None, 7, 7, 96) │ 0 │ dropout_66[0][0] │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_66 (Add) │ (None, 7, 7, 96) │ 0 │ stochastic_depth_5[0][0], │ │ │ │ │ reshape_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_12 (Reshape) │ (None, 49, 96) │ 0 │ add_66[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ norm (LayerNormalization) │ (None, 49, 96) │ 192 │ reshape_12[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ global_average_pooling1d… │ (None, 96) │ 0 │ norm[0][0] │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ flatten (Flatten) │ (None, 96) │ 0 │ global_average_pooling1d_… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ head (Dense) │ (None, 10) │ 970 │ flatten[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 118,381 (462.43 KB)
Trainable params: 118,381 (462.43 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 23ms/step - accuracy: 0.5703 - loss: 1.2913 - val_accuracy: 0.9475 - val_loss: 0.1807
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.9437 - loss: 0.1907 - val_accuracy: 0.9697 - val_loss: 0.1060
Test loss: 0.1195862889289856
Test accuracy: 0.9613999724388123
from k3im.gmlp import gMLPModel # jax ✅, tensorflow ✅, torch ✅
model = gMLPModel(
image_size=28,
patch_size=7,
embedding_dim=32,
num_blocks=4,
dropout_rate=0.5,
num_classes=num_classes,
positional_encoding=False,
num_channels=1,
)
model.summary()
Model: "functional_39"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer_26 (InputLayer) │ (None, 28, 28, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ patches_1 (Patches) │ (None, 16, 49) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_77 (Dense) │ (None, 16, 32) │ 1,600 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ g_mlp_layer (gMLPLayer) │ (None, 16, 32) │ 3,568 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ g_mlp_layer_1 (gMLPLayer) │ (None, 16, 32) │ 3,568 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ g_mlp_layer_2 (gMLPLayer) │ (None, 16, 32) │ 3,568 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ g_mlp_layer_3 (gMLPLayer) │ (None, 16, 32) │ 3,568 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ global_average_pooling1d_3 │ (None, 32) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dropout_71 (Dropout) │ (None, 32) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_90 (Dense) │ (None, 10) │ 330 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
Total params: 16,202 (63.29 KB)
Trainable params: 16,202 (63.29 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.2239 - loss: 2.1543 - val_accuracy: 0.7037 - val_loss: 0.8542
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.5964 - loss: 1.1702 - val_accuracy: 0.8968 - val_loss: 0.3385
Test loss: 0.3988232910633087
Test accuracy: 0.8780999779701233
from k3im.mlp_mixer import MixerModel # jax ✅, tensorflow ✅, torch ✅
model = MixerModel(
image_size=28,
patch_size=7,
embedding_dim=32,
num_blocks=4,
dropout_rate=0.5,
num_classes=num_classes,
positional_encoding=True,
num_channels=1,
)
model.summary()
Model: "functional_49"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_31 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patches_2 (Patches) │ (None, 16, 49) │ 0 │ input_layer_31[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_91 (Dense) │ (None, 16, 32) │ 1,600 │ patches_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ position_embedding │ (None, 16, 32) │ 512 │ dense_91[0][0] │ │ (PositionEmbedding) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_67 (Add) │ (None, 16, 32) │ 0 │ dense_91[0][0], │ │ │ │ │ position_embedding[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ mlp_mixer_layer │ (None, 16, 32) │ 1,680 │ add_67[0][0] │ │ (MLPMixerLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ mlp_mixer_layer_1 │ (None, 16, 32) │ 1,680 │ mlp_mixer_layer[0][0] │ │ (MLPMixerLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ mlp_mixer_layer_2 │ (None, 16, 32) │ 1,680 │ mlp_mixer_layer_1[0][0] │ │ (MLPMixerLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ mlp_mixer_layer_3 │ (None, 16, 32) │ 1,680 │ mlp_mixer_layer_2[0][0] │ │ (MLPMixerLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ global_average_pooling1d… │ (None, 32) │ 0 │ mlp_mixer_layer_3[0][0] │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_80 (Dropout) │ (None, 32) │ 0 │ global_average_pooling1d_… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_108 (Dense) │ (None, 10) │ 330 │ dropout_80[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 9,162 (35.79 KB)
Trainable params: 9,162 (35.79 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 30ms/step - accuracy: 0.2378 - loss: 2.1815 - val_accuracy: 0.7837 - val_loss: 0.6431
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 13s 31ms/step - accuracy: 0.6868 - loss: 0.9226 - val_accuracy: 0.8920 - val_loss: 0.3507
Test loss: 0.4144611358642578
Test accuracy: 0.8709999918937683
from k3im.simple_vit import SimpleViT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT(
image_size=(28, 28),
patch_size=(7, 7),
num_classes=num_classes,
dim=32,
depth=2,
heads=8,
mlp_dim=65,
channels=1,
dim_head=32,
pool="mean",
)
model.summary()
Model: "functional_53"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_40 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_4 │ (None, 4, 4, 49) │ 0 │ input_layer_40[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_13 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_66 │ (None, 16, 49) │ 98 │ reshape_13[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_109 (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization_66[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_67 │ (None, 16, 32) │ 64 │ dense_109[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_68 (Add) │ (None, 16, 32) │ 0 │ layer_normalization_67[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_29 │ (None, 16, 32) │ 33,568 │ add_68[0][0], add_68[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_69 (Add) │ (None, 16, 32) │ 0 │ add_68[0][0], │ │ │ │ │ multi_head_attention_29[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_30 │ (None, 16, 32) │ 4,321 │ add_69[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_70 (Add) │ (None, 16, 32) │ 0 │ add_69[0][0], │ │ │ │ │ sequential_30[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_30 │ (None, 16, 32) │ 33,568 │ add_70[0][0], add_70[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_71 (Add) │ (None, 16, 32) │ 0 │ add_70[0][0], │ │ │ │ │ multi_head_attention_30[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_31 │ (None, 16, 32) │ 4,321 │ add_71[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_72 (Add) │ (None, 16, 32) │ 0 │ add_71[0][0], │ │ │ │ │ sequential_31[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_70 │ (None, 16, 32) │ 64 │ add_72[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 32) │ 0 │ layer_normalization_70[0]… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_114 (Dense) │ (None, 10) │ 330 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 77,934 (304.43 KB)
Trainable params: 77,934 (304.43 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.5580 - loss: 1.3136 - val_accuracy: 0.8928 - val_loss: 0.3520
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 10s 24ms/step - accuracy: 0.8881 - loss: 0.3604 - val_accuracy: 0.9362 - val_loss: 0.2134
Test loss: 0.2578793466091156
Test accuracy: 0.9205999970436096
from k3im.simple_vit_with_fft import SimpleViTFFT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTFFT(image_size=28, patch_size=7, freq_patch_size=7, num_classes=num_classes, dim=32, depth=2,
heads=8, mlp_dim=64, channels=1,
dim_head = 16)
model.summary()
Model: "functional_57"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_43 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ lambda_2 (Lambda) │ (None, 28, 28, 2) │ 0 │ input_layer_43[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_5 │ (None, 4, 4, 49) │ 0 │ input_layer_43[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_6 │ (None, 4, 4, 98) │ 0 │ lambda_2[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_14 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_15 (Reshape) │ (None, 16, 98) │ 0 │ extract_patches_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_71 │ (None, 16, 49) │ 98 │ reshape_14[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_73 │ (None, 16, 98) │ 196 │ reshape_15[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_115 (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization_71[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_116 (Dense) │ (None, 16, 32) │ 3,168 │ layer_normalization_73[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_72 │ (None, 16, 32) │ 64 │ dense_115[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_74 │ (None, 16, 32) │ 64 │ dense_116[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_73 (Add) │ (None, 16, 32) │ 0 │ layer_normalization_72[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_74 (Add) │ (None, 16, 32) │ 0 │ layer_normalization_74[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_20 │ (None, 32, 32) │ 0 │ add_73[0][0], add_74[0][0] │ │ (Concatenate) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_31 │ (None, 32, 32) │ 16,800 │ concatenate_20[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_20[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_75 (Add) │ (None, 32, 32) │ 0 │ concatenate_20[0][0], │ │ │ │ │ multi_head_attention_31[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_32 │ (None, 32, 32) │ 4,256 │ add_75[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_76 (Add) │ (None, 32, 32) │ 0 │ add_75[0][0], │ │ │ │ │ sequential_32[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_32 │ (None, 32, 32) │ 16,800 │ add_76[0][0], add_76[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_77 (Add) │ (None, 32, 32) │ 0 │ add_76[0][0], │ │ │ │ │ multi_head_attention_32[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_33 │ (None, 32, 32) │ 4,256 │ add_77[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_78 (Add) │ (None, 32, 32) │ 0 │ add_77[0][0], │ │ │ │ │ sequential_33[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_77 │ (None, 32, 32) │ 64 │ add_78[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 32) │ 0 │ layer_normalization_77[0]… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_121 (Dense) │ (None, 10) │ 330 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 47,696 (186.31 KB)
Trainable params: 47,696 (186.31 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 29ms/step - accuracy: 0.5238 - loss: 1.4078 - val_accuracy: 0.9082 - val_loss: 0.3177
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 29ms/step - accuracy: 0.9021 - loss: 0.3226 - val_accuracy: 0.9453 - val_loss: 0.1842
Test loss: 0.2127855271100998
Test accuracy: 0.9369999766349792
from k3im.simple_vit_with_register_tokens import SimpleViT_RT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT_RT(image_size=28,
patch_size=7,
num_classes=num_classes,
dim=32,
depth=2,
heads=4,
mlp_dim=64,
num_register_tokens=4,
channels=1,
dim_head=64,)
model.summary()
Model: "functional_61"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_46 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_7 │ (None, 4, 4, 49) │ 0 │ input_layer_46[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_16 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_78 │ (None, 16, 49) │ 98 │ reshape_16[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_122 (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization_78[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_79 │ (None, 16, 32) │ 64 │ dense_122[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_79 (Add) │ (None, 16, 32) │ 0 │ layer_normalization_79[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ register_tokens │ (None, 20, 32) │ 128 │ add_79[0][0] │ │ (RegisterTokens) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_33 │ (None, 20, 32) │ 33,568 │ register_tokens[0][0], │ │ (MultiHeadAttention) │ │ │ register_tokens[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_80 (Add) │ (None, 20, 32) │ 0 │ register_tokens[0][0], │ │ │ │ │ multi_head_attention_33[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_34 │ (None, 20, 32) │ 4,256 │ add_80[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_81 (Add) │ (None, 20, 32) │ 0 │ add_80[0][0], │ │ │ │ │ sequential_34[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_34 │ (None, 20, 32) │ 33,568 │ add_81[0][0], add_81[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_82 (Add) │ (None, 20, 32) │ 0 │ add_81[0][0], │ │ │ │ │ multi_head_attention_34[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_35 │ (None, 20, 32) │ 4,256 │ add_82[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_83 (Add) │ (None, 20, 32) │ 0 │ add_82[0][0], │ │ │ │ │ sequential_35[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_82 │ (None, 20, 32) │ 64 │ add_83[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ split (Split) │ [(None, 16, 32), │ 0 │ layer_normalization_82[0]… │ │ │ (None, 4, 32)] │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 32) │ 0 │ split[0][0] │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_127 (Dense) │ (None, 10) │ 330 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 77,932 (304.42 KB)
Trainable params: 77,932 (304.42 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.5336 - loss: 1.3876 - val_accuracy: 0.8742 - val_loss: 0.4016
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.8742 - loss: 0.4034 - val_accuracy: 0.9272 - val_loss: 0.2538
Test loss: 0.2974727153778076
Test accuracy: 0.909600019454956
from k3im.swint import SwinTModel # jax ✅, tensorflow ✅, torch ✅
model = SwinTModel(
img_size=28,
patch_size=7,
embed_dim=32,
num_heads=4,
window_size=4,
num_mlp=4,
qkv_bias=True,
dropout_rate=0.2,
shift_size=2,
num_classes=num_classes,
in_channels=1,
)
model.summary()
Model: "functional_65"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer_49 (InputLayer) │ (None, 28, 28, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ extract_patches_8 (ExtractPatches) │ (None, 4, 4, 49) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_17 (Reshape) │ (None, 16, 49) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ layer_normalization_83 │ (None, 16, 49) │ 98 │ │ (LayerNormalization) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_128 (Dense) │ (None, 16, 32) │ 1,600 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ swin_transformer (SwinTransformer) │ (None, 16, 32) │ 5,096 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ swin_transformer_1 │ (None, 16, 32) │ 5,352 │ │ (SwinTransformer) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ patch_merging (PatchMerging) │ (None, 4, 64) │ 8,192 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ global_average_pooling1d_5 │ (None, 64) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_138 (Dense) │ (None, 10) │ 650 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
Total params: 20,988 (82.98 KB)
Trainable params: 20,220 (78.98 KB)
Non-trainable params: 768 (4.00 KB)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.6364 - loss: 1.1143 - val_accuracy: 0.9118 - val_loss: 0.2969
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 11s 26ms/step - accuracy: 0.8750 - loss: 0.3952 - val_accuracy: 0.9358 - val_loss: 0.2215
Test loss: 0.26952508091926575
Test accuracy: 0.9175999760627747
from k3im.token_learner import ViTokenLearner # jax check with jax ✅, tensorflow ✅, torch ✅
model = ViTokenLearner(image_size=28,
patch_size=7,
num_classes=10,
dim=64,
depth=4,
heads=4,
mlp_dim=32,
token_learner_units=2,
channels=1,
dim_head=64,
dropout_rate=0.,
pool="mean", use_token_learner=True)
model.summary()
Model: "functional_72"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_52 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv2d_11 (Conv2D) │ (None, 4, 4, 64) │ 3,200 │ input_layer_52[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_18 (Reshape) │ (None, 16, 64) │ 0 │ conv2d_11[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ patch_encoder │ (None, 16, 64) │ 1,024 │ reshape_18[0][0] │ │ (PatchEncoder) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_95 (Dropout) │ (None, 16, 64) │ 0 │ patch_encoder[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_88 │ (None, 16, 64) │ 128 │ dropout_95[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_35 │ (None, 16, 64) │ 66,368 │ layer_normalization_88[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_88[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_84 (Add) │ (None, 16, 64) │ 0 │ multi_head_attention_35[0… │ │ │ │ │ dropout_95[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_89 │ (None, 16, 64) │ 128 │ add_84[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_38 │ (None, 16, 64) │ 4,320 │ layer_normalization_89[0]… │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_85 (Add) │ (None, 16, 64) │ 0 │ sequential_38[0][0], │ │ │ │ │ add_84[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_91 │ (None, 16, 64) │ 128 │ add_85[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_36 │ (None, 16, 64) │ 66,368 │ layer_normalization_91[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_91[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_86 (Add) │ (None, 16, 64) │ 0 │ multi_head_attention_36[0… │ │ │ │ │ add_85[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_92 │ (None, 16, 64) │ 128 │ add_86[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_39 │ (None, 16, 64) │ 4,320 │ layer_normalization_92[0]… │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_87 (Add) │ (None, 16, 64) │ 0 │ sequential_39[0][0], │ │ │ │ │ add_86[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_94 │ (None, 16, 64) │ 128 │ add_87[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_37 │ (None, 16, 64) │ 66,368 │ layer_normalization_94[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_94[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_88 (Add) │ (None, 16, 64) │ 0 │ multi_head_attention_37[0… │ │ │ │ │ add_87[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_95 │ (None, 16, 64) │ 128 │ add_88[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_40 │ (None, 16, 64) │ 4,320 │ layer_normalization_95[0]… │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_89 (Add) │ (None, 16, 64) │ 0 │ sequential_40[0][0], │ │ │ │ │ add_88[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_19 (Reshape) │ (None, 4, 4, 64) │ 0 │ add_89[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_97 │ (None, 4, 4, 64) │ 128 │ reshape_19[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_41 │ (None, 2, 16) │ 1,260 │ layer_normalization_97[0]… │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ expand_dims (ExpandDims) │ (None, 2, 16, 1) │ 0 │ sequential_41[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_21 (Reshape) │ (None, 1, 16, 64) │ 0 │ reshape_19[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multiply (Multiply) │ (None, 2, 16, 64) │ 0 │ expand_dims[0][0], │ │ │ │ │ reshape_21[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ mean (Mean) │ (None, 2, 64) │ 0 │ multiply[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_98 │ (None, 2, 64) │ 128 │ mean[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_38 │ (None, 2, 64) │ 66,368 │ layer_normalization_98[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_98[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_90 (Add) │ (None, 2, 64) │ 0 │ multi_head_attention_38[0… │ │ │ │ │ mean[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_99 │ (None, 2, 64) │ 128 │ add_90[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_42 │ (None, 2, 64) │ 4,320 │ layer_normalization_99[0]… │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_91 (Add) │ (None, 2, 64) │ 0 │ sequential_42[0][0], │ │ │ │ │ add_90[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_101 │ (None, 2, 64) │ 128 │ add_91[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 64) │ 0 │ layer_normalization_101[0… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_147 (Dense) │ (None, 10) │ 650 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 290,166 (1.11 MB)
Trainable params: 290,166 (1.11 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 24s 56ms/step - accuracy: 0.3959 - loss: 2.0571 - val_accuracy: 0.7192 - val_loss: 1.7431
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 23s 55ms/step - accuracy: 0.7304 - loss: 1.7321 - val_accuracy: 0.8182 - val_loss: 1.6422
Test loss: 1.654205560684204
Test accuracy: 0.8080999851226807
from k3im.vit import ViT # jax ✅, tensorflow ✅, torch ✅
model = ViT(
image_size=(28, 28),
patch_size=(7, 7),
num_classes=num_classes,
dim=32,
depth=2,
heads=8,
mlp_dim=65,
channels=1,
dim_head=32,
pool="mean",
)
model.summary()
Model: "functional_76"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_58 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_9 │ (None, 4, 4, 49) │ 0 │ input_layer_58[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_22 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_102 │ (None, 16, 49) │ 98 │ reshape_22[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_148 (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization_102[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_103 │ (None, 16, 32) │ 64 │ dense_148[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ class_token_position_emb │ (None, 17, 32) │ 576 │ layer_normalization_103[0… │ │ (ClassTokenPositionEmb) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_39 │ (None, 17, 32) │ 33,568 │ class_token_position_emb[… │ │ (MultiHeadAttention) │ │ │ class_token_position_emb[… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_92 (Add) │ (None, 17, 32) │ 0 │ class_token_position_emb[… │ │ │ │ │ multi_head_attention_39[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_43 │ (None, 17, 32) │ 4,321 │ add_92[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_93 (Add) │ (None, 17, 32) │ 0 │ add_92[0][0], │ │ │ │ │ sequential_43[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_40 │ (None, 17, 32) │ 33,568 │ add_93[0][0], add_93[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_94 (Add) │ (None, 17, 32) │ 0 │ add_93[0][0], │ │ │ │ │ multi_head_attention_40[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_44 │ (None, 17, 32) │ 4,321 │ add_94[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_95 (Add) │ (None, 17, 32) │ 0 │ add_94[0][0], │ │ │ │ │ sequential_44[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_106 │ (None, 17, 32) │ 64 │ add_95[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ max_pool │ (None, 32) │ 0 │ layer_normalization_106[0… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_153 (Dense) │ (None, 10) │ 330 │ max_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 78,510 (306.68 KB)
Trainable params: 78,510 (306.68 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 15s 36ms/step - accuracy: 0.5635 - loss: 1.3064 - val_accuracy: 0.8988 - val_loss: 0.3343
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 12s 28ms/step - accuracy: 0.9013 - loss: 0.3197 - val_accuracy: 0.9478 - val_loss: 0.1735
Test loss: 0.18965917825698853
Test accuracy: 0.942799985408783
from k3im.vit_with_patch_dropout import SimpleViTPD # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTPD(
image_size=28,
patch_size=7,
num_classes=10,
dim=32,
depth=4,
heads=8,
mlp_dim=42,
patch_dropout=0.25,
channels=1,
dim_head=16,
pool="mean",
)
model.summary()
Model: "functional_82"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_61 │ (None, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ extract_patches_10 │ (None, 4, 4, 49) │ 0 │ input_layer_61[0][0] │ │ (ExtractPatches) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_23 (Reshape) │ (None, 16, 49) │ 0 │ extract_patches_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_107 │ (None, 16, 49) │ 98 │ reshape_23[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_154 (Dense) │ (None, 16, 32) │ 1,600 │ layer_normalization_107[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_108 │ (None, 16, 32) │ 64 │ dense_154[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_96 (Add) │ (None, 16, 32) │ 0 │ layer_normalization_108[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ spatial_dropout1d │ (None, 16, 32) │ 0 │ add_96[0][0] │ │ (SpatialDropout1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_41 │ (None, 16, 32) │ 16,800 │ spatial_dropout1d[0][0], │ │ (MultiHeadAttention) │ │ │ spatial_dropout1d[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_97 (Add) │ (None, 16, 32) │ 0 │ spatial_dropout1d[0][0], │ │ │ │ │ multi_head_attention_41[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_45 │ (None, 16, 32) │ 2,826 │ add_97[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_98 (Add) │ (None, 16, 32) │ 0 │ add_97[0][0], │ │ │ │ │ sequential_45[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_42 │ (None, 16, 32) │ 16,800 │ add_98[0][0], add_98[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_99 (Add) │ (None, 16, 32) │ 0 │ add_98[0][0], │ │ │ │ │ multi_head_attention_42[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_46 │ (None, 16, 32) │ 2,826 │ add_99[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_100 (Add) │ (None, 16, 32) │ 0 │ add_99[0][0], │ │ │ │ │ sequential_46[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_43 │ (None, 16, 32) │ 16,800 │ add_100[0][0], │ │ (MultiHeadAttention) │ │ │ add_100[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_101 (Add) │ (None, 16, 32) │ 0 │ add_100[0][0], │ │ │ │ │ multi_head_attention_43[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_47 │ (None, 16, 32) │ 2,826 │ add_101[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_102 (Add) │ (None, 16, 32) │ 0 │ add_101[0][0], │ │ │ │ │ sequential_47[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_44 │ (None, 16, 32) │ 16,800 │ add_102[0][0], │ │ (MultiHeadAttention) │ │ │ add_102[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_103 (Add) │ (None, 16, 32) │ 0 │ add_102[0][0], │ │ │ │ │ multi_head_attention_44[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_48 │ (None, 16, 32) │ 2,826 │ add_103[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_104 (Add) │ (None, 16, 32) │ 0 │ add_103[0][0], │ │ │ │ │ sequential_48[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_113 │ (None, 16, 32) │ 64 │ add_104[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 32) │ 0 │ layer_normalization_113[0… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_163 (Dense) │ (None, 10) │ 330 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 80,660 (315.08 KB)
Trainable params: 80,660 (315.08 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 39ms/step - accuracy: 0.4712 - loss: 1.5234 - val_accuracy: 0.8593 - val_loss: 0.4529
Epoch 2/2
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 39ms/step - accuracy: 0.8093 - loss: 0.5927 - val_accuracy: 0.9195 - val_loss: 0.2688
Test loss: 0.30448290705680847
Test accuracy: 0.9017000198364258