Skip to content
paulius.rauba
/papers/ nested-subspace-networks
live · v1.0

ICLR 2026 · interactive exposition · paper #01

Deep Hierarchical Learning with
Nested Subspace Networks

Paulius Rauba · Mihaela van der Schaar · University of Cambridge

Modern neural networks are trained for a fixed compute budget — a rigid bargain between performance and cost. We instead reparameterize each linear layer as over shared factors, so that every rank- sub-model is a strict prefix of every higher-rank one. A single set of weights then encodes a continuous hierarchy of models that can be traversed at inference time by simply choosing .

active rank

24 / 48

capacity · 50.0%

step · shift+←/→ jump · space animate

macs / token

12.3K

37.5% of dense

est. accuracy

78.4%

−6.1 pp vs full

panel 01

Subspace flow

A token vector is projected through a shared factor stack. Only the first r channels carry signal — everything beyond rank r is dark, but instantly available the moment you slide the dial up.

x ∈ ℝd_in A · x ∈ ℝr B(A · x) ∈ ℝd_out

panel 02

Matrix factorization · live

The dense weight is rebuilt from the first columns of and the first rows of . Watch the high-frequency texture of sharpen as the rank grows.

first r columns active

first r rows active

· 0.0%

panel 03

SVD · seen on a target

The same construction, applied to a synthetic target. A real decompose_and_replace truncates the SVD of each pretrained weight to the active rank.

target
rank-r reconstruction
|target − rank-r|

energy retained · 98.4% · spectrum (σk)

panel 04

Inference budget

Per-token MACs at rank are exactly — a straight line. The accuracy, by contrast, saturates fast.

62.5% saved

vs dense layer

accuracy rank · 1 → R

d_model

2048

max rank

1024

active

512

net Δ

−50.0%

section 05

How it actually works

The mechanism is a few lines of PyTorch — verbatim from the reference implementation.

Each nn.Linear(d_in, d_out) is replaced by two factors and . At inference, calling set_rank(r) simply slices the shared factors:

Because rank uses a prefix of the same factors as rank , the image spaces are nested:

Training jointly couples ranks via an uncertainty-aware objective (learned log-variances per rank) plus a curriculum that introduces low ranks gradually. The result is a single set of weights that behaves as a continuum of models — anywhere from a 50% FLOPs reduction at a 5pp accuracy drop, to interpolating to ranks the network never even saw at training time.

class DynamicLowRankLinear(nn.Module):
    def __init__(self, in_features, out_features, max_rank, bias=True):
        super().__init__()
        self.in_features  = in_features
        self.out_features = out_features
        self.max_rank     = max_rank
        self.active_rank  = max_rank

        self.A = nn.Linear(in_features,  max_rank,    bias=False)
        self.B = nn.Linear(max_rank,     out_features, bias=bias)

    def set_rank(self, rank):
        self.active_rank = max(1, min(rank, self.max_rank))

    def forward(self, x):
        r = self.active_rank
        A_r = self.A.weight[:r, :]      # (r, in_features)
        B_r = self.B.weight[:, :r]      # (out_features, r)
        h   = F.linear(x, A_r)
        return F.linear(h, B_r, self.B.bias)

section 06

Cite

@inproceedings{rauba2026deep,
  title     = {Deep Hierarchical Learning with Nested Subspace Networks},
  author    = {Paulius Rauba and Mihaela van der Schaar},
  booktitle = {The Fourteenth International Conference on Learning Representations},
  year      = {2026},
  url       = {https://openreview.net/forum?id=ymUOPsbxLi}
}