Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

HLO Pre-Passes

Symbol names, VAs, and the build-id below apply to libtpu.so from the libtpu-0.0.40-cp314 wheel (build-id 89edbbe81c5b328a958fe628a9f2207d). Other versions differ; treat every VA as version-pinned.

Abstract

The HLO pre-passes are the ordered set of xla::HloPassInterface passes that run on libtpu's incoming XLA program before it leaves the HloInstruction world for MLIR. They scrub free-form frontend graphs into a normalized shape the rest of the compiler can assume: scheduling annotations are legalized, linear-algebra and RNG custom-calls are decomposed into primitive arithmetic, dynamic shapes are statically padded, sub-byte and high-precision dtypes are bracketed, sharding is propagated and the program is SPMD-partitioned per device, layout is assigned, and a final pass materializes TPU frontend_attribute strings into typed backend_config so the MLIR import sees nothing TPU-private. Every pass here consumes HLO and emits HLO; the handoff to MLIR is the next phase, owned by overview.md.

This page is the full ordered enumeration of that pre-pass set — the table that compile-phases.md delegates to. It fixes, for every pass: its C++ class, the builder function that adds it, the pipeline-builder phase it sits in (1–6, matching the compile-phases.md numbering), what HLO invariant it consumes, what HLO invariant it guarantees on output, whether it is TPU-specific or open-source XLA, and its byte-anchored entry symbol where one was recovered. The five builder functions live in the (anonymous namespace) of deepsea_compiler_hlo_passes.cc and are reached from DeepseaCompilerBase::RunHloPasses (0x1093a420), the body of the public C-ABI export TpuCompiler_RunHloPasses and of the separate-compilation entry xla::CompilePhase1HloOptimizations (0xf84ee00).

This is the enumeration page. It does not re-derive the top-level phase ordering or each phase's entry symbol — that spine lives on compile-phases.md and is linked, not duplicated. It does not re-derive per-pass transformation algorithms (the actual graph rewrites) — those have their own pages: algebraic-simplifier.md for TpuAlgebraicSimplifier, sharding-propagation.md / auto-sharding-spmd.md for the sharding flows. The 372-entry HloPassInterface RTTI catalog with name() strings lives on hlo-pass-registry.md.

For reimplementation, the pre-pass contract is:

  • One container type. xla::HloPassPipeline is the only HLO pass container — there is no TPU-private pipeline class. Every TPU pass derives directly from xla::HloPassInterface (so RunImpl(HloModule*, execution_threads) returns absl::StatusOr<bool>) and is added with AddPass<T>(...).
  • Six phases, five builders. The ordering decomposes into six pipeline-builder phases. Phases 1–5 are HLO→HLO; phase 6 is the final pre-MLIR HLO domain. Five decompiled builder functions add the passes: PreOptimizationPipeline, AddAutoShardingAndRelatedPasses, AddTpuPartitioningPasses, HloOptimizeThroughLayoutAssignment, PostOptimizationPipeline.
  • The pre-passes legitimately introduce ops. Expanders emit new ops; the partitioner emits collectives. The acceptance gate (TpuHloSupportChecker) therefore runs in phase 4 after expansion and partitioning, not at the front — checking earlier would reject programs that are actually compilable.
  • Fixed-point loops. Several passes are wrapped in xla::HloPassFix<P> and re-run to convergence; HloDCE and HloCSE are re-run between most stages whether or not wrapped.
  • Invariant checkers re-run after every pass. MaybeAddInvariantCheckers (0x10944600) adds HloVerifier, LegalizeSchedulingAnnotations (as checker), and HloCycleDetection via the separate AddInvariantChecker<T> API at the head of every nested pipeline — they re-validate after each pass, not once.
