API Reference


alibi_position_embedding(mask::Union{AbstractAttenMask, Nothing}, score, args...)

Add the non-trainable ALiBi position embedding to the attention score. The ALiBi embedding varied for each head, which assuming the attention is multi-head variants. The first dimension of the batch dimension of the attention score is treated as the head dimension (If used in single head attention, the alibi value would vary across batches). mask can either be a attention mask or nothing. Usually, it is needed when there are gaps or prefix paddings in the samples.

biased_score(bias, score, args...)

Adding a precomputed bias to the attention score. bias should be in shape (key length, query length, ...) and size(bias, 1) == size(s, 1) == size(bias, 2) == size(s, 2) && ndims(bias) <= ndims(s) where s = score(args...) must hold.

layer_norm([epsilon = 1e-5,] alpha, beta, x)

Function which perform layer normalization on x. alpha and beta can a Vector, Number or Nothing.

$layer_norm(α, β, x) = α\frac{(x - μ)}{σ} + β$

If both alpha and beta is Nothing, this is just a standardize function applied on the first dimension.

move_head_dim_in_perm(x::AbstractArray{T, N}, nobatch=false)
move_head_dim_in_perm(N::Int, nobatch=false)

Dimension order for permutedims to move the head dimension (created by split_head) from batch dimension to feature dimension (for merge_head). Return a tuple of integer of length n. nobatch specify where x is a batch of data.


julia> Functional.move_head_dim_in_perm(5, false)
(1, 4, 2, 3, 5)

julia> Functional.move_head_dim_in_perm(5, true)
(1, 5, 2, 3, 4)

See also: merge_head, move_head_dim_in

move_head_dim_out_perm(x::AbstractArray{T, N}, nobatch=false)
move_head_dim_out_perm(N::Int, nobatch=false)

Dimension order for permutedims to move the head dimension (created by split_head) to batch dimension. Return a tuple of integer of length n. nobatch specify where x is a batch of data.


julia> Functional.move_head_dim_out_perm(5, false)
(1, 3, 4, 2, 5)

julia> Functional.move_head_dim_out_perm(5, true)
(1, 3, 4, 5, 2)

See also: split_head, move_head_dim_out

naive_qkv_attention(q, k, v, mask=nothing)

The scaled dot-product attention of a regular transformer layer.

$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

It's equivalent to generic_qkv_attention(weighted_sum_mixing, normalized_score(NNlib.softmax) $ masked_score(GenericMaskOp(), mask) $ scaled_dot_product_score, q, k, v).


julia> fdim, ldim, bdim = 32, 10, 4;

julia> x = randn(fdim, ldim, bdim);

julia> y = naive_qkv_attention(x, x, x); # simple self attention

# no mask here
julia> z = generic_qkv_attention(weighted_sum_mixing, normalized_score(NNlib.softmax) $ scaled_dot_product_score, x, x, x);

julia> y ≈ z

See also: generic_qkv_attention

normalized_score(norm) = normalized_score $ norm
normalized_score(norm, score, args...)

Normalized attenion score api. norm is the normalize function (like softmax) and score is the function that compute attention score from args....

See also: naive_qkv_attention

rms_layer_norm([epsilon = 1e-5,] alpha, x)

Function which perform root-mean-square layer normalization on x. alpha and beta can a Vector, Number or Nothing.

$rms_layer_norm(α, x) = α\frac{x}{\sqrt{\sum_{i=1}^{N} x^2 / N}}$

If both alpha is Nothing, this is just a normalization with root-mean-square function applied on the first dimension.

scalar_relative_position_embedding(relative_position_id_func, embedding_table, score, args...)

A relative position embedding that produce a trainable scalar bias for each value in the attention score. relative_position_id_func is a function that take the attention score and return a relative_position_id matrix with the same size of the attention score with batches (normally (key length, query length)). This relative_position_id would be used to index (or gather) the embedding_table. embedding_table is an array with multiple dimensions, where the first dimension is the number of possible "id"s and the remaining dimensions are for giving different value to each heads. By default we treat the last dimension of attention score as the batch dimension and the dimension between last dimension and the "length" dimension as the head dimensions.

 scaled_dot_product_score(q, k, s = sqrt(inv(size(k, 1))))

The scaled dot-product attention score function of a regular transformer layer.

$Score(Q, K) = \frac{QK^T}{\sqrt{d_k}}$

scaled_dot_product_score(f, q, k)

Apply a transform function f on q/k before dot-product.

See also: naive_qkv_attention

split_head(head::Int, x)

