Inside a Shared-Attention Block
per-modality projections (pink) feed one shared softmax (blue), then split back to per-modality MLPs
read bottom-to-top — arrows show data flow
FFN \(\!{}^v\) + LayerNorm\(\!{}^v\)
per-modality weights
FFN\(\!{}^a\) + LayerNorm\(\!{}^a\)
per-modality weights
\(W_O^v\) (output projection)
per-modality weights
\(W_O^a\) (output projection)
per-modality weights
↑ split output back to per-modality streams ↑
Shared attention — one softmax over the combined Q / K / V
\(\text{Attn} \;=\; \text{softmax}\!\big([Q^v;\,Q^a]\,[K^v;\,K^a]^\top / \sqrt{d}\big)\;[V^v;\,V^a]\)
single shared op — only place where the two modalities mix
↑ concatenate \(Q\), \(K\), \(V\) along the sequence dim ↑
\(W_Q^v,\, W_K^v,\, W_V^v\)
per-modality QKV projection
\(W_Q^a,\, W_K^a,\, W_V^a\)
per-modality QKV projection
video stream
\(x^v\) video tokens
action stream
\(x^a\) action tokens
per-modality (different weights for video / action)
shared (one operation, all tokens together)
The "shared" in shared attention is just the softmax. Every other weight matrix in the block — \(W_Q\), \(W_K\), \(W_V\), \(W_O\), FFN, LayerNorm — is duplicated per modality. The single thing both modalities reach for is the scaled-dot-product softmax in the middle, which sees the concatenated Q / K / V from both streams. That is enough for an action token's query to attend to a video token's K/V (and vice versa), without forcing them to share any other weights.