Pipeline driverxla::jellyfish::DeepseaCompilerBase::RunHloPasses @ 0x1093a420
Sep-comp entryxla::CompilePhase1HloOptimizations @ 0xf84ee00
Builder — phase 1(anon)::PreOptimizationPipeline(Target const&, unique_ptr<HloModule>, long, CompilationStats*)
Builder — phase 2xla::jellyfish::AddAutoShardingAndRelatedPasses(...) @ 0x10939c40
Builder — phase 3xla::jellyfish::AddTpuPartitioningPasses(HloModule*, AliasInfo const*, PartitioningPipelineConfig) @ 0x1278a440
Builder — phase 4(anon)::HloOptimizeThroughLayoutAssignment(Target const&, RunHloPassesConfig const&, long, HloModule*, ComputationLayout*, CompilationStats*) — body in …::$_0::operator() @ 0x1094ada0 (std::function trampoline InvokeObject<…$_0> @ 0x1094ad80)
Builder — phase 5/6(anon)::PostOptimizationPipeline(Target const&, AliasInfo const*, HloModule*, long, CompilationStats*, bool, bool) @ 0x1093fd40
Pipeline opener(anon)::CreateHloPipeline(...) @ 0x1093efe0; nested variants CreateNestedHloPipeline @ 0x1093a180, CreateNestedHloPipelineFix @ 0x10953e20
Invariant checkers(anon)::MaybeAddInvariantCheckers(...) @ 0x10944600
Gather/scatter helper(anon)::AddGatherScatterExpanderPasses(Target const&, HloModule*, HloPassPipeline&) @ 0x1095a040
Acceptance gatexla::TpuHloSupportChecker::RunImpl @ 0x11071480 (vtable _ZTVN3xla20TpuHloSupportCheckerE)
Last HLO-domain passxla::jellyfish::ConvertFrontendAttributesToBackendConfig
Source unitplatforms/xla/service/jellyfish/deepsea_compiler_hlo_passes.cc (string-anchored)
ConfidenceCONFIRMED (byte-anchored) unless a row or callout says otherwise

Pipeline-Level Architecture

RunHloPasses (0x1093a420) is the monolithic driver. It builds and runs a sequence of xla::HloPassPipeline containers; the finer numbered stages map onto the five builder functions confirmed in the symbol table by their exact signatures:

RunHloPasses (0x1093a420)  ==  CompilePhase1HloOptimizations (0xf84ee00)
  ├─ PreOptimizationPipeline ........................ PHASE 1  input scrub
  │
  ├─ CreateHloPipeline (0x1093efe0)  opens the main pipeline
  │    ├─ AddAutoShardingAndRelatedPasses (0x10939c40) PHASE 2  sharding
  │    ├─ AddTpuPartitioningPasses (0x1278a440) ...... PHASE 3  SPMD prep
  │    ├─ HloOptimizeThroughLayoutAssignment (0x1094ada0) PHASE 4 pre-layout
  │    │    └─ LayoutAssignment  ........... (see layout-assignment.md)
  │    └─ PostOptimizationPipeline (0x1093fd40) ...... PHASE 5  post-layout
  │                                                    PHASE 6  pre-MLIR HLO
  └─ MLIR conversion  ..................... (see overview.md)

Two structural facts shape the whole set:

  • xla::HloPassFix<P> wraps a pass in a fixed-point loop. Vtable evidence in the binary shows it instantiated for HloPassPipeline, HloDCE, AllReduceReassociate, ReduceWindowRewriter, ReduceScatterReassociate, WhileLoopConstantSinking, AllReduceReduceScatterReorder, jellyfish::TpuReduceWindowRewriter, and jellyfish::MosaicFusion. The fixed-point mechanism is CONFIRMED. Its crash-on-non-convergence behaviour is reportedly gated by --xla_tpu_crash_if_hlo_pass_fix_did_not_converge; that flag string was not located in the sampled strings table. [Confidence: LOW on the crash flag; CONFIRMED on HloPassFix.]

  • Invariant checkers re-add at every nested pipeline head. MaybeAddInvariantCheckers (0x10944600) uses AddInvariantChecker<T> (not AddPass<T>), so each checker re-runs after every pass:

Checker classPurpose
xla::HloVerifier + xla::jellyfish::TpuVerifierMetadatafull HLO structural / shape verification (shape-size via an HloVerifierOpts::shape_size member-fn-ptr wrapped in std::function<int64_t(Shape const&)>)
xla::LegalizeSchedulingAnnotations (as checker)scheduling-annotation graph well-formedness
xla::HloCycleDetectionreject HLO graphs with circular dependencies

