Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

NormalizedComputationCost

All addresses on this page apply to libtpu.so from the libtpu-0.0.40-cp314 wheel (BuildID md5 89edbbe81c5b328a958fe628a9f2207d, 781,691,048 bytes, not stripped — every symbol is a demangled C++ name). Section map: .text/.rodata VMA == file offset; .data.rel.ro VMA − 0x200000 == file offset. All addresses are VMA. Other libtpu builds will differ.

Abstract

TpuPriorityFusionQueue::NormalizedComputationCost (@ 0x130989a0) is the compute term the TPU priority-fusion queue subtracts when it ranks a producer→consumer fusion. It answers one question per HLO op — how much vector/matrix work does this op add to a fused region — and returns a single double. Unlike the bundle-occupancy model (TpuHloCostAnalysis), which deposits per-op cycles into a 23-slot ResourceVector, this function collapses each op to a scalar weight multiplied by Target::ChunksIn(shape) (the op's chunk-granule element count). The LLVM analogue is a TargetTransformInfo::getInstructionCost query that returns a coarse "expense class" rather than a cycle-accurate latency — and like TTI, it is a ladder of a few discrete weights, not a continuous model.

The weight ladder is a switch over HloOpcode compiled to a 106-entry self-relative jump table at .rodata 0xae0dcdc (index = opcode − 0x18, opcode > 0x81 → default). It resolves to exactly six scalar tiers0.0 (free data-layout ops), 1.0 (the default cheap elementwise), 2.0 (a fused parameter read), 4.0 (logistic / reduce / cross-lane broadcast), 10.0 (divide), 42.0 (erf) — plus three structural escapes: a convolution that is priced by flop_count / peak rather than the ladder, a fusion that is priced by recursively summing its body's per-op costs, and a dot that is a CHECK-fatal because dots must have been lowered to convolutions before this point.

GetCyclesIfFused (@ 0x130aba40) is the other half of the priority formula — the bundle cost of fusing a (producer, consumer) pair. It is not a hand-written producer+consumer merge; it builds a FusionState that augments the consumer's operand set with the producer, then re-prices the merged op through the same GetHloResourcesImpl machinery used for any fusion. The cross-functional-unit "packing" — back-to-back MXU ops overlap, the producer→consumer HBM round-trip is dropped, the input-DMA startup latency is paid once — happens inside that shared machinery (ScaleAndSumOutputFusionResourceVectors @ 0x130b8320), which is what makes total_fused < total_unfused and the fusion priority positive. This page documents both: the opcode→weight table with its byte-exact constants, and the GetCyclesIfFused fused-merge driver.

For reimplementation, the contract is:

  • The six-tier opcode→weight table, the .rodata weight constants, and the per-class rationale (why divide=10, erf=42, why reduce uses the operand's ChunksIn and parameter is ×2 only for the first two slots).
  • The LoopFusion operand-accumulation pre-path and its element-type gate mask.
  • The broadcast cross-lane path and the convolution flop→cycle formula.
  • The fusion recursive-sum-and-cache path (it is not a pure lookup).
  • GetCyclesIfFused's eligibility gates, the conv-like main-op selection, the max-pool FLT_MAX sentinel, and the FusionState → GetHloResourcesImpl → ResourceVector::Add merge driver.
  • The four-sub-emitter ScaleAndSumOutputFusionResourceVectors combine with its explicit slot-9/11 MAX-combine.
Compute weightTpuPriorityFusionQueue::NormalizedComputationCost(HloInstruction*, long) @ 0x130989a0
Jump table.rodata 0xae0dcdc — 106 × i32 self-relative (index = opcode − 0x18)
Weight tiers0.0, 1.0, 2.0, 4.0, 10.0, 42.0 + conv-flop / fusion-recurse / dot-fatal
Chunk countTarget::ChunksIn(Shape&) @ 0x1d619900
Fused mergeCostModel::GetCyclesIfFused(producer, consumer, opts, ResourceVector*) @ 0x130aba40
Merge coreScaleAndSumOutputFusionResourceVectors @ 0x130b8320 (4 sub-emitters; slots 9/11 MAX)
Eligibility gateIsFusionSupportedHlo @ 0x130abee0
Source filestpu_instruction_fusion.cc (weight) · cost_model/cost_model.cc (fused merge)

NormalizedComputationCost — the Opcode→Weight Ladder

Purpose

This is the compute penalty half of the bundle-aware fusion priority (priority = total_unfused − total_fused). For a candidate fusion edge it estimates how much vector/matrix work the consumer (or, recursively, a nested fusion) adds. The result is one double; the priority queue subtracts it scaled by the user count. The function takes (HloInstruction* inst, long operand_index): rbx is the owning TpuPriorityFusionQueue this (holding the conv flop cache and a Target* at this+0x170 / *((q*)this+46)); inst is the op being priced; operand_index drives both a recursion guard and the parameter-slot tier.

Entry Point

NormalizedComputationCost (0x130989a0)                 ── double, computes a scalar weight
  ├─ (operand_index>0 + binary + iota/broadcast op0)   ── recursion guard → flop/return tail
  ├─ LoopFusion pre-path (0x13098a13)                  ── multi-operand elementwise estimate
  │     └─ Target::ChunksIn (0x1d619900) × (1 + #non-zero-minor operands)
  └─ per-opcode switch (0x13098b62)                    ── jump table @ .rodata 0xae0dcdc
        ├─ 0.0  block  0x130995ee   (layout/metadata ops)
        ├─ 1.0  block  0x13098e2f   (DEFAULT — cheap elementwise)
        ├─ 2.0  block  0x13098de1   (parameter, operand_index ≤ 1)
        ├─ 4.0  blocks 0x13098d8e / 0x13098d74 / 0x13098b87  (logistic / reduce / broadcast)
        ├─ 10.0 block  0x13098d52   (divide)
        ├─ 42.0 block  0x13098e49   (erf)
        ├─ conv block  0x13098db0   (flop_count / peak — see CONV FORMULA)
        ├─ fusion block 0x13098e09  (cached, else recursive sum of body)
        └─ dot block   0x130996c5   (CHECK-fatal "Dots should have been replaced…")

Algorithm

double NormalizedComputationCost(this /*rbx*/, inst /*r14*/, operand_index /*r15*/):  // 0x130989a0
    // (0) RECURSION GUARD — when called per fused-edge with operand_index>0,
    //     a binary op whose operand(0) is iota(0x43) or broadcast(0x1a) returns 0.0.
    if operand_index > 0:
        if (inst.operand_count_field & ~1) == 2:          // [inst+0x10] low form == 2
            op0 = inst.operand(0).opcode;                 // [operand(0)+12]
            if op0 == 0x43 /*iota*/ or op0 == 0x1a /*broadcast*/:
                return 0.0;
        goto per_opcode_switch;                           // LABEL_29

    // (1) LOOP-FUSION PRE-PATH — multi-operand elementwise estimate.
    if operand_index == 0
       and inst.IsLoopFusion()
       and root_element_type in NUMERIC_MASK            // 0x2FFF91FFE | {0x20,0x21} | 0x400048000
       and inst.fused_instructions_computation().instruction_count() <= 254:  // [comp+88]
        multiplier = 1.0;                                 // qword_A2DF230
        if inst.operand_count >= 2:                       // [inst+0x10] >= 2
            for each operand i:
                op_minor   = operand(i).shape.dimensions()[0] >> 1;   // minor dim, halved
                if op_minor != 0: multiplier += 1.0;      // +1 per NON-zero-minor operand (gated)
                root_minor = root.shape.dimensions()[0] >> 1;
                if op_minor >= root_minor: break;         // operand as-wide-or-wider → abandon
            else:
                return Target::ChunksIn(root.shape) * multiplier;     // estimate kept
        // any operand as-wide-or-wider, or 0/1 operands → fall through to the switch
        goto per_opcode_switch;

per_opcode_switch:                                        // 0x13098b62
    xmm0 = 0.0;                                           // pre-zeroed (the 0.0 tier needs no store)
    op = inst.opcode;                                     // byte [inst+0xc]
    if (op - 0x18) > 0x69: goto default_1_0;
    switch (op):                                          // jump table @ .rodata 0xae0dcdc

      case 0x18,0x27,0x29,0x2A,0x43,0x61,0x81:            // bitcast/concat/constant/convert/
        return 0.0;                                       //   iota/reshape/tuple — no compute

      case 0x1A /*broadcast*/:  return BroadcastWeight(inst);          // see BROADCAST PATH
      case 0x2B /*convolution*/: return ConvolutionWeight(this, inst); // see CONV FORMULA
      case 0x32 /*divide*/:     return Target::ChunksIn(root) * 10.0;  // qword_A2DF498
      case 0x38 /*erf*/:        return Target::ChunksIn(root) * 42.0;  // qword_A2DF1A0
      case 0x47 /*logistic*/:   return Target::ChunksIn(root) *  4.0;  // qword_A2DE830
      case 0x5B /*reduce*/:     return Target::ChunksIn(operand(0).shape) * 4.0;  // OPERAND, ×4
      case 0x52 /*parameter*/:
        if operand_index <= 1:  return Target::ChunksIn(root) * 2.0;   // vaddsd self
        else:                   return 0.0;
      case 0x3D /*fusion*/:     return FusionWeight(this, inst);       // see FUSION PATH
      case 0x34 /*dot*/:        LOG(FATAL) "Dots should have been replaced by convolutions.";
                                                          // tpu_instruction_fusion.cc:863

      default:                  return Target::ChunksIn(root) * 1.0;   // vcvtsi2sd, no vmulsd

NOTE — the 0.0 tier needs no work because xmm0 is zeroed before the dispatch (the vxorpd xmm0,xmm0 at 0x13098b5e); the 1.0 default emits only a vcvtsi2sd of ChunksIn with no multiply. Every other tier is a vmulsd by a .rodata double. A reimplementation can fold the 0.0 and 1.0 cases into the chunk-count computation, but the dot case is a hard CHECK — it must abort, not return a number.

The Weight Table

root is inst.shape; ChunksIn(s) = Target::ChunksIn(s) (@ 0x1d619900) is the op's chunk-granule element count. Jump-table targets are byte-verified from .rodata 0xae0dcdc; weight constants are byte-verified from .rodata.

Weight / pathConstantBlock @Opcodes (hex = name)
0.0 — no cost(xmm0 pre-zeroed, ret)0x130995ee0x18 bitcast, 0x27 concatenate, 0x29 constant, 0x2A convert, 0x43 iota, 0x61 reshape, 0x81 tuple
1.0ChunksIn(root) (DEFAULT)(no vmulsd)0x13098e2fevery opcode 0x19..0x80 not listed below (cheap unary/binary elementwise, structural, collective, control-flow, I/O)
2.0ChunksIn(root) × 2vaddsd self0x13098de10x52 parameter — only when operand_index ≤ 1; else 0.0
4.0ChunksIn(root) × 40xa2de830 = 4.00x13098d8e0x47 logistic
4.0ChunksIn(operand(0)) × 40xa2de830 = 4.00x13098d740x5B reduce (priced over the reduced-over operand, not the root)
4.0 — broadcast cross-lane0xa2de830 = 4.00x13098b870x1A broadcast (conditional — see BROADCAST PATH)
10.0ChunksIn(root) × 100xa2df498 = 10.00x13098d520x32 divide
42.0ChunksIn(root) × 420xa2df1a0 = 42.00x13098e490x38 erf
conv flop path(see CONV FORMULA)0x13098db00x2B convolution
fusion recurse / cacheflat_hash_map @ this+0x800x13098e090x3D fusion
dot CHECK-fatalLogMessageFatal0x130996c50x34 dot

The jump table has exactly 11 distinct targets: the ten listed above plus the shared 0.0 block (0x130995ee). Opcode→name resolution follows the alphabetical XLA HloOpcode enum (0x18 kBitcast … 0x81 kTuple); the load-time values used by the function — 0x34 dot (fatal), 0x38 erf, 0x47 logistic, 0x52 parameter, 0x5B reduce — are corroborated by the function's own opcode comparisons (e.g. GetCyclesIfFused tests opcode != 91 for the reduce special-case, 91 == 0x5B).

Considerations — Why the Ladder Has These Tiers

The four scalar tiers above 0.0/1.0 rank elementwise expense, not op category:

  • 0.0 ops are pure data-layout / metadata — bitcast, reshape, convert, concatenate, constant, iota, tuple. Fusing them adds no functional-unit work, so they never penalise the priority. This matches the bundle model's ZERO arm (TpuHloCostAnalysis).
  • divide is 10.0 — it is the costliest cheap elementwise op: a reciprocal-plus-multiply micro-sequence. The bundle model agrees, expanding divide into a four-deposit reciprocal sequence.
  • erf is 42.0 — the single most expensive elementwise op in the model: a long polynomial. This 42.0 is the elementwise-erf weight, not a matmul/conv class. Matmul and convolution are priced by the separate flop path below; there is no ×42 matmul branch.
  • reduce uniquely multiplies by the OPERAND's ChunksIn (operand(0).shape), not the root's. A reduce's compute scales with the large input it reduces over, not the small reduced output — the priority mirror of HandleReduce's ExtentProduct(operand) flop formula.
  • parameter is ×2 only for operand_index ≤ 1 — the first two fused-parameter slots cost the read-plus-forward of the param tensor; higher slots are free. This is the only tier gated on operand_index rather than opcode.

GOTCHA — 42.0 is the erf elementwise weight (block 0x13098e49, constant 0xa2df1a0), not a matmul/conv class. Convolution takes the separate flop_count / peak path (block 0x13098db0); there is no opcode mapping to a ×42 matmul tier. Pricing matmul at ×42·ChunksIn is wrong by orders of magnitude on large convs.


The LoopFusion Operand-Accumulation Pre-Path

Purpose

Before the per-opcode switch, a loop (elementwise) fusion gets a special multi-operand estimate at 0x13098a13. The intuition: a loop fusion's compute is roughly its output chunk count times one unit per operand that must be broadcast/expanded to reach the output width. An operand already as wide as the output costs nothing extra; a narrower operand (needing a lane/sublane expansion) adds one unit.

Algorithm

// reached only when operand_index == 0
if inst.IsLoopFusion()
   and element_type_in_numeric_mask(root.shape.element_type)
   and fused_instructions_computation().instruction_count() <= 254:   // 0xFE cap, [comp+88]
    multiplier = 1.0;
    if inst.operand_count >= 2:
        for each operand i in [0 .. operand_count):
            op_minor   = operand(i).shape.dimensions()[0] >> 1;     // [operand.shape+8] >> 1
            if op_minor != 0: multiplier += 1.0;                    // gated: skip zero-minor operands
            root_minor = root.shape.dimensions()[0] >> 1;           // re-read [inst.shape+8] each iter
            if op_minor >= root_minor:
                goto per_opcode_switch;          // operand as-wide-or-wider → abandon estimate
        return Target::ChunksIn(root.shape) * multiplier;
    goto per_opcode_switch;                       // 0 or 1 operand → switch on the root opcode

The numeric element-type gate is the same mask used throughout the cost model: _bittest64(0x2FFF91FFE, et) for et ≤ 0x21, OR (et & ~1) == 0x20, OR _bittest64(0x400048000, et) for et ≤ 0x22 — covering the numeric and packed types the bundle model can price (the same mask appears in GetHloResourcesImpl and GetCyclesIfFused). The ≤ 254 instruction-count cap bounds the estimate to small fusions.

GOTCHA — the loop reads dimensions()[0] >> 1 for both the operand and the root minor dimension, and the comparison is op_minor >= root_minor. The >> 1 (halving) applies symmetrically, so it cancels in the comparison — it is the minor dimension in chunk units. The estimate is abandoned (falls to the switch on the fusion-root opcode, i.e. case 0x3D) the instant any operand is as wide as the output; only "all operands strictly narrower than the output" keeps the estimate. Two further byte-level subtleties: (1) the multiplier += 1.0 increment is gated on op_minor != 0 — an operand whose halved minor dim is zero is counted in the loop but does not bump the multiplier (the vaddsd result is discarded when v17 == 0 at 0x13098…/decompile line 732); (2) root_minor is re-read from inst.shape on every iteration rather than hoisted. The kept estimate is therefore ChunksIn(root) × (1 + #non-zero-minor operands), not (1 + #operands). A reimplementer who returns the estimate unconditionally, or who bumps the multiplier for every operand, will over-cost fusions whose operands are already full-width or zero-minor.


Broadcast Path (opcode 0x1A, block 0x13098b87)

Algorithm

double BroadcastWeight(inst):                              // 0x13098b87
    target = this.Target;                                  // *((q*)this+46)
    if target[0x398] == 0: return 0.0;                     // broadcast-cost flag off (Target+920)
    operand = inst.operand(0);
    if operand.shape.rank > 3: return 0.0;                  // high-rank broadcasts not priced
    phys = LayoutUtil::MakeLogicalToPhysical(inst.shape.layout);
    sort(inst.dimensions());                                // ascending
    if ShapeUtil::IsEffectiveScalar(operand.shape):         return 0.0;   // splat of a scalar
    if phys.minor_most_dim is a broadcast dimension:        return 0.0;   // free minor-axis splat
    return Target::ChunksIn(inst.shape) * 4.0;             // cross-lane movement, qword_A2DE830

A broadcast along the minor-most (sublane/lane) axis is a free register splat; a broadcast that materialises across lanes costs the 4.0 mid weight. The whole path is gated by the per-Target broadcast-cost flag at Target+0x398 (= Target+920) — when clear, every broadcast is 0.0.

QUIRK — the bundle-occupancy model prices a broadcast at zero unconditionally (TpuHloCostAnalysis ZERO arm), while this priority model may charge 4.0·ChunksIn for a cross-lane broadcast. The two surfaces disagree on layout ops by design — the bundle path treats a broadcast as zero functional-unit occupancy; the priority path accounts for the cross-lane data movement. Do not unify them.


Convolution Formula (opcode 0x2B, block 0x13098db0)

Algorithm

Convolution escapes the scalar ladder entirely: it is priced by its flop count converted to MXU cycles. The flop count is cached per unique_id so repeated priority queries are cheap.

double ConvolutionWeight(this, inst):                      // 0x13098db0
    // (1) FLOP CACHE — flat_hash_map<int64,double> at this+0xd0 / this+208
    uid  = inst.unique_id();
    flop = this.flop_cache.find(uid);
    if MISS:
        TpuHloCostAnalysis ca(this.Target, /*model_tpu_specific=*/true);   // ctor 0x130a1620
        st = ca.HandleConvolution(inst);                   // 0x1e480be0
        CHECK(st.ok()) << "hlo_cost_analysis.HandleConvolution(instruction) is OK";
                                                            // tpu_instruction_fusion.cc:871
        flop = ca.flop_count(inst);                        // 0x1e4841e0
        this.flop_cache.emplace(uid, flop)                 // CHECK emplace.second @ :875
    // (2) FLOP → CYCLES
    if inst.batch_group_count() == 1 and inst.feature_group_count() == 1:   // dense conv
        fmt        = LhsFormatForConvInstruction(inst, this.Target);  // 0x1307bd40
        peak_flops = Target::FlopsPerSecond(fmt);          // vtable+0x718, 0x1d61f280
        freq_MHz   = Target::TensorCoreFrequencyInMegaHertz();        // 0x1d615b60
        peak_per_cycle = peak_flops / freq_MHz / 1e6;      // qword_A2E0208 = 1e6
        mxu_cycles = flop / peak_per_cycle;
        valu_slots = Target::VectorAluSlotsPerTensorCore();// vtable+0x500, 0x1d61e380 (int)
        derate     = Target[0x4ac] * (-0.03) + 1.0;        // A2E05A8 = -0.03; A2DF230 = 1.0
        return (valu_slots * mxu_cycles) / derate;
    else:                                                  // grouped / depthwise conv
        return flop * 0.00048828125;                       // A2E0118 = 1/2048

A dense conv's cost is its flop count converted to MXU cycles at the per-LHS-format peak rate, scaled by the vector-ALU slot count and divided by a (1 − 0.03·N) headroom derate. Grouped convolutions — which the MXU does not accelerate as well — use a flat flop/2048 estimate. The HandleConvolution flop is the same one documented in TpuHloCostAnalysis (it divides by both group counts); the conv-shaped state it walks is on ConvolutionCostState.

NOTE — the absolute per-gen integers Target::FlopsPerSecond(format), Target::VectorAluSlotsPerTensorCore(), and the Target+0x4ac derate operand are per-codename / chip_parts-sourced and not enumerated here (only the v7 chip parameters are embedded in this build). The formula is byte-exact; the constants are owned by the per-gen MXU pages (MXU Latency Overview, MatmulMode and Modifiers). LhsFormatForConvInstructionMatmulDataFormat selection is also not decoded here (MEDIUM on which peak each conv selects).


Fusion Path (opcode 0x3D, block 0x13098e09)

Algorithm

A fusion op is priced by the sum of its body's per-op costs, recursively, with the result cached per HloInstruction*. This is not a pure lookup — on a cache miss it walks the fused computation's root chain and recurses through NormalizedComputationCost for each fused instruction (with operand_index advanced, which is exactly what the recursion guard at the top of the function protects against double-counting iota/broadcast operands).

double FusionWeight(this, inst):                           // 0x13098e09
    // cache: flat_hash_map<HloInstruction*, double> at this+0x80 / this+128
    hit = this.fusion_cache.find(inst);                    // 0x13098e09
    if hit: return hit->second;

    comp = inst.fused_instructions_computation();          // 0x13098f1c
    cost = 0.0;
    // walk the fused instruction list (16-byte stride), skipping null slots,
    // tracking a (lo,hi) index pair into the computation's instruction array:
    for each fused_instruction f in comp:                  // LABEL_105..LABEL_112
        cost += NormalizedComputationCost(this, f, edge_index);   // recursion @0x13098f50
    this.fusion_cache.emplace(inst, cost);                 // SOO flat_hash_map insert
    return cost;

NOTE — Case 0x3D is a cache lookup whose miss (0x13098f18 onward) computes the cost in place: it recurses NormalizedComputationCost over every instruction of fused_instructions_computation(), accumulating into xmm0 (the vaddsd xmm0, var_30, xmm0 at 0x13098fd0), then stores the sum into the this+0x80 map. The cache is populated by NormalizedComputationCost itself on the parent fusion's first query, not by an external pass.

The recursion guard at the function's top (operand_index > 0 + binary op + operand(0) is iota/broadcast → 0.0) prevents a fused binary op from charging for an iota/broadcast operand that the loop pre-path or a sibling edge already accounts for.


GetCyclesIfFused — the Fused Bundle Merge

Purpose

GetCyclesIfFused (@ 0x130aba40) returns the bundle cost of fusing a (producer, consumer) pair — the total_fused term of the priority formula. Its key structural fact: it does not hand-merge two ResourceVectors. It treats the consumer as a fusion whose operand set is augmented by the producer (a FusionState), then prices that merged op with the same GetHloResourcesImpl used for any fusion. The cross-FU packing therefore happens inside the shared machinery, not in this function. The signature is StatusOr<double> GetCyclesIfFused(producer, consumer, const PerCalculationOptions&, ResourceVector* out); rbx is the StatusOr<double>* return slot, the producer arrives in a4/v19, the consumer in a3, and out in a5/a6.

Entry Point

GetCyclesIfFused (0x130aba40)                          ── StatusOr<double>
  ├─ IsFusionSupportedHlo      (0x130abee0)            ── eligibility gate → trivial 1.0 cy
  ├─ numeric element-type mask (0x2FFF91FFE | …)       ── non-numeric → 1.0 cy
  ├─ ShapeUtil::IsZeroElementArray                      ── (consumer ≠ reduce) zero-elem → 1.0 cy
  ├─ IsConvLowerable (0x14553620) / ExtractConvLikeHlo (0x1d6aa140)  ── conv-like main-op pick
  │     └─ GetReduceWindowType (0x1454d4a0) == -1/2     ── max-pool → FLT_MAX sentinel
  ├─ FusionState::Create(consumer, producer) (0x130ab320)           ── combined operand set
  ├─ GetHloResourcesImpl(consumer, opts, &fs, isFused=1) (0x130aa580) ── price merged op
  │     └─ ScaleAndSumOutputFusionResourceVectors (0x130b8320)       ── 4-emitter combine
  └─ ResourceVector::Add(out, merged, Defaults) (0x1c89b820)         ── fold into caller's out

Algorithm

StatusOr<double> GetCyclesIfFused(producer /*a4*/, consumer /*a3*/, opts, out /*a5*/):  // 0x130aba40
    // (A) ELIGIBILITY GATES — return a trivial 1.0-cycle cost if the op is not modellable.
    if !IsFusionSupportedHlo(consumer, opts.target)        // 0x130abee0
       or consumer.shape.element_type not in NUMERIC_MASK  // 0x2FFF91FFE | {0x20,0x21} | 0x400048000
       or (consumer.opcode != 0x5B /*reduce*/              // reduce is exempt from the next test
           and ShapeUtil::IsZeroElementArray(producer.shape)):
        out.scalar = 0;                                    // *(this+1) = 0
        return StatusOr<double>(1.0);                      // *(this) = 1  (1 trivial cycle)

    // (B) CONV-LIKE MAIN-OP SELECTION — pick the conv/reduce-window "hero" between the two ops.
    main = consumer;                                       // default
    for cand in {producer, consumer}:                      // probe both
        if IsConvLowerable(cand):                          // 0x14553620
            conv = ExtractConvLikeHlo(cand);               // 0x1d6aa140
            if conv and conv.opcode == 0x5E /*reduce-window*/:
                t = GetReduceWindowType(conv);             // 0x1454d4a0
                // (C) MAX-POOL SENTINEL: a max-pool (t==2) or unknown (t==-1) window is not
                //     bundle-cost-modellable → emit FLT_MAX cycles so it never fuses.
                if t == 2 or t == -1:
                    out.resources[…] = FLT_MAX;            // 0x47EFFFFFE0000000 = qword_A2E0530
                    return StatusOr<double>(out.MaxResourceCycles());   // = FLT_MAX
            else:
                main = cand;                               // this cand is the conv-lowerable main op

    // (D) BUILD THE MERGED FUSION STATE — consumer augmented by producer's operands.
    FusionState fs;
    FusionState::Create(&fs, consumer, producer);          // 0x130ab320
    //   fs records consumer.operands() + producer.operands() as one combined set, and the
    //   per-producer operand-indices at which the consumer USES the producer (the internal edges).

    // (E) PRICE THE MERGED OP AS A FUSION.
    StatusOr<ResourceVector> merged =
        GetHloResourcesImpl(consumer, opts, &fs, /*isFused=*/1);   // 0x130aa580
    if !merged.ok(): return merged.status();               // propagate w/ source loc cost_model.cc:2154

    double fused_cycles = merged.value().scalar_cycles;    // [merged + 0x240-ish]

    // (F) FOLD THE MERGED BYTE/SLOT MAP INTO THE CALLER'S out VECTOR.
    out.Add(merged.value(), AddOptions::Defaults());       // 0x1c89b820 — slots 9/11 MAX, rest ADD
    return StatusOr<double>(fused_cycles);

GetHloResourcesImpl walks the consumer's fused expression plus the producer (via fs). For each operand edge it consults IsProducerUse (@ 0x130ab0c0): edges recorded in the FusionState are internal, so their input-DMA bytes are not deposited into the MemXfer slots R[9..12] — this is what drops the producer→consumer HBM round-trip from the fused cost. The producer's compute is accumulated into the same ResourceVector as the consumer's, and the sub-emitter costs are combined by ScaleAndSumOutputFusionResourceVectors.

The Max-Pool FLT_MAX Sentinel

The one early-exit worth calling out: a conv-like op that resolves to a reduce-window (opcode 0x5E) whose GetReduceWindowType is 2 (max-pool) or -1 (unknown) is not bundle-cost-modellable. Rather than mis-price it, the function writes the FLT_MAX bit pattern 0x47EFFFFFE0000000 (= .rodata 0xa2e0530, byte-verified) into the result ResourceVector's slots and returns MaxResourceCycles() over it — an effectively infinite cost that guarantees the fusion is never chosen. This mirrors the same t ∈ {−1, 2} reduce-window sentinel in GetHloResourcesImpl's routing (TpuHloCostAnalysis).


ScaleAndSumOutputFusionResourceVectors — the Per-FU Combine

Purpose

This (@ 0x130b8320) is the routine GetHloResourcesImpl invokes to fuse the sub-emitter costs of a merged op into one ResourceVector. It combines up to four sub-emitter vectors — activations, kernel, output, and conv-compute — each scaled by its own iteration count, with the memory-transfer latency slots (9 and 11) MAX-combined rather than summed.

Algorithm

ResourceVector ScaleAndSumOutputFusionResourceVectors(            // 0x130b8320
        out, rv_act,  act_count,
             rv_kern, kern_count,
             rv_out,  out_count,
             rv_conv, conv_count):
    CHECK(conv_count >= act_count)  << "conv_compute_iteration_count >= activations_iteration_count";  // :1567
    CHECK(conv_count >= kern_count) << "… >= kernel_iteration_count";                                   // :1568
    CHECK(conv_count >= out_count)  << "… >= output_iteration_count";                                   // :1569
    out = 0;                                              // zero all 23 slots + byte maps

    // (1) ADD each scaled sub-emitter (default Add: every slot accumulates).
    for (rv, count, subset) in {(rv_act,  act_count,  SubsetOpts@0xae0f0c8),
                                (rv_kern, kern_count, SubsetOpts@0xae0f0cc),
                                (rv_out,  out_count,  SubsetOpts@0xae0f0d0),
                                (rv_conv, conv_count, SubsetOpts@0xae0f0d4)}:
        sub    = rv.GetSubset(subset);                    // 0x1c89bb00
        scaled = sub.GetScaled((double)count, ScaleOptions::Defaults());   // 0x1c89b3c0
        out.Add(scaled, AddOptions::Defaults());          // 0x1c89b820

    // (2) OVERRIDE slots 9 and 11 (MemXfer Input/Output LATENCY) with MAX-across-emitters × count.
    out.Acc(9,  max(rv_act[9],  rv_kern[9],  rv_out[9],  rv_conv[9])  * conv_count);   // +0x48
    out.Acc(11, max(rv_act[11], rv_kern[11], rv_out[11], rv_conv[11]) * conv_count);   // +0x58
    return out;

The four GetSubset calls use distinct SubsetOptions literals at .rodata 0xae0f0c8 / 0xcc / 0xd0 / 0xd4 (each the 4-byte tuple {0x00,0x01,0x01,0x01}, byte-verified) — one subset shape per emitter. The final two Acc calls are the explicit slot-9/11 MAX-combine: the input-DMA and output-DMA startup latency terms are paid once (the maximum across the four sub-emitters), not summed per emitter — a transfer's startup latency overlaps across the fused sub-regions, while its bandwidth (slots 10/12) and the compute slots accumulate.

NOTE — This routine builds the combined ResourceVector (with slots 9/11 MAX-combined) and returns it after Acc(this, 11, …); it does not itself reduce to cycles. The final scalar reduction happens in the caller via a separate MaxResourceCycles (@0x1c89b9e0). The MAX-combine of slots 9/11 here is an explicit per-slot override, distinct from the MaxResourceCycles reduction's serial-sum of the full {9,10,11,12} memory group (Resource Enum).

Why total_fused < total_unfused

total_unfused = GetHloCycles(producer)·user_count + Σ_users GetHloCycles(user)
   each op is its OWN bundle: the producer's output is WRITTEN to HBM (output-DMA, R[11]/R[12])
   and each user READS it back (input-DMA, R[9]/R[10]). Those DMA cycles are counted (user_count+1)×.

total_fused = Σ_users GetCyclesIfFused(producer, user)
   the producer→user edge is INTERNAL (IsProducerUse drops its input-DMA), the producer's HBM
   output write is eliminated (it stays in VMEM), and the producer's compute slots OVERLAP the
   user's in the MaxResourceCycles plain-MAX group (e.g. producer Matmul R[1] overlaps user
   VectorAlu R[3..5]) instead of being two serial bundles.

priority = total_unfused − total_fused > 0  whenever the saved DMA + the bundle-packing overlap
           exceed any duplicated compute the fusion introduces.

For a single-user producer the priority is (C_p + C_u) − C_f; for an n-user producer it is (n·C_p + Σ_i C_u_i) − Σ_i GetCyclesIfFused(producer, user_i), where C_p = MaxResourceCycles(RV_producer), C_u likewise for the consumer, and C_f = GetCyclesIfFused(producer, consumer).


IsFusionSupportedHlo — the Eligibility Gate

IsFusionSupportedHlo (@ 0x130abee0) is the first gate of GetCyclesIfFused. It returns false (→ the trivial 1-cycle cost) when the consumer is not bundle-cost-modellable:

  • the element type has bit-width 64 (ShapeUtil::ElementHasBitWidth(shape, 0x40)), unless the op is a custom fusion or a collective-compute fusion;
  • the op is a zero-element array;
  • the opcode is in the REJECT set _bittest(0x2000100000400001, opcode−5) for opcodes in [5..0x42]{0x05 all-gather-done, 0x1B call, 0x31 custom-call, 0x42 infeed} — control-flow / host-I/O ops that have no functional-unit occupancy.

Worked Example — a small elementwise fusion

A loop fusion rooted at multiply (0x4B, default tier 1.0), three operands, output [256,128]:

operand0 = parameter [256,128]   (root minor = 128)
operand1 = broadcast scalar      (minor << 128)
operand2 = parameter [256,128]   (minor = 128)

LoopFusion pre-path:

multiplier = 1.0
operand0: op_minor 128 >= root_minor 128  → break immediately (operand as-wide as root)
→ estimate abandoned → fall to per-opcode switch on the root (kFusion 0x3D) → recurse the body

For the body's multiply leaf (default tier) the per-op weight is ChunksIn([256,128]) × 1.0. Replace the root with divide and the leaf is ChunksIn × 10.0; with erf, ChunksIn × 42.0; with a dense convolution, flop_count / peak (heavy). This is the compute term total_unfused subtracts; the bundle term GetCyclesIfFused computes via the merge above.


Function Map

FunctionAddressRole
TpuPriorityFusionQueue::NormalizedComputationCost0x130989a0opcode→weight scalar + conv/fusion escapes
Target::ChunksIn(Shape&)0x1d619900chunk-granule element count (the ×multiplier base)
TpuHloCostAnalysis ctor0x130a1620conv flop sub-analysis
HloCostAnalysis::HandleConvolution0x1e480be0conv flop emitter
HloCostAnalysis::flop_count0x1e4841e0reads cached flop property
LhsFormatForConvInstruction0x1307bd40conv LHS → MatmulDataFormat (peak select)
Target::FlopsPerSecond0x1d61f280per-format peak (vtable+0x718)
Target::VectorAluSlotsPerTensorCore0x1d61e380VALU slot count (vtable+0x500)
Target::TensorCoreFrequencyInMegaHertz0x1d615b60TC clock (cycles ← seconds)
CostModel::GetCyclesIfFused0x130aba40fused-pair bundle cost driver
IsFusionSupportedHlo0x130abee0eligibility gate (→ 1-cycle trivial)
IsConvLowerable0x14553620conv-lowerable predicate
ExtractConvLikeHlo0x1d6aa140pull the conv/reduce-window root
GetReduceWindowType0x1454d4a0−1/2 max-pool sentinel
FusionState::Create0x130ab320combined operand set + internal-edge map
CostModel::IsProducerUse0x130ab0c0drops internal-edge input DMA
CostModel::GetHloResourcesImpl0x130aa580prices the merged op
ScaleAndSumOutputFusionResourceVectors0x130b83204-emitter combine; slots 9/11 MAX
ResourceVector::Add0x1c89b820per-slot accumulate (Defaults)
ResourceVector::MaxResourceCycles0x1c89b9e0scalar bundle-cycle reduction

Weight / Formula Constants (.rodata, byte-verified)

AddressValueUsed by
0xa2df2301.0default weight / conv derate +1.0 / multiplier base
0xa2de8304.0logistic, reduce, cross-lane broadcast
0xa2df49810.0divide
0xa2df1a042.0erf
0xa2e05303.4028e38 (FLT_MAX)max-pool GetCyclesIfFused sentinel
0xa2e02081.0e6conv freq_MHz → Hz
0xa2e05a8-0.03conv derate slope 1 − 0.03·Target[+0x4ac]
0xa2e01180.00048828125 (1/2048)grouped-conv flop→cost factor

ComponentRelationship
TpuHloCostAnalysisSupplies the conv flop_count this page caches; the bundle-occupancy peer of the scalar ladder
Resource Enum (23-slot)The ResourceVector slots and MaxResourceCycles reduction the fused merge feeds
ConvolutionCostStateThe conv-shaped state HandleConvolution walks before the flop is cached
Reduce-Window / Pooling CostThe GetReduceWindowType taxonomy behind the max-pool FLT_MAX sentinel
Per-Opcode Cycle ConstantsThe per-gen cycles deposited into the merged ResourceVector

Cross-References