Split the first dimension into head piece of small vector. Equivalent to reshape(x, :, head, tail(size(x))...).

with_rotary_position_embedding([size,] x)

Apply rotary position embedding to x. Can take an size argument and the rotary position embedding will only apply to x[1:size, :, ...]. Should be used with scaled_dot_product_score/dot_product_score.

PrefixedFunction(f, args::NTuple{N}) <: Function

A type representating a partially-applied version of the function f, with the first N arguments fixed to the values args. In other words, PrefixedFunction(f, args) behaves similarly to (xs...)->f(args..., xs...).

See also NeuralAttentionlib.:$.



apply_mask(op::GenericMaskOp, mask::AbstractMask, score)

Equivalent to op.apply(score, op.scale .* (op.flip ? .! mask : mask)).


julia> x = randn(10, 10);

julia> m = CausalMask()

julia> apply_mask(GenericMaskOp(.+, true, -1e9), m, x) ==  @. x + (!m * -1e9)
BandPartMask(l::Int, u::Int) <: AbstractAttenMask{DATALESS}

Attention mask that only allow band_part values to pass.


julia> trues(10, 10) .* BandPartMask(3, 5)
10×10 BitMatrix:
 1  1  1  1  1  1  0  0  0  0
 1  1  1  1  1  1  1  0  0  0
 1  1  1  1  1  1  1  1  0  0
 1  1  1  1  1  1  1  1  1  0
 0  1  1  1  1  1  1  1  1  1
 0  0  1  1  1  1  1  1  1  1
 0  0  0  1  1  1  1  1  1  1
 0  0  0  0  1  1  1  1  1  1
 0  0  0  0  0  1  1  1  1  1
 0  0  0  0  0  0  1  1  1  1
BatchedMask(mask::AbstractMask) <: AbstractWrapperMask

Attention mask wrapper over array mask for applying the same mask within the same batch.


julia> m = SymLengthMask([2,3])
SymLengthMask{1, Vector{Int32}}(Int32[2, 3])

julia> trues(3,3, 2) .* m
3×3×2 BitArray{3}:
[:, :, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 2] =
 1  1  1
 1  1  1
 1  1  1

julia> trues(3,3, 2, 2) .* m
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; mask require ndims(A) == 3")

julia> trues(3,3, 2, 2) .* BatchedMask(m) # 4-th dim become batch dim
3×3×2×2 BitArray{4}:
[:, :, 1, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 2, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 1, 2] =
 1  1  1
 1  1  1
 1  1  1

[:, :, 2, 2] =
 1  1  1
 1  1  1
 1  1  1
BiLengthMask(q_len::A, k_len::A) where {A <: AbstractArray{Int, N}} <: AbstractAttenMask{ARRAYDATA}

Attention mask specified by two arrays of integer that indicate the length dimension size.


julia> bm = BiLengthMask([2,3], [3, 5])
BiLengthMask{1, Vector{Int32}}(Int32[2, 3], Int32[3, 5])

julia> trues(5,5, 2) .* bm
5×5×2 BitArray{3}:
[:, :, 1] =
 1  1  0  0  0
 1  1  0  0  0
 1  1  0  0  0
 0  0  0  0  0
 0  0  0  0  0

[:, :, 2] =
 1  1  1  0  0
 1  1  1  0  0
 1  1  1  0  0
 1  1  1  0  0
 1  1  1  0  0

See also: SymLengthMask, BiSeqMask, BatchedMask, RepeatMask

BiSeqMask(qmask::A1, kmask::A2) where {A1 <: AbstractSeqMask, A2 <: AbstractSeqMask} <: AbstractAttenMask

Take two sequence mask and construct an attention mask.


julia> trues(7, 7, 2) .* Masks.BiSeqMask(Masks.LengthMask([3, 5]), Masks.RevLengthMask([3, 5]))
7×7×2 BitArray{3}:
[:, :, 1] =
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 1  1  1  0  0  0  0
 1  1  1  0  0  0  0
 1  1  1  0  0  0  0

[:, :, 2] =
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 1  1  1  1  1  0  0
 1  1  1  1  1  0  0
 1  1  1  1  1  0  0
 1  1  1  1  1  0  0
 1  1  1  1  1  0  0

See also: BiLengthMask, RevBiLengthMask

CausalMask() <: AbstractAttenMask{DATALESS}

Attention mask that block the future values.

Similar to applying LinearAlgebra.triu! on the score matrix