NOTE — decompile cross-check. All five builder signatures were confirmed verbatim in *_functions.json: PreOptimizationPipeline (lambdas $_0..$_8, predicate $_7 confirmed as the HloDomainIsolator argument), PostOptimizationPipeline (lambdas $_0..$_12), HloOptimizeThroughLayoutAssignment, AddAutoShardingAndRelatedPasses, AddTpuPartitioningPasses (takes PartitioningPipelineConfig), plus AddGatherScatterExpanderPasses, MaybeAddInvariantCheckers, and CreateHloPipeline. The four TPU vtables (TpuHloSupportChecker, jellyfish::TpuCallInliner, jellyfish::TpuInt2AutoUpDownCaster, jellyfish::TpuRngBitGeneratorTupleDecomposer) are present in *_names.json. [Confidence: CONFIRMED.]


The Ordered Pre-Pass Table

The master table below is the discovery-order enumeration recovered by walking each builder's AddPass<T> instantiations. The Phase column matches the pipeline-builder numbering on compile-phases.md: phase 1 = PreOptimizationPipeline, phase 2 = sharding, phase 3 = SPMD prep, phase 4 = HloOptimizeThroughLayoutAssignment (through layout assignment), phases 5–6 = PostOptimizationPipeline (post-layout refinement, then the final pre-MLIR HLO domain). Numbering is discovery order, not source contiguity; loop/fixed-point passes appear once but re-run.

The Src column: T = TPU-specific (xla::jellyfish::, xla::tpu::, or xla::Tpu* prefix); O = open-source XLA pass (some called with a TPU Target).

Phase 1 — PreOptimizationPipeline (input scrub)

#Pass classHLO-input invariantHLO-output invariantSrcNotes
1xla::LegalizeSchedulingAnnotationssched. annotations may be textualannotations normalized to backend-config formOgated by non_mitigatable_gap_checking config
2xla::CheckNoDataDependencyInSchedulingAnnotationssched. annotations on instructionsunchanged (verifier); error on dep cycleOinvariant checker
3xla::HloDomainIsolator (predicate $_7)sharding domains may be implicitexplicit kDomain ops wrap sharding regionsOonly if EnableDomainPasses()
4xla::DynamicIndexSplitterdynamic-index ops, multi-dim indexdynamic-index ops, split scalar indicesO
5xla::BatchNormExpander(true,true,true)BatchNorm{Training,Inference,Grad}only primitive arithmetic opsOrewrites to log/sqrt/divide
6xla::TpuCholeskyExpanderCholesky custom-calltile-recursive dots + triangular-solvesTsubclass of xla::CholeskyExpander
7xla::TpuQrExpanderQrDecompositionGivens-rotation HLO graphTsubclass of xla::QrExpander
8xla::LuDecompositionExpanderLuDecompositionscatter/gather decompositionO
9xla::TpuEighExpanderEighJacobi-rotation HLO graphTsubclass of xla::EighExpander
10xla::FftExpander(Target const&)Fft opCooley–Tukey radix-2, TPU tile sizesO*TPU-aware via Target arg
11xla::TpuTriangularSolveExpanderTriangularSolverecursive blocked solveTsubclass of xla::TriangularSolveExpander
12xla::jellyfish::TpuRngBitGeneratorExpanderRngBitGeneratorPhilox / ThreeFry HLO sequenceTwhen impure-RNG enabled
13xla::RngBitGeneratorExpanderRngBitGeneratoropen-source default algorithmOalternate path when 12 not selected
14xla::jellyfish::TpuRngBitGeneratorTupleDecomposertupled-output RngBitGeneratorseparate GetTupleElement per outputT
15xla::HloDCEanydead instr/computations removedOre-run between most stages
16xla::jellyfish::TpuInt2AutoUpDownCastermixed-precision arith with int2int2 ops bracketed by Convert ↔ int8TMXU wire min is int8
17xla::jellyfish::TpuCallInliner(MustFuseInlineMode)must_fuse-marked call sitescallees inlined into callerTgated by xla_tpu_impure_inline_must_fuse_early
18xla::jellyfish::UserGuidedFusionIdAssignerfrontend_attribute: fusion_id stringsfusion_id as integer backend-configTfeeds later fusion passes
19xla::ConditionalCanonicalizerConditional with arbitrary nestingcanonical form (single branch root)O
20xla::DynamicDimensionSimplifierdynamic-dim opsredundant dynamic-dim ops foldedO
21xla::DynamicPadder(DynamicPadderOptions)dynamic shapesstatic shapes + padding masksOruns dynamic-shape lowering (2× AddPass)
22xla::jellyfish::PreX64RewriterOptimizations(Target const&)anysub-ops pre-rewritten for x64 host codegenT
23xla::ScatterExpander(ScatterExpander::Mode)Scatter opWhile loop over indicesOmode = slow-path vs simplified
24xla::MapInlinerMap opcallee inlined as elementwise sequenceO
25xla::HloDomainRemover("sharding", ApplyDomainSharding)sharding kDomain bracketsbrackets removed, sharding stays as attributeOalways before phase 2
26xla::jellyfish::TpuHloPrecisionTracerinstruction precisionsprecision_config filled where missingT
27xla::BitcastDtypesExpanderdtype-only Bitcastreinterpret-cast HLO graphO
28xla::jellyfish::XPrecisionRewriter(kX128Precision)matmul with x128 precision8-step accumulation chainThigh-precision dot decomposition
29xla::jellyfish::XPrecisionRewriter()matmul with x6 / x9 precision2-step / 3-step accumulation chainTrun twice (cumulative; 2× AddPass confirmed)
30xla::ComparisonExpander({{S64,S32}})s64 comparisonss32 comparisons on hi/lo halvesOconfigured for S64→S32

