Examples 3D/Space-Time

!pip install k3im medmnist -qq --upgrade

import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import medmnist
import numpy as np
import tensorflow as tf # For data processes only
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
num_classes = 11

def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )


# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")
batch_size = 32
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(trainloader, epochs=epochs, validation_data=validloader)
    score = model.evaluate(testloader)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])
from k3im.cait_3d import CAiT3DModel # fixed jax ✅
model = CAiT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=128,
    depth=3,
    cls_depth=3,
    heads=8,
    mlp_dim=84,
    channels=1,
    dim_head=64,
)
model.summary()
Model: "functional_15"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_6             │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape (Reshape)         │ (None, 7, 4, 7, 4, 7,  │          0 │ input_layer_6[0][0]        │
│                           │ 4, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose (Transpose)     │ (None, 4, 4, 4, 7, 7,  │          0 │ reshape[0][0]              │
│                           │ 7, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_1 (Reshape)       │ (None, 4, 4, 4, 343)   │          0 │ transpose[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization       │ (None, 4, 4, 4, 343)   │        686 │ reshape_1[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_2 (Dense)           │ (None, 4, 4, 4, 128)   │     44,032 │ layer_normalization[0][0]  │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_1     │ (None, 4, 4, 4, 128)   │        256 │ dense_2[0][0]              │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_2 (Reshape)       │ (None, 64, 128)        │          0 │ layer_normalization_1[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention      │ (None, 64, 128)        │    263,808 │ reshape_2[0][0],           │
│ (MultiHeadAttention)      │                        │            │ reshape_2[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_4 (Add)               │ (None, 64, 128)        │          0 │ reshape_2[0][0],           │
│                           │                        │            │ multi_head_attention[0][0] │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_4 (Sequential) │ (None, 64, 128)        │     21,972 │ add_4[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_5 (Add)               │ (None, 64, 128)        │          0 │ add_4[0][0],               │
│                           │                        │            │ sequential_4[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_1    │ (None, 64, 128)        │    263,808 │ add_5[0][0], add_5[0][0]   │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_6 (Add)               │ (None, 64, 128)        │          0 │ add_5[0][0],               │
│                           │                        │            │ multi_head_attention_1[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_5 (Sequential) │ (None, 64, 128)        │     21,972 │ add_6[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_7 (Add)               │ (None, 64, 128)        │          0 │ add_6[0][0],               │
│                           │                        │            │ sequential_5[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_2    │ (None, 64, 128)        │    263,808 │ add_7[0][0], add_7[0][0]   │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_8 (Add)               │ (None, 64, 128)        │          0 │ add_7[0][0],               │
│                           │                        │            │ multi_head_attention_2[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_6 (Sequential) │ (None, 64, 128)        │     21,972 │ add_8[0][0]                │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_9 (Add)               │ (None, 64, 128)        │          0 │ add_8[0][0],               │
│                           │                        │            │ sequential_6[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_5     │ (None, 64, 128)        │        256 │ add_9[0][0]                │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ cls__token (CLS_Token)    │ [(None, 65, 128),      │        128 │ layer_normalization_5[0][ │
│                           │ (None, 1, 128)]        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ concatenate (Concatenate) │ (None, 65, 128)        │          0 │ cls__token[0][1],          │
│                           │                        │            │ layer_normalization_5[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_3    │ (None, 1, 128)         │    263,808 │ cls__token[0][1],          │
│ (MultiHeadAttention)      │                        │            │ concatenate[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_10 (Add)              │ (None, 1, 128)         │          0 │ cls__token[0][1],          │
│                           │                        │            │ multi_head_attention_3[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_7 (Sequential) │ (None, 1, 128)         │     21,972 │ add_10[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_11 (Add)              │ (None, 1, 128)         │          0 │ add_10[0][0],              │
│                           │                        │            │ sequential_7[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ concatenate_1             │ (None, 65, 128)        │          0 │ add_11[0][0],              │
│ (Concatenate)             │                        │            │ layer_normalization_5[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_4    │ (None, 1, 128)         │    263,808 │ add_11[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_1[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_12 (Add)              │ (None, 1, 128)         │          0 │ add_11[0][0],              │
│                           │                        │            │ multi_head_attention_4[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_8 (Sequential) │ (None, 1, 128)         │     21,972 │ add_12[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_13 (Add)              │ (None, 1, 128)         │          0 │ add_12[0][0],              │
│                           │                        │            │ sequential_8[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ concatenate_2             │ (None, 65, 128)        │          0 │ add_13[0][0],              │
│ (Concatenate)             │                        │            │ layer_normalization_5[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_5    │ (None, 1, 128)         │    263,808 │ add_13[0][0],              │
│ (MultiHeadAttention)      │                        │            │ concatenate_2[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_14 (Add)              │ (None, 1, 128)         │          0 │ add_13[0][0],              │
│                           │                        │            │ multi_head_attention_5[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_9 (Sequential) │ (None, 1, 128)         │     21,972 │ add_14[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_15 (Add)              │ (None, 1, 128)         │          0 │ add_14[0][0],              │
│                           │                        │            │ sequential_9[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_9     │ (None, 1, 128)         │        256 │ add_15[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ squeeze (Squeeze)         │ (None, 128)            │          0 │ layer_normalization_9[0][ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_15 (Dense)          │ (None, 11)             │      1,419 │ squeeze[0][0]              │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 1,761,713 (6.72 MB)
 Trainable params: 1,761,713 (6.72 MB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 29s 596ms/step - accuracy: 0.4146 - loss: 1.8926 - val_accuracy: 0.9255 - val_loss: 0.2574
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.7658 - loss: 0.7132 - val_accuracy: 0.9068 - val_loss: 0.2550
20/20 ━━━━━━━━━━━━━━━━━━━━ 5s 268ms/step - accuracy: 0.7239 - loss: 1.0073
Test loss: 0.9272577166557312
Test accuracy: 0.7180327773094177

from k3im.cct_3d import CCT3DModel # jax ✅, tensorflow ✅, torch ✅
model = CCT3DModel(input_shape=(28, 28, 28, 1),
    num_heads=4,
    projection_dim=64,
    kernel_size=4,
    stride=4,
    padding=2,
    transformer_units=[16, 64],
    stochastic_depth_rate=0.6,
    transformer_layers=2,
    num_classes=num_classes,
    positional_emb=False,)
model.summary()
Model: "functional_18"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_13            │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ cct_tokenizer3d           │ (None, 27, 64)         │    266,240 │ input_layer_13[0][0]       │
│ (CCTTokenizer3D)          │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_10    │ (None, 27, 64)         │        128 │ cct_tokenizer3d[0][0]      │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_6    │ (None, 27, 64)         │     66,368 │ layer_normalization_10[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_10[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth          │ (None, 27, 64)         │          0 │ multi_head_attention_6[0]… │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_16 (Add)              │ (None, 27, 64)         │          0 │ stochastic_depth[0][0],    │
│                           │                        │            │ cct_tokenizer3d[0][0]      │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_11    │ (None, 27, 64)         │        128 │ add_16[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_16 (Dense)          │ (None, 27, 16)         │      1,040 │ layer_normalization_11[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_7 (Dropout)       │ (None, 27, 16)         │          0 │ dense_16[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_17 (Dense)          │ (None, 27, 64)         │      1,088 │ dropout_7[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_8 (Dropout)       │ (None, 27, 64)         │          0 │ dense_17[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth_1        │ (None, 27, 64)         │          0 │ dropout_8[0][0]            │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_17 (Add)              │ (None, 27, 64)         │          0 │ stochastic_depth_1[0][0],  │
│                           │                        │            │ add_16[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_12    │ (None, 27, 64)         │        128 │ add_17[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_7    │ (None, 27, 64)         │     66,368 │ layer_normalization_12[0]… │
│ (MultiHeadAttention)      │                        │            │ layer_normalization_12[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth_2        │ (None, 27, 64)         │          0 │ multi_head_attention_7[0]… │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_18 (Add)              │ (None, 27, 64)         │          0 │ stochastic_depth_2[0][0],  │
│                           │                        │            │ add_17[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_13    │ (None, 27, 64)         │        128 │ add_18[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_18 (Dense)          │ (None, 27, 16)         │      1,040 │ layer_normalization_13[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_10 (Dropout)      │ (None, 27, 16)         │          0 │ dense_18[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_19 (Dense)          │ (None, 27, 64)         │      1,088 │ dropout_10[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_11 (Dropout)      │ (None, 27, 64)         │          0 │ dense_19[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ stochastic_depth_3        │ (None, 27, 64)         │          0 │ dropout_11[0][0]           │
│ (StochasticDepth)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_19 (Add)              │ (None, 27, 64)         │          0 │ stochastic_depth_3[0][0],  │
│                           │                        │            │ add_18[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_14    │ (None, 27, 64)         │        128 │ add_19[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequence_pooling          │ (None, 64)             │         65 │ layer_normalization_14[0]… │
│ (SequencePooling)         │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_21 (Dense)          │ (None, 11)             │        715 │ sequence_pooling[0][0]     │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 404,652 (1.54 MB)
 Trainable params: 404,652 (1.54 MB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 18s 311ms/step - accuracy: 0.1076 - loss: 2.4914 - val_accuracy: 0.1801 - val_loss: 2.2091
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.2319 - loss: 2.0499 - val_accuracy: 0.4099 - val_loss: 1.4763
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.3311 - loss: 1.6405
Test loss: 1.6255520582199097
Test accuracy: 0.34590163826942444

from k3im.convmixer_3d import ConvMixer3DModel # jax ✅, # tensorflow ✅,  #torch fail
model = ConvMixer3DModel(image_size=28,
    num_frames=28,
    filters=32,
    depth=2,
    kernel_size=4,
    kernel_depth=3,
    patch_size=3,
    patch_depth=3,
    num_classes=10,
    num_channels=1)
model.summary()
Model: "functional_22"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_15            │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ rescaling_2 (Rescaling)   │ (None, 28, 28, 28, 1)  │          0 │ input_layer_15[0][0]       │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv3d_16 (Conv3D)        │ (None, 9, 9, 9, 32)    │        896 │ rescaling_2[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_10             │ (None, 9, 9, 9, 32)    │          0 │ conv3d_16[0][0]            │
│ (Activation)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_10    │ (None, 9, 9, 9, 32)    │        128 │ activation_10[0][0]        │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv2_plus1d_4            │ (None, 9, 9, 9, 32)    │     19,520 │ batch_normalization_10[0]… │
│ (Conv2Plus1D)             │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_11             │ (None, 9, 9, 9, 32)    │          0 │ conv2_plus1d_4[0][0]       │
│ (Activation)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_11    │ (None, 9, 9, 9, 32)    │        128 │ activation_11[0][0]        │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_20 (Add)              │ (None, 9, 9, 9, 32)    │          0 │ batch_normalization_11[0]… │
│                           │                        │            │ batch_normalization_10[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv3d_19 (Conv3D)        │ (None, 9, 9, 9, 32)    │      1,056 │ add_20[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_12             │ (None, 9, 9, 9, 32)    │          0 │ conv3d_19[0][0]            │
│ (Activation)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_12    │ (None, 9, 9, 9, 32)    │        128 │ activation_12[0][0]        │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv2_plus1d_5            │ (None, 9, 9, 9, 32)    │     19,520 │ batch_normalization_12[0]… │
│ (Conv2Plus1D)             │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_13             │ (None, 9, 9, 9, 32)    │          0 │ conv2_plus1d_5[0][0]       │
│ (Activation)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_13    │ (None, 9, 9, 9, 32)    │        128 │ activation_13[0][0]        │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_21 (Add)              │ (None, 9, 9, 9, 32)    │          0 │ batch_normalization_13[0]… │
│                           │                        │            │ batch_normalization_12[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ conv3d_22 (Conv3D)        │ (None, 9, 9, 9, 32)    │      1,056 │ add_21[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ activation_14             │ (None, 9, 9, 9, 32)    │          0 │ conv3d_22[0][0]            │
│ (Activation)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ batch_normalization_14    │ (None, 9, 9, 9, 32)    │        128 │ activation_14[0][0]        │
│ (BatchNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ global_average_pooling3d… │ (None, 32)             │          0 │ batch_normalization_14[0]… │
│ (GlobalAveragePooling3D)  │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_22 (Dense)          │ (None, 10)             │        330 │ global_average_pooling3d_… │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 43,018 (168.04 KB)
 Trainable params: 42,698 (166.79 KB)
 Non-trainable params: 320 (1.25 KB)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 8s 151ms/step - accuracy: 0.2606 - loss: 1.7970 - val_accuracy: 0.0994 - val_loss: 2.1118
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - accuracy: 0.4494 - loss: 1.4487 - val_accuracy: 0.0994 - val_loss: 2.1147
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.0999 - loss: 2.0564
Test loss: 2.065363645553589
Test accuracy: 0.11311475187540054

from k3im.eanet3d import EANet3DModel # jax ✅, tensorflow ✅, torch ✅
model = EANet3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=64,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_coefficient=4,
    projection_dropout=0.0,
    attention_dropout=0,
)
model.summary()
Model: "functional_26"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_18            │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_3 (Reshape)       │ (None, 7, 4, 7, 4, 7,  │          0 │ input_layer_18[0][0]       │
│                           │ 4, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_1 (Transpose)   │ (None, 4, 4, 4, 7, 7,  │          0 │ reshape_3[0][0]            │
│                           │ 7, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_4 (Reshape)       │ (None, 4, 4, 4, 343)   │          0 │ transpose_1[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_15    │ (None, 4, 4, 4, 343)   │        686 │ reshape_4[0][0]            │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_23 (Dense)          │ (None, 4, 4, 4, 64)    │     22,016 │ layer_normalization_15[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_16    │ (None, 4, 4, 4, 64)    │        128 │ dense_23[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_5 (Reshape)       │ (None, 64, 64)         │          0 │ layer_normalization_16[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_24 (Dense)          │ (None, 64, 256)        │     16,640 │ reshape_5[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_6 (Reshape)       │ (None, 64, 16, 16)     │          0 │ dense_24[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_2 (Transpose)   │ (None, 16, 64, 16)     │          0 │ reshape_6[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_25 (Dense)          │ (None, 16, 64, 16)     │        272 │ transpose_2[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ softmax_8 (Softmax)       │ (None, 16, 64, 16)     │          0 │ dense_25[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ lambda (Lambda)           │ (None, 16, 64, 16)     │          0 │ softmax_8[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_12 (Dropout)      │ (None, 16, 64, 16)     │          0 │ lambda[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_26 (Dense)          │ (None, 16, 64, 16)     │        272 │ dropout_12[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_3 (Transpose)   │ (None, 64, 16, 16)     │          0 │ dense_26[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_7 (Reshape)       │ (None, 64, 256)        │          0 │ transpose_3[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_27 (Dense)          │ (None, 64, 64)         │     16,448 │ reshape_7[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_13 (Dropout)      │ (None, 64, 64)         │          0 │ dense_27[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_22 (Add)              │ (None, 64, 64)         │          0 │ reshape_5[0][0],           │
│                           │                        │            │ dropout_13[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_13             │ (None, 64, 64)         │      4,320 │ add_22[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_23 (Add)              │ (None, 64, 64)         │          0 │ add_22[0][0],              │
│                           │                        │            │ sequential_13[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_30 (Dense)          │ (None, 64, 256)        │     16,640 │ add_23[0][0]               │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_8 (Reshape)       │ (None, 64, 16, 16)     │          0 │ dense_30[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_4 (Transpose)   │ (None, 16, 64, 16)     │          0 │ reshape_8[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_31 (Dense)          │ (None, 16, 64, 16)     │        272 │ transpose_4[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ softmax_9 (Softmax)       │ (None, 16, 64, 16)     │          0 │ dense_31[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ lambda_1 (Lambda)         │ (None, 16, 64, 16)     │          0 │ softmax_9[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_14 (Dropout)      │ (None, 16, 64, 16)     │          0 │ lambda_1[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_32 (Dense)          │ (None, 16, 64, 16)     │        272 │ dropout_14[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_5 (Transpose)   │ (None, 64, 16, 16)     │          0 │ dense_32[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_9 (Reshape)       │ (None, 64, 256)        │          0 │ transpose_5[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_33 (Dense)          │ (None, 64, 64)         │     16,448 │ reshape_9[0][0]            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dropout_15 (Dropout)      │ (None, 64, 64)         │          0 │ dense_33[0][0]             │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_24 (Add)              │ (None, 64, 64)         │          0 │ add_23[0][0],              │
│                           │                        │            │ dropout_15[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_14             │ (None, 64, 64)         │      4,320 │ add_24[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_25 (Add)              │ (None, 64, 64)         │          0 │ add_24[0][0],              │
│                           │                        │            │ sequential_14[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_19    │ (None, 64, 64)         │        128 │ add_25[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ avg_pool                  │ (None, 64)             │          0 │ layer_normalization_19[0]… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_36 (Dense)          │ (None, 11)             │        715 │ avg_pool[0][0]             │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 99,577 (388.97 KB)
 Trainable params: 99,577 (388.97 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 12s 248ms/step - accuracy: 0.3390 - loss: 2.0201 - val_accuracy: 0.9130 - val_loss: 0.5298
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.7829 - loss: 0.8192 - val_accuracy: 0.9441 - val_loss: 0.2836
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 133ms/step - accuracy: 0.7109 - loss: 0.9908
Test loss: 0.9404497146606445
Test accuracy: 0.7180327773094177

from k3im.fnet_3d import FNet3DModel # jax ✅ , tensorflow ✅, torch ✅
model = FNet3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=256,
    depth=5,
    hidden_units=256,
    dropout_rate=0.6,
    channels=1,
)
model.summary()
Model: "functional_33"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Layer (type)                        Output Shape                       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ input_layer_21 (InputLayer)        │ (None, 28, 28, 28, 1)         │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_10 (Reshape)               │ (None, 7, 4, 7, 4, 7, 4, 1)   │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ transpose_6 (Transpose)            │ (None, 4, 4, 4, 7, 7, 7, 1)   │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_11 (Reshape)               │ (None, 4, 4, 4, 343)          │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_20             │ (None, 4, 4, 4, 343)          │         686 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_37 (Dense)                   │ (None, 4, 4, 4, 256)          │      88,064 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_21             │ (None, 4, 4, 4, 256)          │         512 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_12 (Reshape)               │ (None, 64, 256)               │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ f_net_layer (FNetLayer)            │ (None, 64, 256)               │     132,608 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ f_net_layer_1 (FNetLayer)          │ (None, 64, 256)               │     132,608 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ f_net_layer_2 (FNetLayer)          │ (None, 64, 256)               │     132,608 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ f_net_layer_3 (FNetLayer)          │ (None, 64, 256)               │     132,608 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ f_net_layer_4 (FNetLayer)          │ (None, 64, 256)               │     132,608 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ avg_pool (GlobalAveragePooling1D)  │ (None, 256)                   │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_48 (Dense)                   │ (None, 11)                    │       2,827 │
└────────────────────────────────────┴───────────────────────────────┴─────────────┘
 Total params: 755,129 (2.88 MB)
 Trainable params: 755,129 (2.88 MB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 26s 681ms/step - accuracy: 0.1565 - loss: 2.3726 - val_accuracy: 0.6149 - val_loss: 1.6513
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 0.6299 - loss: 1.4855 - val_accuracy: 0.7391 - val_loss: 0.7485
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.5232 - loss: 1.4321
Test loss: 1.3722052574157715
Test accuracy: 0.5278688669204712

from k3im.gmlp_3d import gMLP3DModel # jax ✅ , tensorflow ✅, torch ✅
model = gMLP3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 23s 552ms/step - accuracy: 0.2921 - loss: 2.1528 - val_accuracy: 0.7329 - val_loss: 0.9436
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.7036 - loss: 1.1000 - val_accuracy: 0.9193 - val_loss: 0.4642
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - accuracy: 0.6815 - loss: 1.1396
Test loss: 1.0670228004455566
Test accuracy: 0.685245931148529

from k3im.mlp_mixer_3d import MLPMixer3DModel # jax ✅, tensorflow ✅, torch ✅

model = MLPMixer3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)
model.summary()
Model: "functional_49"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Layer (type)                        Output Shape                       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ input_layer_32 (InputLayer)        │ (None, 28, 28, 28, 1)         │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_16 (Reshape)               │ (None, 7, 4, 7, 4, 7, 4, 1)   │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ transpose_8 (Transpose)            │ (None, 4, 4, 4, 7, 7, 7, 1)   │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_17 (Reshape)               │ (None, 4, 4, 4, 343)          │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_42             │ (None, 4, 4, 4, 343)          │         686 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_63 (Dense)                   │ (None, 4, 4, 4, 32)           │      11,008 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ layer_normalization_43             │ (None, 4, 4, 4, 32)           │          64 │
│ (LayerNormalization)               │                               │             │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ reshape_18 (Reshape)               │ (None, 64, 32)                │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ mlp_mixer_layer (MLPMixerLayer)    │ (None, 64, 32)                │      12,576 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ mlp_mixer_layer_1 (MLPMixerLayer)  │ (None, 64, 32)                │      12,576 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ mlp_mixer_layer_2 (MLPMixerLayer)  │ (None, 64, 32)                │      12,576 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ mlp_mixer_layer_3 (MLPMixerLayer)  │ (None, 64, 32)                │      12,576 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ avg_pool (GlobalAveragePooling1D)  │ (None, 32)                    │           0 │
├────────────────────────────────────┼───────────────────────────────┼─────────────┤
│ dense_80 (Dense)                   │ (None, 11)                    │         363 │
└────────────────────────────────────┴───────────────────────────────┴─────────────┘
 Total params: 62,425 (243.85 KB)
 Trainable params: 62,425 (243.85 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 24s 655ms/step - accuracy: 0.3077 - loss: 2.1705 - val_accuracy: 0.8944 - val_loss: 0.3976
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.7351 - loss: 0.8639 - val_accuracy: 0.9317 - val_loss: 0.2207
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.7119 - loss: 1.0056
Test loss: 0.9401963353157043
Test accuracy: 0.7147541046142578

from k3im.simple_vit_3d import SimpleViT3DModel # jax ✅, tensorflow ✅, torch ✅

model = SimpleViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_head=64,
)
model.summary()
Model: "functional_53"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_41            │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_19 (Reshape)      │ (None, 7, 4, 7, 4, 7,  │          0 │ input_layer_41[0][0]       │
│                           │ 4, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_9 (Transpose)   │ (None, 4, 4, 4, 7, 7,  │          0 │ reshape_19[0][0]           │
│                           │ 7, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_20 (Reshape)      │ (None, 4, 4, 4, 343)   │          0 │ transpose_9[0][0]          │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_48    │ (None, 4, 4, 4, 343)   │        686 │ reshape_20[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_81 (Dense)          │ (None, 4, 4, 4, 32)    │     11,008 │ layer_normalization_48[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_49    │ (None, 4, 4, 4, 32)    │         64 │ dense_81[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_21 (Reshape)      │ (None, 64, 32)         │          0 │ layer_normalization_49[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_8    │ (None, 64, 32)         │     33,568 │ reshape_21[0][0],          │
│ (MultiHeadAttention)      │                        │            │ reshape_21[0][0]           │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_26 (Add)              │ (None, 64, 32)         │          0 │ reshape_21[0][0],          │
│                           │                        │            │ multi_head_attention_8[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_32             │ (None, 64, 32)         │      2,176 │ add_26[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_27 (Add)              │ (None, 64, 32)         │          0 │ add_26[0][0],              │
│                           │                        │            │ sequential_32[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_9    │ (None, 64, 32)         │     33,568 │ add_27[0][0], add_27[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_28 (Add)              │ (None, 64, 32)         │          0 │ add_27[0][0],              │
│                           │                        │            │ multi_head_attention_9[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_33             │ (None, 64, 32)         │      2,176 │ add_28[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_29 (Add)              │ (None, 64, 32)         │          0 │ add_28[0][0],              │
│                           │                        │            │ sequential_33[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_52    │ (None, 64, 32)         │         64 │ add_29[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ avg_pool                  │ (None, 32)             │          0 │ layer_normalization_52[0]… │
│ (GlobalAveragePooling1D)  │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_86 (Dense)          │ (None, 11)             │        363 │ avg_pool[0][0]             │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 83,673 (326.85 KB)
 Trainable params: 83,673 (326.85 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 13s 296ms/step - accuracy: 0.3410 - loss: 2.0619 - val_accuracy: 0.8571 - val_loss: 0.7134
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.7661 - loss: 0.9615 - val_accuracy: 0.9006 - val_loss: 0.4115
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 66ms/step - accuracy: 0.6486 - loss: 1.1380
Test loss: 1.0664023160934448
Test accuracy: 0.6672131419181824

from k3im.vit_3d import ViT3DModel # jax ✅, tensorflow ✅, torch ✅
model = ViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    pool='cls',
    channels=1,
    dim_head=64,
)
model.summary()
Model: "functional_57"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)               Output Shape               Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer_44            │ (None, 28, 28, 28, 1)  │          0 │ -                          │
│ (InputLayer)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_22 (Reshape)      │ (None, 7, 4, 7, 4, 7,  │          0 │ input_layer_44[0][0]       │
│                           │ 4, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ transpose_10 (Transpose)  │ (None, 4, 4, 4, 7, 7,  │          0 │ reshape_22[0][0]           │
│                           │ 7, 1)                  │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_23 (Reshape)      │ (None, 4, 4, 4, 343)   │          0 │ transpose_10[0][0]         │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_53    │ (None, 4, 4, 4, 343)   │        686 │ reshape_23[0][0]           │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_87 (Dense)          │ (None, 4, 4, 4, 32)    │     11,008 │ layer_normalization_53[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_54    │ (None, 4, 4, 4, 32)    │         64 │ dense_87[0][0]             │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ reshape_24 (Reshape)      │ (None, 64, 32)         │          0 │ layer_normalization_54[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ class_token_position_emb  │ (None, 65, 32)         │      2,112 │ reshape_24[0][0]           │
│ (ClassTokenPositionEmb)   │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_10   │ (None, 65, 32)         │     33,568 │ class_token_position_emb[ │
│ (MultiHeadAttention)      │                        │            │ class_token_position_emb[ │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_30 (Add)              │ (None, 65, 32)         │          0 │ class_token_position_emb[ │
│                           │                        │            │ multi_head_attention_10[0… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_34             │ (None, 65, 32)         │      2,176 │ add_30[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_31 (Add)              │ (None, 65, 32)         │          0 │ add_30[0][0],              │
│                           │                        │            │ sequential_34[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ multi_head_attention_11   │ (None, 65, 32)         │     33,568 │ add_31[0][0], add_31[0][0] │
│ (MultiHeadAttention)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_32 (Add)              │ (None, 65, 32)         │          0 │ add_31[0][0],              │
│                           │                        │            │ multi_head_attention_11[0… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ sequential_35             │ (None, 65, 32)         │      2,176 │ add_32[0][0]               │
│ (Sequential)              │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ add_33 (Add)              │ (None, 65, 32)         │          0 │ add_32[0][0],              │
│                           │                        │            │ sequential_35[0][0]        │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ layer_normalization_57    │ (None, 65, 32)         │         64 │ add_33[0][0]               │
│ (LayerNormalization)      │                        │            │                            │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ get_item (GetItem)        │ (None, 32)             │          0 │ layer_normalization_57[0]… │
├───────────────────────────┼────────────────────────┼────────────┼────────────────────────────┤
│ dense_92 (Dense)          │ (None, 11)             │        363 │ get_item[0][0]             │
└───────────────────────────┴────────────────────────┴────────────┴────────────────────────────┘
 Total params: 85,785 (335.10 KB)
 Trainable params: 85,785 (335.10 KB)
 Non-trainable params: 0 (0.00 B)
train_model(model)
Epoch 1/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 13s 262ms/step - accuracy: 0.3574 - loss: 1.9998 - val_accuracy: 0.8696 - val_loss: 0.5885
Epoch 2/2
31/31 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7141 - loss: 0.9330 - val_accuracy: 0.9255 - val_loss: 0.3551
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 63ms/step - accuracy: 0.6813 - loss: 1.0153
Test loss: 0.963051438331604
Test accuracy: 0.6983606815338135

from k3im.video_eanet import VideoEANet # fixed jax ✅
model = VideoEANet(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=64,
    spatial_depth=2,
    temporal_depth=2,
    heads=4,
    mlp_dim=128,
    pool="cls",
    channels=1,
    dim_coefficient=4,
    projection_dropout=0.0,
    attention_dropout=0,
    emb_dropout=0.0,
)
model.summary()
train_model(model)
from k3im.vivit import ViViT

model = ViViT(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=64,
    spatial_depth=1,
    temporal_depth=1,
    heads=4,
    mlp_dim=128,
    pool="cls",
    channels=1,
    dim_head=64,
    dropout=0.0,
    emb_dropout=0.0,
)
model.summary()
train_model(model)