def CAiT3DModel(
image_size,
image_patch_size,
frames,
frame_patch_size,
num_classes,
dim,
depth,
cls_depth,
heads,
mlp_dim,
channels=3,
dim_head=64,
):
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert (
image_height % patch_height == 0 and image_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."
assert (
frames % frame_patch_size == 0
), "Frames must be divisible by the frame patch size"
nf, nh, nw = (
frames // frame_patch_size,
image_height // patch_height,
image_width // patch_width,
)
patch_dim = channels * patch_height * patch_width * frame_patch_size
i_p = layers.Input((frames, image_height, image_width, channels))
tubelets = layers.Reshape(
(frame_patch_size, nf, patch_height, nh, patch_width, nw, channels)
)(i_p)
tubelets = ops.transpose(tubelets, (0, 2, 4, 6, 1, 3, 5, 7))
tubelets = layers.Reshape((nf, nh, nw, -1))(tubelets)
tubelets = layers.LayerNormalization()(tubelets)
tubelets = layers.Dense(dim)(tubelets)
tubelets = layers.LayerNormalization()(tubelets)
tubelets = layers.Reshape((-1, dim))(tubelets)
tubelets = Transformer(dim, depth, heads, dim_head, mlp_dim)(tubelets)
_, cls_token = CLS_Token(dim)(tubelets)
cls_token = Transformer(dim, cls_depth, heads, dim_head, mlp_dim)(
cls_token, context=tubelets
)
cls_token = ops.squeeze(cls_token, axis=1)
o_p = layers.Dense(num_classes)(cls_token)
return keras.Model(inputs=i_p, outputs=o_p)