O* = open-source class, TPU-specialized through a Target argument.

Phase 2 — AddAutoShardingAndRelatedPasses (sharding)

#Pass classHLO-input invariantHLO-output invariantSrcNotes
31xla::FlattenCallGraphnested callssingle-level call graphOre-run before every sharding pass
32xla::jellyfish::TpuCallInlinersmall reusable computationsinlined into caller for shardingT
33xla::HloDCEanydead instr/computations removedO
34xla::ShardingPropagationpartial sharding annotationsevery op has a sharding (or Replicated)Omanual-sharding flow (4× AddPass)
35xla::TpuAutoSharding(AutoShardingOption, Target*, AliasInfo*)unannotated HLOsharding on every opTauto-sharding flow; auto_sharding flag family
36xla::sdy::ShardyXLA(PropagationOptions, ...)Shardy-format sharding opsHLO sharding annotationsOJAX/Shardy frontend detected
37xla::jellyfish::TpuRngBitGeneratorTupleDecomposer(nullptr, bool)tupled RNG output post-shardingun-tupledTre-run after sharding
38xla::TupleSimplifierredundant tuple/GTE chainssimplifiedO

Phase 3 — AddTpuPartitioningPasses (SPMD prep)

#Pass classHLO-input invariantHLO-output invariantSrcNotes
39xla::spmd::SpmdPrepareHLO with shardingnormalized for SPMD partitioningO
40xla::ConvOperandSwapperconv with swapped operand layoutcanonical operand orderO
41xla::jellyfish::TpuSpmdConcatRewriter(Target const&)sharded Concatenatereplicated concatenate + sliceT
42xla::HloConstantSplitter(bool)shared constants across shardingsper-sharding constant copiesO
43xla::jellyfish::TpuPartitionAssignment(Target const&, long)sharded HLOpartition/device-id metadata attachedTname()"tpu-partition-assignment"
44xla::jellyfish::ConvolutionFolding(Target const&, bool)conv with foldable bias/activationfused conv-bias-activationTrun twice (pre-sharding, in-layout)
45xla::jellyfish::TpuSpmdPartitioner(...)sharded HLOper-partition HLO with collectivesTTPU subclass of SpmdPartitioner
46xla::RecognizeReduceWindowReduceWindow-shaped graphsReduceWindow op explicitO
47xla::CollectivePermuteCSEduplicate CollectivePermutededuplicatedO
48xla::WholeGraphManualPassmanual sharding on whole graphpassthrough sharding annotationO

Phase 4 — HloOptimizeThroughLayoutAssignment (through layout)

