Examples 3D/Space-Time
!pip install k3im medmnist -qq --upgrade
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import medmnist
import numpy as np
import tensorflow as tf # For data processes only
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
num_classes = 11
def download_and_prepare_dataset(data_info: dict):
"""Utility function to download the dataset.
Arguments:
data_info (dict): Dataset metadata.
"""
data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])
with np.load(data_path) as data:
# Get videos
train_videos = data["train_images"]
valid_videos = data["val_images"]
test_videos = data["test_images"]
# Get labels
train_labels = data["train_labels"].flatten()
valid_labels = data["val_labels"].flatten()
test_labels = data["test_labels"].flatten()
return (
(train_videos, train_labels),
(valid_videos, valid_labels),
(test_videos, test_labels),
)
# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]
# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
"""Preprocess the frames tensors and parse the labels."""
# Preprocess images
frames = tf.image.convert_image_dtype(
frames[
..., tf.newaxis
], # The new axis is to help for further processing with Conv3D layers
tf.float32,
)
# Parse label
label = tf.cast(label, tf.float32)
return frames, label
def prepare_dataloader(
videos: np.ndarray,
labels: np.ndarray,
loader_type: str = "train",
batch_size: int = BATCH_SIZE,
):
"""Utility function to prepare the dataloader."""
dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
if loader_type == "train":
dataset = dataset.shuffle(BATCH_SIZE * 2)
dataloader = (
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataloader
trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")
batch_size = 32
epochs = 2
def train_model(model):
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
model.fit(trainloader, epochs=epochs, validation_data=validloader)
score = model.evaluate(testloader)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
from k3im.cait_3d import CAiT3DModel # fixed jax ✅
model = CAiT3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=11,
dim=128,
depth=3,
cls_depth=3,
heads=8,
mlp_dim=84,
channels=1,
dim_head=64,
)
model.summary()
Model: "functional_15"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_6 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape (Reshape) │ (None, 7, 4, 7, 4, 7, │ 0 │ input_layer_6[0][0] │ │ │ 4, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose (Transpose) │ (None, 4, 4, 4, 7, 7, │ 0 │ reshape[0][0] │ │ │ 7, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_1 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ transpose[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization │ (None, 4, 4, 4, 343) │ 686 │ reshape_1[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_2 (Dense) │ (None, 4, 4, 4, 128) │ 44,032 │ layer_normalization[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_1 │ (None, 4, 4, 4, 128) │ 256 │ dense_2[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_2 (Reshape) │ (None, 64, 128) │ 0 │ layer_normalization_1[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention │ (None, 64, 128) │ 263,808 │ reshape_2[0][0], │ │ (MultiHeadAttention) │ │ │ reshape_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_4 (Add) │ (None, 64, 128) │ 0 │ reshape_2[0][0], │ │ │ │ │ multi_head_attention[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_4 (Sequential) │ (None, 64, 128) │ 21,972 │ add_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_5 (Add) │ (None, 64, 128) │ 0 │ add_4[0][0], │ │ │ │ │ sequential_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_1 │ (None, 64, 128) │ 263,808 │ add_5[0][0], add_5[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_6 (Add) │ (None, 64, 128) │ 0 │ add_5[0][0], │ │ │ │ │ multi_head_attention_1[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_5 (Sequential) │ (None, 64, 128) │ 21,972 │ add_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_7 (Add) │ (None, 64, 128) │ 0 │ add_6[0][0], │ │ │ │ │ sequential_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_2 │ (None, 64, 128) │ 263,808 │ add_7[0][0], add_7[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_8 (Add) │ (None, 64, 128) │ 0 │ add_7[0][0], │ │ │ │ │ multi_head_attention_2[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_6 (Sequential) │ (None, 64, 128) │ 21,972 │ add_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_9 (Add) │ (None, 64, 128) │ 0 │ add_8[0][0], │ │ │ │ │ sequential_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_5 │ (None, 64, 128) │ 256 │ add_9[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ cls__token (CLS_Token) │ [(None, 65, 128), │ 128 │ layer_normalization_5[0][… │ │ │ (None, 1, 128)] │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate (Concatenate) │ (None, 65, 128) │ 0 │ cls__token[0][1], │ │ │ │ │ layer_normalization_5[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_3 │ (None, 1, 128) │ 263,808 │ cls__token[0][1], │ │ (MultiHeadAttention) │ │ │ concatenate[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_10 (Add) │ (None, 1, 128) │ 0 │ cls__token[0][1], │ │ │ │ │ multi_head_attention_3[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_7 (Sequential) │ (None, 1, 128) │ 21,972 │ add_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_11 (Add) │ (None, 1, 128) │ 0 │ add_10[0][0], │ │ │ │ │ sequential_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_1 │ (None, 65, 128) │ 0 │ add_11[0][0], │ │ (Concatenate) │ │ │ layer_normalization_5[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_4 │ (None, 1, 128) │ 263,808 │ add_11[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_12 (Add) │ (None, 1, 128) │ 0 │ add_11[0][0], │ │ │ │ │ multi_head_attention_4[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_8 (Sequential) │ (None, 1, 128) │ 21,972 │ add_12[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_13 (Add) │ (None, 1, 128) │ 0 │ add_12[0][0], │ │ │ │ │ sequential_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ concatenate_2 │ (None, 65, 128) │ 0 │ add_13[0][0], │ │ (Concatenate) │ │ │ layer_normalization_5[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_5 │ (None, 1, 128) │ 263,808 │ add_13[0][0], │ │ (MultiHeadAttention) │ │ │ concatenate_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_14 (Add) │ (None, 1, 128) │ 0 │ add_13[0][0], │ │ │ │ │ multi_head_attention_5[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_9 (Sequential) │ (None, 1, 128) │ 21,972 │ add_14[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_15 (Add) │ (None, 1, 128) │ 0 │ add_14[0][0], │ │ │ │ │ sequential_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_9 │ (None, 1, 128) │ 256 │ add_15[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ squeeze (Squeeze) │ (None, 128) │ 0 │ layer_normalization_9[0][… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_15 (Dense) │ (None, 11) │ 1,419 │ squeeze[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 1,761,713 (6.72 MB)
Trainable params: 1,761,713 (6.72 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 29s 596ms/step - accuracy: 0.4146 - loss: 1.8926 - val_accuracy: 0.9255 - val_loss: 0.2574
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.7658 - loss: 0.7132 - val_accuracy: 0.9068 - val_loss: 0.2550
20/20 ━━━━━━━━━━━━━━━━━━━━ 5s 268ms/step - accuracy: 0.7239 - loss: 1.0073
Test loss: 0.9272577166557312
Test accuracy: 0.7180327773094177
from k3im.cct_3d import CCT3DModel # jax ✅, tensorflow ✅, torch ✅
model = CCT3DModel(input_shape=(28, 28, 28, 1),
num_heads=4,
projection_dim=64,
kernel_size=4,
stride=4,
padding=2,
transformer_units=[16, 64],
stochastic_depth_rate=0.6,
transformer_layers=2,
num_classes=num_classes,
positional_emb=False,)
model.summary()
Model: "functional_18"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_13 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ cct_tokenizer3d │ (None, 27, 64) │ 266,240 │ input_layer_13[0][0] │ │ (CCTTokenizer3D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_10 │ (None, 27, 64) │ 128 │ cct_tokenizer3d[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_6 │ (None, 27, 64) │ 66,368 │ layer_normalization_10[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_10[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth │ (None, 27, 64) │ 0 │ multi_head_attention_6[0]… │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_16 (Add) │ (None, 27, 64) │ 0 │ stochastic_depth[0][0], │ │ │ │ │ cct_tokenizer3d[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_11 │ (None, 27, 64) │ 128 │ add_16[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_16 (Dense) │ (None, 27, 16) │ 1,040 │ layer_normalization_11[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_7 (Dropout) │ (None, 27, 16) │ 0 │ dense_16[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_17 (Dense) │ (None, 27, 64) │ 1,088 │ dropout_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_8 (Dropout) │ (None, 27, 64) │ 0 │ dense_17[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth_1 │ (None, 27, 64) │ 0 │ dropout_8[0][0] │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_17 (Add) │ (None, 27, 64) │ 0 │ stochastic_depth_1[0][0], │ │ │ │ │ add_16[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_12 │ (None, 27, 64) │ 128 │ add_17[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_7 │ (None, 27, 64) │ 66,368 │ layer_normalization_12[0]… │ │ (MultiHeadAttention) │ │ │ layer_normalization_12[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth_2 │ (None, 27, 64) │ 0 │ multi_head_attention_7[0]… │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_18 (Add) │ (None, 27, 64) │ 0 │ stochastic_depth_2[0][0], │ │ │ │ │ add_17[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_13 │ (None, 27, 64) │ 128 │ add_18[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_18 (Dense) │ (None, 27, 16) │ 1,040 │ layer_normalization_13[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_10 (Dropout) │ (None, 27, 16) │ 0 │ dense_18[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_19 (Dense) │ (None, 27, 64) │ 1,088 │ dropout_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_11 (Dropout) │ (None, 27, 64) │ 0 │ dense_19[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ stochastic_depth_3 │ (None, 27, 64) │ 0 │ dropout_11[0][0] │ │ (StochasticDepth) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_19 (Add) │ (None, 27, 64) │ 0 │ stochastic_depth_3[0][0], │ │ │ │ │ add_18[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_14 │ (None, 27, 64) │ 128 │ add_19[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequence_pooling │ (None, 64) │ 65 │ layer_normalization_14[0]… │ │ (SequencePooling) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_21 (Dense) │ (None, 11) │ 715 │ sequence_pooling[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 404,652 (1.54 MB)
Trainable params: 404,652 (1.54 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 18s 311ms/step - accuracy: 0.1076 - loss: 2.4914 - val_accuracy: 0.1801 - val_loss: 2.2091
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.2319 - loss: 2.0499 - val_accuracy: 0.4099 - val_loss: 1.4763
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.3311 - loss: 1.6405
Test loss: 1.6255520582199097
Test accuracy: 0.34590163826942444
from k3im.convmixer_3d import ConvMixer3DModel # jax ✅, # tensorflow ✅, #torch fail
model = ConvMixer3DModel(image_size=28,
num_frames=28,
filters=32,
depth=2,
kernel_size=4,
kernel_depth=3,
patch_size=3,
patch_depth=3,
num_classes=10,
num_channels=1)
model.summary()
Model: "functional_22"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_15 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ rescaling_2 (Rescaling) │ (None, 28, 28, 28, 1) │ 0 │ input_layer_15[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv3d_16 (Conv3D) │ (None, 9, 9, 9, 32) │ 896 │ rescaling_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ activation_10 │ (None, 9, 9, 9, 32) │ 0 │ conv3d_16[0][0] │ │ (Activation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ batch_normalization_10 │ (None, 9, 9, 9, 32) │ 128 │ activation_10[0][0] │ │ (BatchNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv2_plus1d_4 │ (None, 9, 9, 9, 32) │ 19,520 │ batch_normalization_10[0]… │ │ (Conv2Plus1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ activation_11 │ (None, 9, 9, 9, 32) │ 0 │ conv2_plus1d_4[0][0] │ │ (Activation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ batch_normalization_11 │ (None, 9, 9, 9, 32) │ 128 │ activation_11[0][0] │ │ (BatchNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_20 (Add) │ (None, 9, 9, 9, 32) │ 0 │ batch_normalization_11[0]… │ │ │ │ │ batch_normalization_10[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv3d_19 (Conv3D) │ (None, 9, 9, 9, 32) │ 1,056 │ add_20[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ activation_12 │ (None, 9, 9, 9, 32) │ 0 │ conv3d_19[0][0] │ │ (Activation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ batch_normalization_12 │ (None, 9, 9, 9, 32) │ 128 │ activation_12[0][0] │ │ (BatchNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv2_plus1d_5 │ (None, 9, 9, 9, 32) │ 19,520 │ batch_normalization_12[0]… │ │ (Conv2Plus1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ activation_13 │ (None, 9, 9, 9, 32) │ 0 │ conv2_plus1d_5[0][0] │ │ (Activation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ batch_normalization_13 │ (None, 9, 9, 9, 32) │ 128 │ activation_13[0][0] │ │ (BatchNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_21 (Add) │ (None, 9, 9, 9, 32) │ 0 │ batch_normalization_13[0]… │ │ │ │ │ batch_normalization_12[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ conv3d_22 (Conv3D) │ (None, 9, 9, 9, 32) │ 1,056 │ add_21[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ activation_14 │ (None, 9, 9, 9, 32) │ 0 │ conv3d_22[0][0] │ │ (Activation) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ batch_normalization_14 │ (None, 9, 9, 9, 32) │ 128 │ activation_14[0][0] │ │ (BatchNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ global_average_pooling3d… │ (None, 32) │ 0 │ batch_normalization_14[0]… │ │ (GlobalAveragePooling3D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_22 (Dense) │ (None, 10) │ 330 │ global_average_pooling3d_… │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 43,018 (168.04 KB)
Trainable params: 42,698 (166.79 KB)
Non-trainable params: 320 (1.25 KB)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 8s 151ms/step - accuracy: 0.2606 - loss: 1.7970 - val_accuracy: 0.0994 - val_loss: 2.1118
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - accuracy: 0.4494 - loss: 1.4487 - val_accuracy: 0.0994 - val_loss: 2.1147
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.0999 - loss: 2.0564
Test loss: 2.065363645553589
Test accuracy: 0.11311475187540054
from k3im.eanet3d import EANet3DModel # jax ✅, tensorflow ✅, torch ✅
model = EANet3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=num_classes,
dim=64,
depth=2,
heads=4,
mlp_dim=32,
channels=1,
dim_coefficient=4,
projection_dropout=0.0,
attention_dropout=0,
)
model.summary()
Model: "functional_26"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_18 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_3 (Reshape) │ (None, 7, 4, 7, 4, 7, │ 0 │ input_layer_18[0][0] │ │ │ 4, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_1 (Transpose) │ (None, 4, 4, 4, 7, 7, │ 0 │ reshape_3[0][0] │ │ │ 7, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_4 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ transpose_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_15 │ (None, 4, 4, 4, 343) │ 686 │ reshape_4[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_23 (Dense) │ (None, 4, 4, 4, 64) │ 22,016 │ layer_normalization_15[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_16 │ (None, 4, 4, 4, 64) │ 128 │ dense_23[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_5 (Reshape) │ (None, 64, 64) │ 0 │ layer_normalization_16[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_24 (Dense) │ (None, 64, 256) │ 16,640 │ reshape_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_6 (Reshape) │ (None, 64, 16, 16) │ 0 │ dense_24[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_2 (Transpose) │ (None, 16, 64, 16) │ 0 │ reshape_6[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_25 (Dense) │ (None, 16, 64, 16) │ 272 │ transpose_2[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ softmax_8 (Softmax) │ (None, 16, 64, 16) │ 0 │ dense_25[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ lambda (Lambda) │ (None, 16, 64, 16) │ 0 │ softmax_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_12 (Dropout) │ (None, 16, 64, 16) │ 0 │ lambda[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_26 (Dense) │ (None, 16, 64, 16) │ 272 │ dropout_12[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_3 (Transpose) │ (None, 64, 16, 16) │ 0 │ dense_26[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_7 (Reshape) │ (None, 64, 256) │ 0 │ transpose_3[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_27 (Dense) │ (None, 64, 64) │ 16,448 │ reshape_7[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_13 (Dropout) │ (None, 64, 64) │ 0 │ dense_27[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_22 (Add) │ (None, 64, 64) │ 0 │ reshape_5[0][0], │ │ │ │ │ dropout_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_13 │ (None, 64, 64) │ 4,320 │ add_22[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_23 (Add) │ (None, 64, 64) │ 0 │ add_22[0][0], │ │ │ │ │ sequential_13[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_30 (Dense) │ (None, 64, 256) │ 16,640 │ add_23[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_8 (Reshape) │ (None, 64, 16, 16) │ 0 │ dense_30[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_4 (Transpose) │ (None, 16, 64, 16) │ 0 │ reshape_8[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_31 (Dense) │ (None, 16, 64, 16) │ 272 │ transpose_4[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ softmax_9 (Softmax) │ (None, 16, 64, 16) │ 0 │ dense_31[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ lambda_1 (Lambda) │ (None, 16, 64, 16) │ 0 │ softmax_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_14 (Dropout) │ (None, 16, 64, 16) │ 0 │ lambda_1[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_32 (Dense) │ (None, 16, 64, 16) │ 272 │ dropout_14[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_5 (Transpose) │ (None, 64, 16, 16) │ 0 │ dense_32[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_9 (Reshape) │ (None, 64, 256) │ 0 │ transpose_5[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_33 (Dense) │ (None, 64, 64) │ 16,448 │ reshape_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dropout_15 (Dropout) │ (None, 64, 64) │ 0 │ dense_33[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_24 (Add) │ (None, 64, 64) │ 0 │ add_23[0][0], │ │ │ │ │ dropout_15[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_14 │ (None, 64, 64) │ 4,320 │ add_24[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_25 (Add) │ (None, 64, 64) │ 0 │ add_24[0][0], │ │ │ │ │ sequential_14[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_19 │ (None, 64, 64) │ 128 │ add_25[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 64) │ 0 │ layer_normalization_19[0]… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_36 (Dense) │ (None, 11) │ 715 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 99,577 (388.97 KB)
Trainable params: 99,577 (388.97 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 12s 248ms/step - accuracy: 0.3390 - loss: 2.0201 - val_accuracy: 0.9130 - val_loss: 0.5298
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.7829 - loss: 0.8192 - val_accuracy: 0.9441 - val_loss: 0.2836
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 133ms/step - accuracy: 0.7109 - loss: 0.9908
Test loss: 0.9404497146606445
Test accuracy: 0.7180327773094177
from k3im.fnet_3d import FNet3DModel # jax ✅ , tensorflow ✅, torch ✅
model = FNet3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=11,
dim=256,
depth=5,
hidden_units=256,
dropout_rate=0.6,
channels=1,
)
model.summary()
Model: "functional_33"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer_21 (InputLayer) │ (None, 28, 28, 28, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_10 (Reshape) │ (None, 7, 4, 7, 4, 7, 4, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ transpose_6 (Transpose) │ (None, 4, 4, 4, 7, 7, 7, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_11 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ layer_normalization_20 │ (None, 4, 4, 4, 343) │ 686 │ │ (LayerNormalization) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_37 (Dense) │ (None, 4, 4, 4, 256) │ 88,064 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ layer_normalization_21 │ (None, 4, 4, 4, 256) │ 512 │ │ (LayerNormalization) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_12 (Reshape) │ (None, 64, 256) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer (FNetLayer) │ (None, 64, 256) │ 132,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer_1 (FNetLayer) │ (None, 64, 256) │ 132,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer_2 (FNetLayer) │ (None, 64, 256) │ 132,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer_3 (FNetLayer) │ (None, 64, 256) │ 132,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ f_net_layer_4 (FNetLayer) │ (None, 64, 256) │ 132,608 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ avg_pool (GlobalAveragePooling1D) │ (None, 256) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_48 (Dense) │ (None, 11) │ 2,827 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
Total params: 755,129 (2.88 MB)
Trainable params: 755,129 (2.88 MB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 26s 681ms/step - accuracy: 0.1565 - loss: 2.3726 - val_accuracy: 0.6149 - val_loss: 1.6513
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 0.6299 - loss: 1.4855 - val_accuracy: 0.7391 - val_loss: 0.7485
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.5232 - loss: 1.4321
Test loss: 1.3722052574157715
Test accuracy: 0.5278688669204712
from k3im.gmlp_3d import gMLP3DModel # jax ✅ , tensorflow ✅, torch ✅
model = gMLP3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=11,
dim=32,
depth=4,
hidden_units=32,
dropout_rate=0.4,
channels=1,
)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 23s 552ms/step - accuracy: 0.2921 - loss: 2.1528 - val_accuracy: 0.7329 - val_loss: 0.9436
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.7036 - loss: 1.1000 - val_accuracy: 0.9193 - val_loss: 0.4642
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.6815 - loss: 1.1396
Test loss: 1.0670228004455566
Test accuracy: 0.685245931148529
from k3im.mlp_mixer_3d import MLPMixer3DModel # jax ✅, tensorflow ✅, torch ✅
model = MLPMixer3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=num_classes,
dim=32,
depth=4,
hidden_units=32,
dropout_rate=0.4,
channels=1,
)
model.summary()
Model: "functional_49"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer_32 (InputLayer) │ (None, 28, 28, 28, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_16 (Reshape) │ (None, 7, 4, 7, 4, 7, 4, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ transpose_8 (Transpose) │ (None, 4, 4, 4, 7, 7, 7, 1) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_17 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ layer_normalization_42 │ (None, 4, 4, 4, 343) │ 686 │ │ (LayerNormalization) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_63 (Dense) │ (None, 4, 4, 4, 32) │ 11,008 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ layer_normalization_43 │ (None, 4, 4, 4, 32) │ 64 │ │ (LayerNormalization) │ │ │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ reshape_18 (Reshape) │ (None, 64, 32) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ mlp_mixer_layer (MLPMixerLayer) │ (None, 64, 32) │ 12,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ mlp_mixer_layer_1 (MLPMixerLayer) │ (None, 64, 32) │ 12,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ mlp_mixer_layer_2 (MLPMixerLayer) │ (None, 64, 32) │ 12,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ mlp_mixer_layer_3 (MLPMixerLayer) │ (None, 64, 32) │ 12,576 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ avg_pool (GlobalAveragePooling1D) │ (None, 32) │ 0 │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ dense_80 (Dense) │ (None, 11) │ 363 │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘
Total params: 62,425 (243.85 KB)
Trainable params: 62,425 (243.85 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 24s 655ms/step - accuracy: 0.3077 - loss: 2.1705 - val_accuracy: 0.8944 - val_loss: 0.3976
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.7351 - loss: 0.8639 - val_accuracy: 0.9317 - val_loss: 0.2207
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.7119 - loss: 1.0056
Test loss: 0.9401963353157043
Test accuracy: 0.7147541046142578
from k3im.simple_vit_3d import SimpleViT3DModel # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=num_classes,
dim=32,
depth=2,
heads=4,
mlp_dim=32,
channels=1,
dim_head=64,
)
model.summary()
Model: "functional_53"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_41 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_19 (Reshape) │ (None, 7, 4, 7, 4, 7, │ 0 │ input_layer_41[0][0] │ │ │ 4, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_9 (Transpose) │ (None, 4, 4, 4, 7, 7, │ 0 │ reshape_19[0][0] │ │ │ 7, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_20 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ transpose_9[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_48 │ (None, 4, 4, 4, 343) │ 686 │ reshape_20[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_81 (Dense) │ (None, 4, 4, 4, 32) │ 11,008 │ layer_normalization_48[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_49 │ (None, 4, 4, 4, 32) │ 64 │ dense_81[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_21 (Reshape) │ (None, 64, 32) │ 0 │ layer_normalization_49[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_8 │ (None, 64, 32) │ 33,568 │ reshape_21[0][0], │ │ (MultiHeadAttention) │ │ │ reshape_21[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_26 (Add) │ (None, 64, 32) │ 0 │ reshape_21[0][0], │ │ │ │ │ multi_head_attention_8[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_32 │ (None, 64, 32) │ 2,176 │ add_26[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_27 (Add) │ (None, 64, 32) │ 0 │ add_26[0][0], │ │ │ │ │ sequential_32[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_9 │ (None, 64, 32) │ 33,568 │ add_27[0][0], add_27[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_28 (Add) │ (None, 64, 32) │ 0 │ add_27[0][0], │ │ │ │ │ multi_head_attention_9[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_33 │ (None, 64, 32) │ 2,176 │ add_28[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_29 (Add) │ (None, 64, 32) │ 0 │ add_28[0][0], │ │ │ │ │ sequential_33[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_52 │ (None, 64, 32) │ 64 │ add_29[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ avg_pool │ (None, 32) │ 0 │ layer_normalization_52[0]… │ │ (GlobalAveragePooling1D) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_86 (Dense) │ (None, 11) │ 363 │ avg_pool[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 83,673 (326.85 KB)
Trainable params: 83,673 (326.85 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 13s 296ms/step - accuracy: 0.3410 - loss: 2.0619 - val_accuracy: 0.8571 - val_loss: 0.7134
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.7661 - loss: 0.9615 - val_accuracy: 0.9006 - val_loss: 0.4115
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.6486 - loss: 1.1380
Test loss: 1.0664023160934448
Test accuracy: 0.6672131419181824
from k3im.vit_3d import ViT3DModel # jax ✅, tensorflow ✅, torch ✅
model = ViT3DModel(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=num_classes,
dim=32,
depth=2,
heads=4,
mlp_dim=32,
pool='cls',
channels=1,
dim_head=64,
)
model.summary()
Model: "functional_57"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_44 │ (None, 28, 28, 28, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_22 (Reshape) │ (None, 7, 4, 7, 4, 7, │ 0 │ input_layer_44[0][0] │ │ │ 4, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ transpose_10 (Transpose) │ (None, 4, 4, 4, 7, 7, │ 0 │ reshape_22[0][0] │ │ │ 7, 1) │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_23 (Reshape) │ (None, 4, 4, 4, 343) │ 0 │ transpose_10[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_53 │ (None, 4, 4, 4, 343) │ 686 │ reshape_23[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_87 (Dense) │ (None, 4, 4, 4, 32) │ 11,008 │ layer_normalization_53[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_54 │ (None, 4, 4, 4, 32) │ 64 │ dense_87[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ reshape_24 (Reshape) │ (None, 64, 32) │ 0 │ layer_normalization_54[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ class_token_position_emb │ (None, 65, 32) │ 2,112 │ reshape_24[0][0] │ │ (ClassTokenPositionEmb) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_10 │ (None, 65, 32) │ 33,568 │ class_token_position_emb[… │ │ (MultiHeadAttention) │ │ │ class_token_position_emb[… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_30 (Add) │ (None, 65, 32) │ 0 │ class_token_position_emb[… │ │ │ │ │ multi_head_attention_10[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_34 │ (None, 65, 32) │ 2,176 │ add_30[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_31 (Add) │ (None, 65, 32) │ 0 │ add_30[0][0], │ │ │ │ │ sequential_34[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ multi_head_attention_11 │ (None, 65, 32) │ 33,568 │ add_31[0][0], add_31[0][0] │ │ (MultiHeadAttention) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_32 (Add) │ (None, 65, 32) │ 0 │ add_31[0][0], │ │ │ │ │ multi_head_attention_11[0… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ sequential_35 │ (None, 65, 32) │ 2,176 │ add_32[0][0] │ │ (Sequential) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ add_33 (Add) │ (None, 65, 32) │ 0 │ add_32[0][0], │ │ │ │ │ sequential_35[0][0] │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ layer_normalization_57 │ (None, 65, 32) │ 64 │ add_33[0][0] │ │ (LayerNormalization) │ │ │ │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ get_item (GetItem) │ (None, 32) │ 0 │ layer_normalization_57[0]… │ ├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤ │ dense_92 (Dense) │ (None, 11) │ 363 │ get_item[0][0] │ └───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
Total params: 85,785 (335.10 KB)
Trainable params: 85,785 (335.10 KB)
Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 13s 262ms/step - accuracy: 0.3574 - loss: 1.9998 - val_accuracy: 0.8696 - val_loss: 0.5885
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7141 - loss: 0.9330 - val_accuracy: 0.9255 - val_loss: 0.3551
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 63ms/step - accuracy: 0.6813 - loss: 1.0153
Test loss: 0.963051438331604
Test accuracy: 0.6983606815338135
from k3im.video_eanet import VideoEANet # fixed jax ✅
model = VideoEANet(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=num_classes,
dim=64,
spatial_depth=2,
temporal_depth=2,
heads=4,
mlp_dim=128,
pool="cls",
channels=1,
dim_coefficient=4,
projection_dropout=0.0,
attention_dropout=0,
emb_dropout=0.0,
)
model.summary()
train_model(model)
from k3im.vivit import ViViT
model = ViViT(
image_size=28,
image_patch_size=7,
frames=28,
frame_patch_size=7,
num_classes=11,
dim=64,
spatial_depth=1,
temporal_depth=1,
heads=4,
mlp_dim=128,
pool="cls",
channels=1,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
)
model.summary()
train_model(model)