julia> trues(10, 10) .* CausalMask()
10×10 BitMatrix:
 1  1  1  1  1  1  1  1  1  1
 0  1  1  1  1  1  1  1  1  1
 0  0  1  1  1  1  1  1  1  1
 0  0  0  1  1  1  1  1  1  1
 0  0  0  0  1  1  1  1  1  1
 0  0  0  0  0  1  1  1  1  1
 0  0  0  0  0  0  1  1  1  1
 0  0  0  0  0  0  0  1  1  1
 0  0  0  0  0  0  0  0  1  1
 0  0  0  0  0  0  0  0  0  1
GenericAttenMask <: AbstractAttenMask{ARRAYDATA}

Generic attention mask. Just a wrapper over AbstractArray{Bool} for dispatch.


julia> bitmask = rand(Bool, 10, 10)
10×10 Matrix{Bool}:
 1  0  1  1  0  0  1  0  1  1
 0  0  1  1  0  0  0  1  1  1
 0  1  0  1  0  1  0  0  1  0
 0  1  1  0  1  1  0  0  0  1
 1  0  1  1  1  0  0  0  0  1
 1  0  1  0  1  1  1  1  0  1
 0  0  0  1  1  1  0  1  1  1
 1  0  1  0  1  1  1  0  0  1
 0  1  0  1  0  0  1  1  0  1
 0  0  0  1  0  1  0  0  0  1

julia> trues(10, 10) .* GenericAttenMask(bitmask)
10×10 BitMatrix:
 1  0  1  1  0  0  1  0  1  1
 0  0  1  1  0  0  0  1  1  1
 0  1  0  1  0  1  0  0  1  0
 0  1  1  0  1  1  0  0  0  1
 1  0  1  1  1  0  0  0  0  1
 1  0  1  0  1  1  1  1  0  1
 0  0  0  1  1  1  0  1  1  1
 1  0  1  0  1  1  1  0  0  1
 0  1  0  1  0  0  1  1  0  1
 0  0  0  1  0  1  0  0  0  1
GenericSeqMask(mask::AbstractArray{Bool}) <: AbstractSeqMask{ARRAYDATA}

Create a sequence mask from an array of Bool.


julia> m = GenericSeqMask(rand(Bool, 10, 2))
GenericSeqMask{3, Array{Bool, 3}}([0 1 … 0 0;;; 1 0 … 1 0])

julia> trues(7, 10, 2) .* m
7×10×2 BitArray{3}:
[:, :, 1] =
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0
 0  1  0  0  1  0  0  0  0  0

[:, :, 2] =
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0
 1  0  1  1  0  1  1  1  1  0

julia> m.mask
1×10×2 Array{Bool, 3}:
[:, :, 1] =
 0  1  0  0  1  0  0  0  0  0

[:, :, 2] =
 1  0  1  1  0  1  1  1  1  0
Indexer(m::AbstractMask, size::Dims{N}) <: AbstractArray{Bool, N}
Indexer(m::AbstractMask, size::Dims{N}, scale::T) <: AbstractArray{T, N}

A lazy array-like object that "materialize" the mask m with size and a optional scale without size check.

See also: GetIndexer

LengthMask(len::AbstractArray{Int, N}) <: AbstractSeqMask{ARRAYDATA}

A Sequence Mask specified by an array of integer that indicate the length dimension size. Can be convert to attention mask (SymLengthMask, BiLengthMask) with AttenMask.


julia> ones(7, 7, 2) .* LengthMask([3, 5])
7×7×2 Array{Float64, 3}:
[:, :, 1] =
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  0.0

[:, :, 2] =
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  0.0  0.0
LocalMask(width::Int) <: AbstractAttenMask{DATALESS}

Attention mask that only allow local (diagonal like) values to pass.

width should be ≥ 0 and A .* LocalMask(1) is similar to Diagonal(A)


julia> trues(10, 10) .* LocalMask(3)
10×10 BitMatrix:
 1  1  1  0  0  0  0  0  0  0
 1  1  1  1  0  0  0  0  0  0
 1  1  1  1  1  0  0  0  0  0
 0  1  1  1  1  1  0  0  0  0
 0  0  1  1  1  1  1  0  0  0
 0  0  0  1  1  1  1  1  0  0
 0  0  0  0  1  1  1  1  1  0
 0  0  0  0  0  1  1  1  1  1
 0  0  0  0  0  0  1  1  1  1
 0  0  0  0  0  0  0  1  1  1
NoMask{T}() <: AbstractDatalessMask{T}

A mask for no mask only for work with wrapper masks type constraints. Generally use nothing instead of NoMask with apply_mask/mask_score for the fast path.

RandomMask(p::Float32) <: AbstractAttenMask{DATALESS}

Attention mask that block value randomly.