#Pass classHLO-input invariantHLO-output invariantSrcNotes
49xla::ZeroSizedHloEliminationops with zero-size operand/outputeliminated / empty constantO
50xla::TpuHloSupportCheckeranyunchanged (validator); error on unsupported ShapeTthe canonical TPU HLO acceptance test
51xla::ConvertMemoryPlacementToInternalAnnotationsfrontend_attribute: memory_spacebackend-config memory_space integerO
52xla::HloModuleDCEmodule-level dead computationsremovedO
53xla::ConvolutionTypeCanonicalizermixed-type convunified-type conv with explicit convertsO
54xla::ConvolutionPrecisionNormalizerconv with operand-only precisionnormalized precision_configO
55xla::BroadcastCanonicalizerbroadcast, non-canonical dim ordercanonical broadcastO
56xla::TransposeFoldingtranspose absorbable into dot/convdot/conv with permuted operandsO
57xla::ConvertOperandFoldingConvert absorbable into dot/convdot/conv with mixed-precision operandO
58xla::HloCSE(bool)duplicate pure opsshared single opOre-run between most stages
59xla::HloPassFix<xla::jellyfish::TpuReduceWindowRewriter>ReduceWindow with non-trivial windowrepeated rewrites until canonicalTfixed-point
60xla::jellyfish::TpuAlgebraicSimplifier(Target, AlgSimpOptions)anyalgebraic-simplified, TPU-awareTsuperset of xla::AlgebraicSimplifier — see algebraic-simplifier.md
61xla::GatherOptimizer(Target const&)Gather opTPU-friendly Gather (split/decomposed)O*
62xla::AllReduceSimplifierAllReduce, degenerate replica groupssimplifiedO
63xla::jellyfish::TpuAllGatherSimplifier(Target const&)AllGather, degenerate replica groupssimplifiedT
64xla::AllToAllDecomposer(bool, int)AllToAll with split-dimper-partition slice + AllToAllO
65xla::jellyfish::RaggedAllToAllExpander(long)RaggedAllToAll custom-calldense AllToAll + scatterT
66xla::SortSimplifierredundant Sort operandssimplified SortO
67xla::jellyfish::TpuReduceRewriter(bool)Reduce with multiple outputsper-output reducesT
68xla::jellyfish::TpuDegenerateDimensionRewriterops with size-1 batch dimsize-1 dim eliminated via reshapeT
69xla::jellyfish::TpuBroadcastRewriterbroadcast, unfavorable target dimreshape + broadcast to TPU-favored dimT
70xla::jellyfish::TpuReduceRewriteras 67, no-flag variantper-output reducesT
71xla::ReduceWindowResizerReduceWindow, non-pow-2 windowresized window via paddingOname()"reduce-window-resizer"
72xla::WhileLoopConstantSinking(bool)While carrying constantsconstants sunk into bodyO
73xla::WhileLoopSimplifier(bool)While with constant tripunrolled / simplifiedO
74xla::WhileLoopConcatCodeMotion(long)While, Concatenate of invariantsconcatenate hoisted outO
75xla::HloConstantFoldingfoldable constant opsfolded into literalO
76xla::jellyfish::TpuConditionalSimplifier(Target const&)Conditional with TPU patternssimplified or rewrittenT
77xla::DeadDynamicUpdateSliceEliminationDUS chains with dead targetsdead DUS droppedO
78xla::conditional_opt::ConditionalCodeMotion(...)code identical in cond. branchescode hoisted before / sunk afterO
79xla::jellyfish::SortMergeradjacent Sort over compatible keysmerged into single SortT
80xla::ScanExpanderScan HLOWhile loop of partial reducesO
81xla::StableSortExpanderSort requesting stableaugmented-key Sort + post-stripO
82xla::InfeedTokenPropagationInfeed without explicit token edgestoken edges threaded for orderingO
83xla::jellyfish::InfeedDecomposerInfeed opDMA + token sequence (TPU host-transfer)T
84xla::jellyfish::OutfeedDecomposerOutfeed opDMA + token sequenceT
85xla::megascale::compiler::TpuAllReduceMerger(Target, mapper)per-slice AllReducecross-slice AllReduce mergedTMegaScale path
86xla::megascale::compiler::CrossSliceLegalizer(Target const&)cross-slice opslegalized for MegaScale topologyTMegaScale path
87xla::TpuGatherScatterFlattener(Target, long)high-rank gather/scatterrank-flattened gather/scatterTin AddGatherScatterExpanderPasses
88xla::TpuGatherExpander(Target const&)Gather that can be expandedWhile-loop of slicesT
89xla::TpuScatterExpander(Target const&)Scatter that can be expandedWhile-loop of DUST

