class GPT2(nnx.Module):
def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, num_heads: int, rate: float, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
self.embedding_layer = TokenAndPositionEmbedding(
seqlen, vocab_size, embed_dim, rngs=rngs
)