p specify the percentage of value to block. e.g. A .* RandomMask(0) is equivalent to identity(A) and A .* RandomMask(1) is equivalent to zero(A).


julia> trues(10, 10) .* RandomMask(0.5)
10×10 BitMatrix:
 1  1  1  1  1  1  0  1  1  1
 0  0  1  0  1  0  0  0  1  0
 0  0  1  1  0  0  0  0  1  1
 1  0  1  1  1  0  0  1  0  1
 1  1  0  1  0  0  1  0  1  1
 0  1  1  1  1  0  1  0  1  1
 1  1  0  0  0  0  1  0  0  0
 0  0  1  0  1  1  0  1  1  0
 1  1  1  1  1  1  0  0  1  1
 0  0  1  0  1  1  0  0  1  0

julia> trues(10, 10) .* RandomMask(0.5)
10×10 BitMatrix:
 1  0  1  1  0  0  1  1  0  1
 0  1  0  1  1  1  0  0  1  1
 0  0  1  0  0  0  1  1  0  0
 0  0  0  0  1  0  0  1  1  1
 0  1  1  1  1  0  1  0  0  1
 1  0  0  1  1  0  0  0  1  1
 1  1  1  0  1  1  1  0  0  0
 0  0  1  1  0  0  1  1  1  0
 0  1  1  1  1  0  1  0  1  0
 0  0  1  0  0  0  0  1  1  1
RepeatMask(mask::AbstractMask, num::Int) <: AbstractWrapperMask

Attention mask wrapper over array mask for doing inner repeat on the last dimension.


julia> m = SymLengthMask([2,3])
SymLengthMask{1, Vector{Int32}}(Int32[2, 3])

julia> trues(3,3, 2) .* m
3×3×2 BitArray{3}:
[:, :, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 2] =
 1  1  1
 1  1  1
 1  1  1

julia> trues(3,3, 4) .* m
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; mask require 3-th dimension to be 2, but get 4")

julia> trues(3,3, 4) .* RepeatMask(m, 2)
3×3×4 BitArray{3}:
[:, :, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 2] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 3] =
 1  1  1
 1  1  1
 1  1  1

[:, :, 4] =
 1  1  1
 1  1  1
 1  1  1
RevBiLengthMask(q_len::A, k_len::A) where {A <: AbstractArray{Int, N}} <: AbstractAttenMask{ARRAYDATA}

BiLengthMask but counts from the end of array, used for left padding.


julia> bm = RevBiLengthMask([2,3], [3, 5])
RevBiLengthMask{1, Vector{Int32}}(Int32[2, 3], Int32[3, 5])

julia> trues(5,5, 2) .* bm
5×5×2 BitArray{3}:
[:, :, 1] =
 0  0  0  0  0
 0  0  0  0  0
 0  0  0  1  1
 0  0  0  1  1
 0  0  0  1  1

[:, :, 2] =
 0  0  1  1  1
 0  0  1  1  1
 0  0  1  1  1
 0  0  1  1  1
 0  0  1  1  1

See also: RevLengthMask, RevSymLengthMask, BiSeqMask, BatchedMask, RepeatMask

RevLengthMask(len::AbstractArray{Int, N}) <: AbstractSeqMask{ARRAYDATA}

LengthMask but counts from the end of array, used for left padding. Can be convert to attention mask (RevSymLengthMask, RevBiLengthMask) with AttenMask.


julia> ones(7, 7, 2) .* RevLengthMask([3, 5])
7×7×2 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  1.0  1.0  1.0

[:, :, 2] =
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0
RevSymLengthMask(len::AbstractArray{Int, N}) <: AbstractAttenMask{ARRAYDATA}

SymLengthMask but counts from the end of array, used for left padding.


julia> m = RevSymLengthMask([2,3])
RevSymLengthMask{1, Vector{Int32}}(Int32[2, 3])

julia> trues(3,3, 2) .* m
3×3×2 BitArray{3}:
[:, :, 1] =
 0  0  0
 0  1  1
 0  1  1

[:, :, 2] =
 1  1  1
 1  1  1
 1  1  1

See also: BiLengthMask, BatchedMask, RepeatMask

SymLengthMask(len::AbstractArray{Int, N}) <: AbstractAttenMask{ARRAYDATA}

Attention mask specified by an array of integer that indicate the length dimension size. assuming Query length and Key length are the same.


julia> m = SymLengthMask([2,3])
SymLengthMask{1, Vector{Int32}}(Int32[2, 3])