Layout assignment (Phase2PreLayoutAssignment / TpuLayoutAssignment) runs at the tail of phase 4 and is owned by layout-assignment.md.

Phases 5–6 — PostOptimizationPipeline (post-layout refinement, then pre-MLIR)

#Pass classPhaseHLO-input invariantHLO-output invariantSrcNotes
90xla::HloDomainIsolator (predicate $_7)5re-add domain brackets if EnableDomainPasses()kDomain brackets re-addedOinverse of step 25
91xla::HloCSE(bool=false)5duplicate pure opsshared single opOpost-layout CSE
92xla::jellyfish::WrapFusionOutputForDebug5fused HLOfusion outputs wrapped with kCopy for debug captureTconditional
93xla::jellyfish::AlwaysCrash5anyintentionally fails (testing pass)Tgated by xla_tpu_always_crash
94xla::AddOriginalValue5HLO post-layouteach instruction tagged with provenance metadataOname()"add-original-value"
95xla::jellyfish::AddRandomHostOffloading(double)6anyrandom instr wrapped with host-offload custom-callsTdebug pass, flag-gated
96xla::jellyfish::ConvertFrontendAttributesToBackendConfig6frontend_attribute stringsparsed into typed backend_config protobufTruns last in HLO domain
97xla::HloHostDeviceTypeCallWrapper(Options)6host_compute call siteswrapped with type-call markers for MLIR importOfinal pre-MLIR pass

