Embedding Minibatching Decomposition
Every address, opcode, operand index, custom-call name, and source-line tag on this page was read byte-exactly from
libtpu.soin thelibtpu-0.0.40-cp314wheel (build-id89edbbe81c5b328a958fe628a9f2207d, buildlibtpu_lts_20260413_b_RC00; not stripped —nm -Cresolves every method)..textVMA equals its file offset (0xe63c000);.rodataat0x84a0000;.data.rel.roisVMA − 0x200000(the operand-name tables, filled byR_X86_64_RELATIVEat load). Addresses apply to this build; other versions differ.
Abstract
This page is the HLO-layer decomposition that sits above the SC scan lowering: how the front-end's monolithic sparse-embedding-matmul custom-call (SparseDenseMatmulWithMinibatchingOp) is split into one inner SparseDenseMatmulOp per minibatch per physical SparseCore core, how each split's per-core CSR window is addressed, and how the result feeds the SC-dialect op set that the SampleCombiner emitter and the valency loop consume downstream. It is owned by three sibling HLO/MLIR passes in xla::tpu::sparse_core: MinibatchingDecomposition (the per-minibatch CSR-slice arithmetic + the forward/backward pass roster), EmbeddingDataFormattingDecomposer (the dense-per-table ↔ SC-packed-stacked-table activations/gradients adapter), and PackedOperandsLowering (the MLIR full-conversion pass that re-creates the dialect sparse_core::SegmentedScanOp on packed-width operands).
The single decisive structural fact — and the one that most often gets mis-read — is that the op named DynamicSliceCsr is not an HLO kDynamicSlice (0x36). CreateDynamicSliceCsr (0x13489ea0) emits no dynamic-slice opcode at all. It builds a {sliced-csr, base-offset, padded-count} 3-tuple out of a GetCoreIndex custom-call, a GetPaddedRowCount granule clamp, and a pure integer multiply/add chain (three kMultiply=0x4b plus one kAdd=0x3). The actual per-minibatch CSR row-pointer window is sliced by the inner SparseDenseMatmulOp custom-call the forward pass emits next, parameterised by that tuple. DynamicSliceCsr/GetCoreIndex are SparseCoreOperationType custom-call names (op-types 0x10 / 0xc), not HLO opcodes.
The second decisive fact is the division of labour across the three passes. MinibatchingDecomposition is the only one that touches the CSR row-pointers — it is the segment-id provenance the SegmentedScan ultimately reduces over. EmbeddingDataFormattingDecomposer reformats activations and gradients between the dense per-table XLA layout and the SC packed stacked-table layout; it is not the CSR/id/gain operand reformatter. PackedOperandsLowering is the dialect-level op-packing pass that re-creates SegmentedScanOp (and ~40 other AluEp ops) on the target packed width via an unpack→create→pack rewrite. A reimplementer must keep these three concerns separate.
The page is three units plus the downstream binding: the minibatching decomposition pipeline (op recognition, CreateDynamicSliceCsr, the granule clamp, the forward/backward roster, while-fusion vs no-while), the operand partition (the per-core CSR base-offset arithmetic + the forward/backward operand-name tables), the packed-operands lowering (the SegmentedScanOp re-builder), and the binding into the gather→sort→uniquify→reduce→scatter SC Stream-op DAG the DotCombiner emitter produces.
| Decomposition pass | MinibatchingDecomposition::RunImpl (0x1348f940) — scan ops by custom-call name, build ArgSpecs, dispatch forward/backward |
| Op recogniser | minibatching_decomposer_util::IsSparseDenseMatmulWithMinibatchingOp (0x13c86da0) — IsCustomCall("…WithMinibatchingOp", 35) OR IsCustomCall("…GradOptimizerUpdateWithMinibatchingOp", 54) |
| Per-minibatch slice | MinibatchingDecomposition::CreateDynamicSliceCsr (0x13489ea0, 0x8c0 B) → {sliced-csr, base, padded} 3-tuple. No HLO kDynamicSlice. |
| Padded count | sparse_dense_matmul_decomposer_util::GetPaddedRowCount (0x13c90280) = max( max(GranuleBytes/4, num), cfg[+0x948]→[+0x94] ), gated by SupportsSparseCore() |
| Custom-call op-types | GetCoreIndex = 0xc, DynamicSliceCsr = 0x10 (SparseCoreOperationTypeToString 0x14b7f480) |
| HLO opcodes emitted | kMultiply = 0x4b (×3), kAdd = 0x3 (×1) — verified against StringToHloOpcode init (0x1e5ef040) |
| Decomposed inner op | GetSparseDenseMatmulOpCustomCallTarget (0x13c86e60) → "SparseDenseMatmulOp" (19 B) |
| Data-format adapter | EmbeddingDataFormattingDecomposer::RunImpl (0x1368b4a0) — op-types 0x1a..0x1d; Sc/Tc gated by EnableEmbeddingDataFormattingOffload |
| Packed-op lowering | ScanOpLowering<SegmentedScanOp>::matchAndRewrite (0x135f3000) — unpack→getReductionOp→SegmentedScanOp::create→pack |
| Downstream emitter | SparseDenseMatmulDotCombinerEmitter::Emit (0x1332bda0) → EmitValencyLoop / EmitVectorizedLoop / EmitSampleCombiner |
Unit 1 — The Minibatching Decomposition Pipeline
Why minibatching exists
A TPU-embedding lookup produces, per training step, a concatenated CSR (compressed-sparse-row) structure describing which embedding-table rows each sample touches. When the global lookup is too large to process in one pass, the front end tags the custom-call as a minibatching op: the lookup is to be processed in several sub-batches, each handled independently, with the results recombined. MinibatchingDecomposition is the HLO pass that turns that one tagged custom-call into one concrete inner SparseDenseMatmulOp per (minibatch, physical SC core) pair — and that per-pair op is the unit the SC scan datapath actually emits a sequencer program for.
The structurally important point is that the decomposition is purely an HLO graph rewrite. It emits no SC dialect, no SPMEM traffic, no sequencer instructions — it produces a tree of HLO custom-calls and arithmetic whose leaves are the per-minibatch inner ops. The SC-dialect lowering happens later, when each inner SparseDenseMatmulOp is itself lowered (Unit 4).
Op recognition — by custom-call name
RunImpl (0x1348f940) walks the module and recognises its target op by custom-call target string, not by opcode or attribute. The recogniser IsSparseDenseMatmulWithMinibatchingOp (0x13c86da0) is a two-name disjunction, byte-confirmed from the decompiled body:
// IsSparseDenseMatmulWithMinibatchingOp (0x13c86da0) — decompile-exact
bool IsSparseDenseMatmulWithMinibatchingOp(const HloInstruction *op) {
if (op->IsCustomCall("SparseDenseMatmulWithMinibatchingOp", 35)) // forward
return true;
return op->IsCustomCall("SparseDenseMatmulGradOptimizerUpdateWithMinibatchingOp", 54); // backward
}
Once matched, GetArgSpec (0x13c87040) selects the per-op operand spec — ForwardPassArgSpec (vtable 0x21937cc8) vs BackwardPassArgSpec (vtable 0x219382c0) — and GetSparseDenseMatmulOpCustomCallTarget (0x13c86e60) maps the matched minibatching op onto the decomposed inner op name. The mapping, read from the decompiled string-assignment arms:
| Matched custom-call (length B) | Decomposed inner op |
|---|---|
SparseDenseMatmulWithMinibatchingOp (35) | SparseDenseMatmulOp (19) |
SparseDenseMatmulGradOptimizerUpdateWithMinibatchingOp (54) | grad-with-optimizer-update inner (kSparseDenseMatmulGradOpWithOptimizerUpdate) |
SparseDenseMatmulCustomCombinerMegachipOp (41) | SparseDenseMatmulCustomCombinerOp |
…CustomCombinerTcCombinerMegachipOp | …CustomCombinerTcCombinerOp |
NOTE — recognition is string-keyed, so the op vocabulary is a closed set of
.rodatastrings. The match isHloInstruction::IsCustomCall(target_string, length)against literal byte strings, not an enum compare. A reimplementer reproduces this by string-matching the custom-call target; the megachip/custom-combiner variants are additional literal strings in the same recogniser family.
CreateDynamicSliceCsr — the per-minibatch slice descriptor (NOT a dynamic-slice)
CreateDynamicSliceCsr (0x13489ea0) is the heart of the partition. It is named for the SC custom-call it produces, not for any HLO dynamic-slice opcode. Its job: for one minibatch, build a {sliced-csr, base-offset, padded-count} 3-tuple that the inner SparseDenseMatmulOp will use to read exactly its slice of the concatenated CSR row-pointers. The decompiled HLO-builder call sequence, in emit order:
CreateDynamicSliceCsr (0x13489ea0) — decompile-confirmed op DAG, in order
guard RetCheckFail "max_ids_per_partition > 0" (minibatching_decomposer.cc:154) if num <= 0
0 padded = GetPaddedRowCount(target, num) // 0x13c90280 — granule clamp (see below)
1 C_pad = CreateConstant( LiteralUtil::CreateR0<int>(padded) ) // s32 scalar
2 GCI = CreateCustomCall(s32[1], "GetCoreIndex", {}) // op-type 0xc; runtime per-core index
3 mul1 = CreateBinary(s32[1], MULTIPLY/*0x4b*/, a6, C_pad) // a6 * padded
4 mul2 = CreateBinary(s32[1], MULTIPLY/*0x4b*/, GCI, mul1) // GCI * a6 * padded
5 mul3 = CreateBinary(s32[1], MULTIPLY/*0x4b*/, C_pad, a7) // padded * a7
6 base = CreateBinary(s32[1], ADD/*0x3*/, mul2, mul3) // per-core CSR base offset = padded·(GCI·a6 + a7)
7 dsc = CreateCustomCall(…, "DynamicSliceCsr", {csr, base, C_pad}) // op-type 0x10; 3 operands
8 gte = CreateGetTupleElement(shape, dsc, 0)
9 tuple = CreateTuple({ gte, base, C_pad }) // StatusOr<HloInstruction*> 3-tuple
The three multiplies are opcode 0x4b (MULTIPLY) and the single add is 0x3 (ADD), both read against the alphabetical StringToHloOpcode init table (0x1e5ef040), where add=0x3, multiply=0x4b, and dynamic-slice=0x36. 0x36 is never emitted here — there is no HLO dynamic-slice anywhere in this function. The shapes are built with ShapeUtil::MakeValidatedShape(S32 /*PrimitiveType 4*/, …).
NOTE —
CreateDynamicSliceCsremits no HLO dynamic-slice. The decompiled body emits onlykMultiply(×3) andkAdd(×1) as HLO opcodes, plusConstant/CustomCall/GetTupleElement/Tuple. The"DynamicSliceCsr"string is aSparseCoreOperationTypecustom-call name (op-type0x10), notHloOpcode::kDynamicSlice. The per-minibatch CSR row-pointer window is sliced inside the innerSparseDenseMatmulOp, parameterised by the{base, padded}this tuple carries.
NOTE — the early-return guard string. The
num <= 0guardRetChecks with the message"max_ids_per_partition > 0"(platforms/xla/sparse_core/hlo/minibatching_decomposer.cc:154), and a secondRetCheckon the inner-call result checks"concatenated_csr_tuple.size() > kCsrTupleRowPtrIndex"(line 177). Both strings are read directly from the decompiled body.
GetPaddedRowCount — the granule clamp
GetPaddedRowCount (0x13c90280) computes the per-segment fixed stride the SegmentedScan operates over: it rounds the minibatch row count up to a granule of int32 words and floors it at a per-target minimum. The decompiled body is short enough to be exact:
// GetPaddedRowCount (0x13c90280) — decompile-exact
int GetPaddedRowCount(const Target &t, int num) {
uint64_t granule = t.GranuleBytes(); // vtable[+0x5c0] (0x1d617f80)
if (!t.SupportsSparseCore()) // vtable[+0x260]
LOG(FATAL) << "SparseCore is not supported by this target"; // target.h:1709
uint64_t r = granule >> 2; // GranuleBytes / sizeof(int32) = i32 words/granule
if (num > (int)r) r = num; // r = max(granule/4, num)
int cfg = *(int*)( *((long*)&t + 297) + 148 ); // t[+0x948] -> [+0x94] (per-target min-row config)
if ((int)r <= cfg) r = cfg; // r = max(r, cfg) <-- floor, not cap
return (int)r;
}
So the padded row count is max( max(GranuleBytes/4, num), cfg ) — a granule-aligned floor on the per-minibatch row count.
NOTE — the config bound is a floor (
max), not a cap (min). The decompiled comparison isif ((int)r <= cfg) r = cfg, which raisesrup tocfgwhenris below it — a lower bound. The padded count is thereforemax(max(GranuleBytes/4, num), cfg). The semantic identity of thecfgfield ([Target+0x948]→[+0x94]) is read structurally — it is a per-target row-count config, but its proto field name is not decoded.
The forward / backward pass roster
For each matched op, DecomposeForwardPass (0x1348a9c0) — and its grad twin DecomposeBackwardPass (0x1348b600) — emits the per-minibatch inner ops. Per minibatch the forward pass: reads core counts via GetNumSparseCores (0x13c9eba0) + GetCores (0x14b79900); reads the division-level field from the SparseDenseMatmulConfig globals (0x223a95b8, field +0x30); calls CreateDynamicSliceCsr; emits the inner SparseDenseMatmulOp via CreateCustomCall; and pulls the per-minibatch result with CreateGetTupleElement. Two emission modes exist, dispatched by the top-level driver:
| Pass | Address | Role |
|---|---|---|
RunImpl | 0x1348f940 | scan ops, build ArgSpecs, dispatch |
IsSparseDenseMatmulWithMinibatchingOp | 0x13c86da0 | IsCustomCall(name) recogniser |
GetArgSpec | 0x13c87040 | → ForwardPassArgSpec / BackwardPassArgSpec |
GetSparseDenseMatmulOpCustomCallTarget | 0x13c86e60 | minibatching → decomposed inner op name |
CreateDynamicSliceCsr | 0x13489ea0 | per-minibatch {base, padded} tuple |
GetPaddedRowCount | 0x13c90280 | granule clamp |
CreateAddComputation | 0x13c90320 | builds the "add" reduction HLO computation |
DecomposeForwardPass | 0x1348a9c0 | per-minibatch forward emit |
DecomposeBackwardPass | 0x1348b600 | per-minibatch grad emit |
DecomposeForwardPassesWithNoWhileLoop | 0x1348c720 | inline forward (minibatches unrolled) |
DecomposeForwardPassesWithWhileFusion | 0x1348cd60 | while-loop forward |
DecomposeBackwardPassesWithNoWhileLoop / …WhileFusion | 0x1348ca40 / 0x1348e260 | backward analogues |
DecomposeSparseDenseMatmulWithMinibatchingWithWhileFusion | 0x1348f200 | top-level while-fusion driver |
CombineParamsIntoTupleAndUpdateOutputShape | 0x13c87260 | while-carry tuple (CreateTuple) |
CreateInitialWhileLoopInductionVar | 0x1348a760 | while induction var |
GOTCHA — two emission modes, same descriptor. The no-while mode (
0x1348c720) emits each minibatch's inner op inline (unrolled); the while-fusion mode (0x1348cd60) wraps them in an HLOwhileloop whose carry tuple is assembled byCombineParamsIntoTupleAndUpdateOutputShape(0x13c87260) and whose induction var is seeded byCreateInitialWhileLoopInductionVar(0x1348a760). Both consume the sameCreateDynamicSliceCsr{base, padded}descriptor — the mode only changes how the per-minibatch bodies are laid out in the graph, not the per-minibatch slice arithmetic. A reimplementer must produce both forms or accept that a fixed-trip-countwhileis the canonical large-batch shape.
Unit 2 — The Operand Partition
The per-core CSR base offset
The whole point of the multiply/add chain in CreateDynamicSliceCsr is to give each physical SparseCore a contiguous, padded window into the single concatenated concatenated_csr_pointers operand. The window's start is the base value; its length is padded. The register-identity dataflow, byte-confirmed:
| symbol | expression | provenance |
|---|---|---|
padded | max( max(GranuleBytes/4, num), cfg ) | GetPaddedRowCount; num = arg int |
C_pad | const s32 = padded | CreateConstant |
GCI | custom-call "GetCoreIndex" → s32[1] | runtime per-physical-SC index (op-type 0xc) |
mul1 | b * C_pad | b = HloInstruction* arg (≈ minibatch index, INFERRED) |
mul2 | GCI * mul1 = GCI · b · padded | |
mul3 | C_pad * GCI = padded · GCI | |
base | mul2 + mul3 = GCI·b·padded + padded·GCI = padded · GCI · (b + 1) | the per-core CSR window base offset |
So each physical SparseCore reads a contiguous padded-length window of the concatenated row-pointers starting at padded · GCI · (b + 1). The csr operand (an HloInstruction* arg) is consumed by the "DynamicSliceCsr" custom-call as its first of three operands when present, and the {sliced-csr, base, padded} is wrapped into the returned 3-tuple.
NOTE — operand identity of
bis structural. The twoHloInstruction*args toCreateDynamicSliceCsr(call themaandb) are read from theDecomposeForwardPasscall-site register convention (r8=a,r9=b), not from a named accessor.bis the minibatch index operand andathe csr operand. The register-identity dataflow insideCreateDynamicSliceCsris byte-exact; which SSA value is the per-table vs per-minibatch index is read structurally. TheSparseDenseMatmulConfigfield[+0x30]that the forward pass reads (≈FLAGS_xla_sparse_core_minibatch_max_division_level,0x222bd280) is a byte read whose proto field name follows from flag adjacency.
The decomposed operand-name tables
The operand layout the decomposition produces is fixed by two .data.rel.ro name tables, reloc-resolved. The forward pass has 7 operands, the backward pass 9 (entry 8 repeats the csr pointers). Operand 0, concatenated_csr_pointers, is the segment-id source that CreateDynamicSliceCsr slices per minibatch and that the SegmentedScan ultimately reduces over:
ForwardPassArgSpec::kForwardPassOperandNames (0x21937d80, 7 entries):
0 concatenated_csr_pointers <- segment-id source (per-minibatch sliced by CreateDynamicSliceCsr)
1 concatenated_embedding_ids (gather indices = sorted_token_ids)
2 concatenated_sample_ids (output rows)
3 concatenated_gains (combiner weights — the per-id gain the DotCombiner applies)
4 num_mini_batches_per_sparse_core (scalar)
5 embedding_table
6 activations_init
BackwardPassArgSpec::kBackwardPassOperandNames (0x21938320, 9 entries):
0..4 same as forward
5 tables
6 gradients
7 hyperparameters (SGD / Adam / Ftrl / Adagrad / AdagradMomentum families)
8 concatenated_csr_pointers (repeated)
The four concatenated_* operands map directly onto the SC embedding sum-lookup roles: csr_pointers → segment boundaries (the SegmentedScan's reset points), embedding_ids → the gather indices, sample_ids → output rows, gains → the per-id multiplicative weight the DotCombiner FMA applies verbatim.
| Table | Address | Entries |
|---|---|---|
kForwardPassOperandNames | 0x21937d80 | 7 |
kBackwardPassOperandNames | 0x21938320 | 9 |
Unit 3 — Packed-Operands Lowering (the SegmentedScanOp re-builder)
Once the inner SparseDenseMatmulOp ops exist, the SC dialect sparse_core::SegmentedScanOp they lower to must be re-created on packed-width operands (the SC vector engines pack sub-byte / bf16 data). PackedOperandsLowering (runOnOperation 0x135d8520, ctor CreatePackedOperandsLoweringPass 0x135d82a0) is the MLIR full-conversion pass that does this. It builds a ConversionTarget with per-op dynamic-legality callbacks (setLegalityCallback 0x1c957e40 / 0x1c958640) plus a RewritePatternSet, then runs mlir::applyFullConversion (0x1c958ac0). An op is legal iff its operands are already at the target packed width; otherwise the matching rewrite wraps it Unpack → op → Pack.
The SegmentedScanOp arm is registered by AddDynamicallyLegalScanOps<SegmentedScanOp> (legality lambda 0x135f3920) → ScanOpLowering<SegmentedScanOp,SegmentedScanOp>::matchAndRewrite (0x135f3000). The byte-confirmed rewrite body:
ScanOpLowering<SegmentedScanOp>::matchAndRewrite (0x135f3000) — decompile-confirmed
1 ReductionOp = SegmentedScanOp::getReductionOp() // 0x145fd460 — read reduction_op StringAttr
2 UnpackOperand<UnpackFOp>(data, …) // 0x1360fac0 — split bf16/sub-byte data (F path)
2' UnpackOperand<UnpackUIOp>(segment-id, …) // 0x136104e0 — split (unsigned-int path)
3 newop = SegmentedScanOp::create(b, loc, T, data, seg, reduction_op) // 0x145fd5a0 — re-create on packed ops
4 PackResults<PackFOp>(…) // 0x13610940 — re-pack F results
4' PackResults<PackUIOp>(…) // 0x13610de0 — re-pack UI results
This is the dialect SegmentedScanOp builder: it produces the op that the SC scan lowering (the segmented-scan emission, scan datapath) then turns into a sequencer intrinsic. The same pass legalizes ~40 other AluEp arith/math ops (AddF/I, MulF/I, DivF, Max/Min, CmpF/I, Exp, Tanh, Rsqrt, Clamp, …) each with its own Unpack{F,SI,UI} / Pack{F,SI,UI} pair — SegmentedScan/Scan are two members of that packing table. The plain (non-segmented) ScanOp arm is 0x135f2580 (legality lambda 0x135f2f60).
NOTE — the partition feeds the scan, not the reverse. The CSR partition (Unit 2) decides which row-pointers each minibatch's SegmentedScan reduces over;
PackedOperandsLoweringdecides what packed width that SegmentedScan's operands carry. They are orthogonal: the partition is HLO-level index arithmetic, the packing is dialect-level operand-width legalization. A reimplementer runs the minibatching decomposition first (producing per-minibatch inner ops over CSR slices), then the packed-operands lowering (re-creating each SegmentedScan on packed operands).
The activations / gradients layout adapter
EmbeddingDataFormattingDecomposer (RunImpl 0x1368b4a0) is the third pass — and the one most easily confused with the operand partition. It is not the CSR/id/gain reformatter. It reformats the dense per-table activations (forward output) and gradients (backward input) between the dense XLA tensor layout and the SC packed stacked-table layout. It matches four SparseCoreOperationType custom-calls:
| op-type | name | direction |
|---|---|---|
0x1a | SparseActivationsUnstack | SC packed → per-table dense (forward output) |
0x1b | SparseActivationsUnstackInterleaved | as above, interleaved |
0x1c | SparseGradientsStack | per-table dense → SC packed (backward input) |
0x1d | SparseGradientsStackInterleaved | as above, interleaved |
Each dispatches on EnableEmbeddingDataFormattingOffload(GetTpuCompEnv(op)) (0x1d6b94a0 / 0x1d73de80): true → the Sc (on-device SparseCore) variant; false → the Tc (host/TensorCore) variant. The Sc unstack (DecomposeActivationsUnstackSc 0x13682d40) splits the packed stacked-table activation into per-table dense rows via CreateSlice + CreateReshape + CreateConvert + CreateTuple, driven by StackedTableConfig::Extract (0x13681160), sized by ElementPackingFactor (0x1d6b03e0) × NumEmbeddingDevices (0x1d6b8a00). The Sc stack (DecomposeGradientsStackSc 0x13684d40) is the inverse: per-table dense grads → packed stacked grad via CreateConcatenate + CreatePad (pad to stacked extent) + CreateSlice + CreateConvert + CreateUnary.
GOTCHA — this pass does not feed the SegmentedScan operands. It would be natural to assume the "data formatting decomposer" reformats the CSR/id/sample/gain operands that the SegmentedScan reads. It does not. The CSR→segment-id provenance is
MinibatchingDecomposition(Unit 1).EmbeddingDataFormattingDecomposeronly touches activations/gradients (the dense ↔ packed-stacked-table layout adapter). Keep the two passes' operand domains separate.
Unit 4 — Binding: from inner op to the Stream-op DAG
Each per-minibatch inner SparseDenseMatmulOp the decomposition emits is itself lowered into an SC-dialect gather → sort → uniquify → segmented-reduce → scatter Stream-op DAG. That lowering is not done by any of the three passes above — it is the job of SparseDenseMatmulDotCombinerEmitter::Emit (0x1332bda0, via LoweringEmitter::EmitSparseDenseMatmulDotCombiner 0x131a7ca0), which splits into EmitValencyLoop (0x1332cee0), EmitVectorizedLoop (0x1332e1c0), and EmitSampleCombiner (0x1332c640). The DAG it produces:
| stage | SC dialect op | role |
|---|---|---|
| gather | IndirectStreamStartOp / IndirectVectorStreamStartOp (build 0x145cf440) | indirect HBM→SPMEM embedding-row load keyed by sorted token ids |
| sort | SortOp (build 0x14604480 / 0x146046c0) | lexicographic sort of (sample, token) ids |
| uniquify | UniqueOp / UniqueWithLaneIdsOp + DuplicateCountOp / …WithLaneIdsOp | dedup token ids, count dups per lane |
| reduce | SegmentedScanOp (reduction_op="sum") | per-segment sum-scan, reset at each CSR boundary |
| scatter | IndirectStreamAddStartOp / IndirectVectorStreamAddStartOp (build 0x145d4420) | scatter-ADD to HBM (forward drain / backward grad accumulate) |
The chain is: MinibatchingDecomposition produces the per-minibatch inner ops over CSR slices → PackedOperandsLowering re-creates each SegmentedScanOp on packed operands → the DotCombiner emitter wires the gather/sort/uniquify/scatter around that scan → LowerToSparseCoreLlvm lowers the dialect to the sequencer program. The concatenated_gains operand (Unit 2 operand 3) is the per-id multiplicative weight the DotCombiner FMA applies; the concatenated_csr_pointers (operand 0) is the segment boundary the valency loop reads as its per-sample trip count.
A separate, non-minibatch path exists for the gather-mul-scatter form: GatherMulScatterSparseDenseMatmulOpDecomposer (0x13c861e0) with the GatherEmitter::ScatterOperandSlicesToHbm{,ForSortedIndices,ForChunkGather,ForColumnWiseGather} family (0x138e6c80 …). That is the symmetric non-CSR lowering and is out of scope here.
NOTE — the deep emitter wiring is surveyed, not instruction-decoded here. The per-instruction gather/scan/scatter wiring inside
EmitVectorizedLoop/EmitSampleCombiner— whichIndirectStreamvariant per dtype, SPMEM tile sizes,Sfence/TileBarrier/SyncAddplacement, and theGetScheduleType(0x131d5300) valency-vs-vectorized selection — is owned by EmitValencyLoop and SampleCombiner Emitter. The DAG shape (the five stages above) is established by op::build/::createpresence; the instruction sequence inside the loops is documented on those pages.
Cross-References
- SparseCore Overview — the navigational entry for Part IX; engine names, per-gen presence, the embedding data path this decomposition opens.
- SparseCore Hardware Architecture — engine roles, the 4:1 SC:TC ratio, and the physical-core geometry
GetCoreIndexindexes. - SC Backend Pipeline — the twelve-pass SC-MLO pipeline the dialect ops produced here are lowered through (and the MEGACORE barrier).
- SC Core Selection — the physical-core selection policy that decides which SC cores each collective occupies (the counterpart to this page's per-core
GetCoreIndexpartition). - SampleCombiner Emitter — the per-sample gather-multiply-accumulate emitter that lowers each inner
SparseDenseMatmulOp; consumes theconcatenated_gainsoperand this decomposition partitions. - EmitValencyLoop — the per-id scalar loop whose trip count is the per-sample CSR segment length the
concatenated_csr_pointersoperand defines. - SC Scan Datapath — the segmented-scan emission the re-created
SegmentedScanOplowers into downstream. - Stream Gather / Scatter — the
IndirectStream*StartOpgather/scatter-add primitives the inner-op lowering issues. - getSequencerType — the SCS/TAC/TEC engine-selection function the lowered bundles route through.
- Binary:
extracted/libtpu-0.0.40-cp314-cp314-manylinux_2_31_x86_64/libtpu/libtpu.so(build-id89edbbe81c5b328a958fe628a9f2207d) - Index entry: Part IX — SparseCore & BarnaCore / SparseCore datapath (embeddings) — back to index