julia> trues(3,3, 2) .* m
3×3×2 BitArray{3}:
[:, :, 1] =
 1  1  0
 1  1  0
 0  0  0

[:, :, 2] =
 1  1  1
 1  1  1
 1  1  1

See also: LengthMask, BiLengthMask, BatchedMask, RepeatMask


Boolean not of an attention mask

m1::AbstractMask & m2::AbstractMask

logical and of two attention mask

m1::AbstractMask | m2::AbstractMask

logical or of two attention mask


Convert mask into corresponding attention mask.

AttenMask(q_mask::AbstractSeqMask, k_mask::AbstractSeqMask)

Create a attention mask from 2 sequence masks specific the sequence mask for "query" and "key".

getmask(m::AbstractMask, score, scale = 1)

Convert m into mask array of AbstractArray for score with scale.


julia> getmask(CausalMask(), randn(7,7), 2)
7×7 Matrix{Float64}:
 2.0  2.0  2.0  2.0  2.0  2.0  2.0
 0.0  2.0  2.0  2.0  2.0  2.0  2.0
 0.0  0.0  2.0  2.0  2.0  2.0  2.0
 0.0  0.0  0.0  2.0  2.0  2.0  2.0
 0.0  0.0  0.0  0.0  2.0  2.0  2.0
 0.0  0.0  0.0  0.0  0.0  2.0  2.0
 0.0  0.0  0.0  0.0  0.0  0.0  2.0


collapsed_size(x, ni, nj [, n])::Dim{3}

Collapse the dimensionality of x into 3 according to ni and nj where ni, nj specify the number of second and third dimensions it take.

(X1, X2, ..., Xk, Xk+1, Xk+2, ..., Xk+ni, Xk+ni+1, ..., Xn)
 |______dim1___|  |_________ni_________|  |______nj______|


julia> x = randn(7,6,5,4,3,2);

julia> collapsed_size(x, 2, 2, 1)

julia> collapsed_size(x, 2, 2, 2)

julia> collapsed_size(x, 2, 2, 3)

julia> collapsed_size(x, 2, 2)
(42, 20, 6)

See also: noncollapsed_size

matmul(a::AbstractArray, b::AbstractArray, s::Number = 1)

Equivalent to s .* (a * b) if a and b are Vector or Matrix. For array with higher dimension, it will convert a and b to CollapsedDimsArray and perform batched matrix multiplication, and then return the result as CollapsedDimsArray. This is useful for preserving the dimensionality. If the batch dimension of a and b have different shape, it pick the shape of b for batch dimension. Work with NNlib.batch_transpose and NNlib.batch_adjoint.


# b-dim shape: (6,)
julia> a = CollapsedDimsArray(randn(3,4,2,3,6), 2, 1); size(a)
(12, 6, 6)

# b-dim shape: (3,1,2)
julia> b = CollapsedDimsArray(randn(6,2,3,1,2), 1, 3); size(b)
(6, 2, 6)

julia> c = matmul(a, b); size(c), typeof(c)
((12, 2, 6), CollapsedDimsArray{Float64, Array{Float64, 6}, Static.StaticInt{1}, Static.StaticInt{3}})

# b-dim shape: (3,1,2)
julia> d = unwrap_collapse(c); size(d), typeof(d)
((3, 4, 2, 3, 1, 2), Array{Float64, 6})

# equivanlent to `batched_mul` but preserve shape
julia> NNlib.batched_mul(collapseddims(a), collapseddims(b)) == collapseddims(matmul(a, b))

See also: CollapsedDimsArray, unwrap_collapse, collapseddims

noncollapsed_size(x, ni, nj [, n])

Collapse the dimensionality of x into 3 according to ni and nj.

(X1, X2, ..., Xk, Xk+1, Xk+2, ..., Xk+ni, Xk+ni+1, ..., Xn)
 |______dim1___|  |_________ni_________|  |______nj______|

But take the size before collapse. e.g. noncollapsed_size(x, ni, nj, 2) will be (Xi, Xi+1, ..., Xj-1).


julia> x = randn(7,6,5,4,3,2);

julia> noncollapsed_size(x, 2, 2, 1)
(7, 6)

julia> noncollapsed_size(x, 2, 2, 2)
(5, 4)

julia> noncollapsed_size(x, 2, 2, 3)
(3, 2)

julia> noncollapsed_size(x, 2, 2)
((7, 6), (5, 4), (3, 2))

See also: collapsed_size

scaled_matmul(a::AbstractArray, b::AbstractArray, s::Number = 1)

Basically equivalent to unwrap_collapse(matmul(a, b, s)), but not differentiable w.r.t. to s.