GOTCHA — the acceptance gate (#50) is in phase 4, not phase 1. A reimplementation that wants to reject unsupported ops early will be tempted to front-load TpuHloSupportChecker. libtpu does not: it runs after the expanders (phase 1) and the SPMD partitioner (phase 3) because those passes legitimately introduce new ops (expander outputs, partitioned collectives) that must themselves pass the check. Checking before expansion would reject compilable programs. The checker never mutates — it walks every HloComputation and validates each result Shape with ShapeUtil::ValidateShapeWithOptionalLayout, returning an error Status on the first unsupported shape. Entry RunImpl @ 0x11071480. [Confidence: CONFIRMED on RunImpl VA + vtable; HIGH that name() returns "tpu-hlo-support-checker" (the literal did not surface in the sampled strings table).]


TPU-Specific vs Open-Source Split

Of the 97 pipeline-mentioned passes, 33 carry an explicit TPU prefix (xla::jellyfish::, xla::tpu::, or xla::Tpu*); the remainder are open-source XLA passes, some parameterized with a TPU Target. The TPU-specific ones cluster by role:

  • Custom-call decomposition (expanders): TpuCholeskyExpander, TpuQrExpander, TpuEighExpander, TpuTriangularSolveExpander, TpuGatherExpander, TpuScatterExpander, jellyfish::TpuRngBitGeneratorExpander (+ TupleDecomposer), jellyfish::RaggedAllToAllExpander, jellyfish::Infeed/OutfeedDecomposer.
  • Shape / canonicalization rewriters: jellyfish::TpuBroadcastRewriter, jellyfish::TpuDegenerateDimensionRewriter, jellyfish::TpuReduceRewriter (flag + no-flag), HloPassFix<jellyfish::TpuReduceWindowRewriter>, xla::TpuGatherScatterFlattener, xla::TpuGatherSplit, jellyfish::TpuConvolutionTypeCanonicalizer.
  • Algebraic / simplifier: jellyfish::TpuAlgebraicSimplifier (superset of the open-source simplifier — see algebraic-simplifier.md), PostFusionTpuSubgraphSimplifier, TpuTrivialFusionRemover, TpuTrivialInstructionUnfuser.
  • Call-graph / acceptance: jellyfish::TpuCallInliner (must-fuse aware; name()"tpu-call-inliner" / -must-fuse / -inner-must-fuse / -non-must-fuse), TpuHloSupportChecker.
  • Sharding / partitioning: TpuAutoSharding (wraps xla::AutoSharding), jellyfish::TpuSpmdPartitioner, TpuPartitionAssignment, TpuSpmdConcatRewriter.
  • Precision / dtype: jellyfish::TpuInt2AutoUpDownCaster, XPrecisionRewriter (x6/x9/x128 dot precision), TpuHloPrecisionTracer.
  • Custom-call / fusion prep: jellyfish::MosaicFusion (Pallas/Mosaic kernels), UserGuidedFusionIdAssigner, WrapFusionOutputForDebug, TpuCustomCallScopedVmemAdjuster (confirmed AddPass in PostOptimizationPipeline::$_12).

A further set of TPU passes is present in the binary's RTTI (vtable-only) but was not observed inside any of the five decompiled builders — they are added by configuration-flag-gated branches deeper inside RunHloPasses. These include the SparseCore offload family (SparseCoreComputeOffloader, OffloadGatherToSparseCore, etc.), the async-collective family (TpuAsyncCollectiveCreator, AsyncCollectiveMerger, AsyncOpScheduler), the fusion family (TpuInstructionFusion, TpuMultiOutputFusion), FlashAttention, Gmm (group-matmul), and the int4 path TpuInt4Rewriter. [Confidence: CONFIRMED these classes exist (vtable RTTI); LOW on their pipeline position and gating.]


Custom-Call Targets Consumed at the HLO Boundary

The pre-passes consume / produce HLO kCustomCall ops with these target strings (recovered from the string section). They form the surface between the JAX/PyTorch frontend and the TPU compiler:

Target stringConsumed/produced byRole
ShardingShardingPropagationsharding boundary marker
SPMDFullToShardShape / SPMDShardToFullShapeSPMD loweringsharding lowering helpers
mhlo.sharding / _XlaSharding (attrs)sdy::ShardyXLAShardy/MHLO import sharding markers — carried as MLIR attributes, not as a Sharding-mhlo custom-call target
RngBitGeneratorTpuRngBitGenerator*RNG, decomposed by the RNG expander family
TopKkept opaque, lowered later
tpu_custom_callMosaicFusion / Mosaic emitthe registered Pallas/Mosaic kernel custom-call target (CustomCallRegistration::RegisterCompilationProperties("tpu_custom_call", …)); generic TPU custom-call wrapper
MoveToHost / MoveToDevicehost-offload legalizerhost-offload markers
Pinjellyfish::PinPrecoloringprecoloring marker
inspect_shardingjellyfish::RemoveInspectShardingCustomCalldebug-only, removed

__cudnn$convForward and similar GPU targets are not consumed by any TPU pass — they are rejected by TpuHloSupportChecker (#50).


What Is Not Recovered

  • Exact ordering of the vtable-only TPU passes. The SparseCore, async-collective, fusion, FlashAttention, and int4 families appear in RTTI but were not observed in the five decompiled builders; they live in flag-gated branches deeper inside RunHloPasses not walked in this pass. [Confidence: LOW on position.]
  • Per-TpuVersion divergence. The decompilation reflects one pipeline that branches on Target / TpuCompilationEnvironment flags; which passes are skipped for, e.g., TPU v3 vs v6e was not isolated. [Confidence: LOW on per-gen differences.]
  • Sharding-flow selector logic. Which of manual / auto / Shardy runs is decided by flags plus frontend-attribute detection inside the builders; the precise dispatch branch was not isolated (the auto_sharding flag family is CONFIRMED present in the string table, but the exact selector is LOW). See auto-sharding-spmd.md.
  • Per-pass rewrite algorithms. This page recovers names + order + I/O invariants, not the internal graph rewrites. The simplifier algorithm has its own page (algebraic-simplifier.md); the rest are pending.
  • The --xla_tpu_crash_if_hlo_pass_fix_did_not_converge flag string was not found in the sampled strings table. [Confidence: LOW.]

Cross-References