Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Learned Cost-Model Client

Addresses apply to libtpu.so from the libtpu-0.0.40-cp314 wheel (BuildID md5 89edbbe81c5b328a958fe628a9f2207d). The binary is not stripped — every symbol below is a demangled C++ name. .text/.rodata VMA == file offset; .data.rel.ro VMA − 0x200000 == file offset. Other versions differ.

Abstract

The TPU convolution lowering path contains a complete client-side hook for an ML-learned cost model, but no server, no client class, and no predictor ship in this build. The hook lives inside SpatialMajorConvolution::ComputeWindowConfigInternal (@ 0x13172c80) — the function that searches the window-tiling space for the fastest convolution schedule. The hook reaches a borrowed EmitterLearnedCostModelBase* pointer stored on the emitter at this+0x20d0; when that pointer is non-null and the enable_learned_cost_model byte is set, the search consults the learned model through four virtual slots: a per-instruction enable check (vtable+0x10), a candidate-window registration call (vtable+0x18, RegisterCandidateWindow), a fastest-window status query (vtable+0x30) whose absl::Status result selects between the learned and the analytic answer, and — only on a non-error status — a paired value-fetch (vtable+0x38) that returns the learned cycle estimate (a float at +0x138 of the returned object). The pointer is borrowed all the way down from the 80-parameter LoweringEmitter ctor through ConvolutionEmitter::Create.

The familiar reference frame is XLA's pluggable cost-model interface — an abstract base with virtual Predict/Register methods, a concrete client that batches features into RPCs to an embedding/inference service, and an absl::Status-gated fallback to the analytic model on any RPC failure. libtpu-0.0.40 ships the interface contract (proto options, the four-enum mode state machine, the gflag wiring, the call sites, the CHECK_OK on the registration result, and the Failed to get fastest window using learned cost model failure-log fallback) but not the implementation: there is no xla::jellyfish::EmitterLearnedCostModelBase vtable, no concrete LearnedCostModelClient, no LearnedCostModelService::Stub, and no embedded model bytes. The EmitterLearnedCostModelBase* is therefore always null in the shipping binary, so every consult is short-circuited and the analytic TpuHloCostAnalysis / window-search result is used unchanged.

This page documents what is real and reimplementable — the wire-level options proto (LearnedCostModelClientOptions with its ServiceType enum, RPC-endpoint fields, and recovered C++ struct layout), the four-enum mode/validation/DB-query state machine the interface implements, the recovered EmitterLearnedCostModelBase vtable shape (four exercised slots), the MLCostModelWindowInfo request payload the client receives, the status/error handling at each call site, and the precise null/flag double-gate that drops the whole layer back to the analytic model.

For reimplementation, the contract is:

  • The EmitterLearnedCostModelBase vtable: IsEnabled(HloInstruction*) @ +0x10 (returns bool), RegisterCandidateWindow(LcmKey, MLCostModelWindowInfo) @ +0x18 (returns absl::Status, CHECK_OK'd), the fastest-window status query @ +0x30 (returns absl::Status, soft-failed), and the paired value-fetch @ +0x38 (returns a pointer to a result object holding the learned cycle estimate as a float at +0x138).
  • The MLCostModelWindowInfo request payload — the eight designated-initializer fields the client receives for every candidate window, including estimated_cycles_classic (the analytic answer, always supplied as the floor).
  • The status handling: the +0x30 query result absl::Status selects learned-vs-analytic; on non-OK the layer logs spatial_major_convolution.cc:4006 and falls through to SetupBestConfig with the classic search result.
  • The double-gate to the analytic model: the this+0x20d0 null check and the TpuCompEnv+0xed6 (enable_learned_cost_model) byte; either-false drops to analytic.
  • The LearnedCostModelClientOptions proto wire shape, ServiceType (LOCAL/REMOTE) enum, and the recovered struct layout — what a client implementation would have to parse from the gflag.
Hook siteSpatialMajorConvolution::ComputeWindowConfigInternal @ 0x13172c80
Registration lambda…::ComputeWindowConfigInternal(…)::$_0 (policy_func) @ 0x1317fe00
Client pointer slotthis+0x20d0 (SpatialMajorConvolution), borrowed EmitterLearnedCostModelBase*
Enable slotvtable+0x10bool IsEnabled(HloInstruction*) (@ call 0x13173…)
Register slotvtable+0x18absl::Status RegisterCandidateWindow(LcmKey, MLCostModelWindowInfo)
Query slotvtable+0x30absl::Status fastest-window status query (soft-failed)
Value-fetch slotvtable+0x38 — paired fetch of the learned cycle estimate (float @ result +0x138)
Enable flag byte*(byte*)(GetTpuCompEnv(inst)+0xed6) == enable_learned_cost_model
Options proto / vtablexla::jellyfish::LearnedCostModelClientOptions @ 0x21cffc10
gflagxla_tpu_emitter_learned_cost_model_options = AutoOr<EmitterLearnedCostModelOptions>
Failure-log fallbackspatial_major_convolution.cc:4006SetupBestConfig (analytic)
Shipping defaultclient pointer null everywhere → analytic model always used

What Ships vs What Does Not

The learned cost model is a textbook "future-extension hook": the schema, the gflag, the consumer call sites, and the failure-fallback all ship; the predictor does not. The split is exact and verifiable by symbol scan.

ComponentPresent?Status
EmitterLearnedCostModelOptions protoYESReachable
LearnedCostModelClientOptions protoYESReachable
FusionDataProtoGenerationOptions protoYESReachable
EmbeddingCacheEntry / EmbeddingCacheDB protosYESReachable
LearnedCostModelMode / DbQueryType / MLOutputValidationStrategy / ServiceType enumsYESDecoded
gflag xla_tpu_emitter_learned_cost_model_optionsYESParsed
Consumer call sites (4 vtable slots)YESCode-present, runtime-dead
CHECK_OK on RegisterCandidateWindow resultYESSource spatial_major_convolution.cc:3996
Failed to get fastest window … analytic fallbackYESSource spatial_major_convolution.cc:4006
EmitterLearnedCostModelBase vtable / typeinfoNOType-only
concrete LearnedCostModelClient classNOAbsent
LearnedCostModelService::Stub (gRPC)NOAbsent
Predict / Inference / Score / EstimateCycles methodNOAbsent
embedded model (SavedModel / ONNX / TFLite blob)NOAbsent

NOTE — the absence is positive evidence, not a gap in analysis. Scanning the (non-stripped) symbol table finds the …ClientOptions proto family and its ser/deser methods, but EmitterLearnedCostModelBase exists only as a function-parameter type in the mangled names of ConvolutionEmitter::Create, SpatialMajorConvolution::SpatialMajorConvolution, and LoweringEmitter::LoweringEmitter. A class used only as a borrowed pointer needs no emitted vtable or typeinfo in the consumer's translation unit — which is exactly the footprint of an interface whose only implementation lives out-of-tree.


The Client Interface — EmitterLearnedCostModelBase vtable

Purpose

EmitterLearnedCostModelBase is the abstract interface the convolution emitter calls. It is borrowed (never owned) by the emitter, so no destructor slot is exercised. Three virtual slots are reached from the decompiled code; their signatures are pinned by the call-site register usage and the CHECK_OK literal.

Recovered vtable

SlotMethod (inferred)ReturnsCall-site evidence
+0x00offset_to_topItanium ABI
+0x08typeinfo ptrItanium ABI
+0x10bool IsEnabled(const HloInstruction*)bool (in %al)(*(…)(*v44 + 16))(v44, hlo) — arg is *(HloInstruction**)(this+72); result drives the consult branch
+0x18absl::Status RegisterCandidateWindow(const LcmKey&, const MLCostModelWindowInfo&)absl::Status(*(…)(*v52 + 24))(v52, key.first, key.second, &status) in lambda; CHECK_OK'd
+0x30fastest-window status queryabsl::Status (stack-returned StatusRep*)(*(…)(**(this+0x20d0) + 48))(&out, this_lcm, fp.first, fp.second); status != OK selects fallback (logs .cc:4006)
+0x38paired value-fetch (learned cycle estimate)pointer to result object(*(…)(**(this+0x20d0) + 56))(this_lcm, fp.first, fp.second); reached only when +0x30 is OK, [1] is presence-checked, float read at result +0x138
// SpatialMajorConvolution::ComputeWindowConfigInternal  @0x13172c80
bool consult = false;                                    // v172
void* lcm = *(void**)(this + 0x20d0);                    // borrowed EmitterLearnedCostModelBase*
if (lcm) {                                               // null → analytic (default ship)
    if (enable_learned_cost_model) {                     // bool param a9, see gate below
        const HloInstruction* hlo = *(const HloInstruction**)(this + 72);
        consult = (*(bool(**)(void*, const HloInstruction*))(*(void**)lcm + 0x10))(lcm, hlo);  // IsEnabled
        if (consult) {
            fp = xla::GetHloInstructionFingerprint(hlo);  // @0x13180b80 — the prediction key
        }
    } else {
        consult = false;
    }
}

// ... the window-tiling search runs (IterateThroughWindowConfigs), invoking the
//     registration lambda once per candidate window when consult is true ...

if (consult != true) {                                   // learned model not used
    SetupBestConfig(/* classic search result */);        // analytic answer
} else {
    // fastest-window status query @ vtable+0x30, keyed by the fingerprint
    status = (*(Status(**)(…))(*(void**)lcm + 0x30))(&out, lcm, fp.first, fp.second);
    if (!status.ok()) {                                  // analytic fallback
        LOG("Failed to get fastest window using learned cost model "
            "for instruction: ", hlo, " with status: ", status);   // .cc:4006
        SetupBestConfig(/* classic search result */);    // <-- fall through to analytic
    } else {
        // paired value-fetch @ vtable+0x38, same fingerprint key
        result = (*(void*(**)(…))(*(void**)lcm + 0x38))(lcm, fp.first, fp.second);
        float learned_cycles = (result[1] ? *(float*)((char*)*result + 0x138) : 0.0f);
        SetupBestConfig(/* learned-selected window, learned_cycles */);
    }
}

GOTCHA — the vtable+0x10 slot takes the HloInstruction* (loaded from this+72) and returns a bool; do not mistake it for a no-arg IsEnabled(). The per-instruction argument is what lets a real client gate the learned path on op shape/type. In the shipping binary the slot is never reached because lcm is null.

The registration call — RegisterCandidateWindow

The window-search iterator invokes a std::function (…::$_0 policy_func @ 0x1317fe00) once per candidate tiling. When the learned path is active, the lambda builds an MLCostModelWindowInfo on the stack from its parameters (the InlinedVectors are deep-copied via inlined_vector_internal::Storage::InitFrom) and calls vtable+0x18:

// lambda body @0x1317fe00 — one call per candidate window
if (*(byte*)(GetTpuCompEnv(hlo) + 0xed6) == 1            // enable_learned_cost_model
    && capture_flag == 1 && a8 < threshold) {            // **(this+0x80)+0x10 byte + window bound
    void* lcm = *(void**)(this + 0x20d0);
    Status s = (*(Status(**)(void*, long, long, Status*))(*(void**)lcm + 0x18))(
                   lcm, lcm_key.first, lcm_key.second, &out);   // RegisterCandidateWindow
    // CHECK_OK — fatal on a non-OK registration:
    //   "learned_cost_model_->RegisterCandidateWindow( *lcm_key,
    //    MLCostModelWindowInfo( {.activations_window = …, .kernel_window = …,
    //    .output_window = …, .iteration_bounds = …, .window_info = …,
    //    .vmem_footprint_granules = estimated_granules, .bundles = estimated_bundles,
    //    .estimated_cycles_classic = estimated_cycles})) is OK"  // .cc:3996
}

NOTE — RegisterCandidateWindow is CHECK_OK'd (fatal on failure), while the fastest-window query @ +0x30 is soft-failed (logs and falls back). The asymmetry is deliberate: registration is a local bookkeeping push into the consideration set (must not fail), whereas the prediction query can fail transiently (RPC down) and must degrade to the analytic model rather than abort compilation.


The Request Payload — MLCostModelWindowInfo

Reconstructed from the CHECK_OK designated-initializer literal and the lambda's stack-build sequence. This is the per-candidate-window feature record the client receives; the trailing estimated_cycles_classic is the analytic answer, supplied as a baseline/floor.

struct MLCostModelWindowInfo {
  absl::InlinedVector<int64_t, 6> activations_window;      // input  window dims (deep-copied)
  absl::InlinedVector<int64_t, 6> kernel_window;           // kernel window dims
  absl::InlinedVector<int64_t, 6> output_window;           // output window dims
  absl::InlinedVector<int64_t, 6> iteration_bounds;        // outer loop bounds
  WindowSizingInfo               window_info;              // sizing metadata (copied from a7)
  int64_t                        vmem_footprint_granules;  // estimated_granules
  int64_t                        bundles;                  // estimated_bundles
  int64_t                        estimated_cycles_classic; // analytic-model estimate (baseline)
};

The first argument to RegisterCandidateWindow is *lcm_key — an LcmKey whose two 8-byte halves (v53 = *v50; v54 = v50[1]) are passed by value. Its full shape is not recoverable beyond being a two-word key; it deduplicates candidates within one fusion search and most likely encodes operation type, MXU format, and a per-window hash. The fastest-window query @ +0x30 is keyed instead by xla::GetHloInstructionFingerprint(hlo) (@ 0x13180b80), also passed as a two-word value (fp.first, fp.second).

CONTRACT — every candidate window the analytic search considers is reported to the client with its classic cycle estimate attached. A real client therefore never starts cold: it can return the classic number verbatim (validation strategy ALWAYS_TRUST off), clamp it (NO_NEGATIVE_CYCLES), or override it. This is the wire-level meaning of the MLOutputValidationStrategy enum below.


The Mode State Machine — Four Enums

The options proto encodes a small state machine the (missing) client implements. All four enums are decoded from the embedded FileDescriptorProto; value-name strings are present in .rodata.

LearnedCostModelMode

ValueNameSemantic
0LEARNED_COST_MODEL_MODE_INVALIDdefault — treated as "no learned cost model"
1LEARNED_COST_MODEL_MODE_ONLY_DBlook up cycles from a pre-built DB only
2LEARNED_COST_MODEL_MODE_ONLY_ML_PREDICTIONalways use the ML predictor
3LEARNED_COST_MODEL_MODE_DB_WITH_FALLBACK_TO_ML_PREDICTIONDB first, ML on miss
4LEARNED_COST_MODEL_MODE_ONLY_DATA_COLLECTIONdump FusionData protos for offline training; no scoring

The mode names disclose the intended design: an offline DB of pre-measured (window-config → cycles) tuples, with an ML predictor filling gaps. The +0x18 RegisterCandidateWindow slot is the data-collection / consideration-set push (used in all modes); the +0x30 query slot is the DB-lookup-or-predict.

MLOutputValidationStrategy

ValueNameSemantic
0ML_OUTPUT_VALIDATION_STRATEGY_NONEno validation
1ML_OUTPUT_VALIDATION_STRATEGY_NEVER_TRUSTalways fall back to the classic cost model
2ML_OUTPUT_VALIDATION_STRATEGY_ALWAYS_TRUSTtake ML output verbatim
3ML_OUTPUT_VALIDATION_STRATEGY_NO_NEGATIVE_CYCLESreject only negative-cycle predictions

DbQueryType

ValueNameSemantic
0DB_QUERY_TYPE_NONEno DB query
1DB_QUERY_TYPE_REPLAY_PREDICITIONS(sic — proto typo) replay stored ML cycles
2DB_QUERY_TYPE_GROUND_TRUTHlook up measured ground-truth cycles

LearnedCostModelClientOptions.ServiceType

ValueNameSemantic
0SERVICE_TYPE_UNSPECIFIEDinvalid sentinel
1SERVICE_TYPE_LOCALload model from local_embedding_model_path, run in-process
2SERVICE_TYPE_REMOTEissue RPCs to remote_embedding_server_address

GOTCHA — both SERVICE_TYPE_LOCAL and SERVICE_TYPE_REMOTE are unbuildable in this wheel. There is no in-process model loader for the LOCAL path and no LearnedCostModelService::Stub / BlockingUnaryCall for the REMOTE path. The binary does ship unrelated gRPC stubs (BarnaCoreInterWorkerCommunicationRpc::Stub, RuntimeMetricService::Stub, MegaScaleTransport::Stub), which is the proof that a learned-cost-model RPC stub would be visible if it existed.


The Options Proto Wire Shape

The RPC/service configuration is carried entirely as a serialized proto inside one gflag — there is no dedicated boolean or endpoint flag. The two proto files are linked by a sub-message: EmitterLearnedCostModelOptions.learned_cost_model_client_options holds a LearnedCostModelClientOptions.

package xla.jellyfish;
import "third_party/tensorflow/core/framework/tensor.proto";

message LearnedCostModelClientOptions {
  enum ServiceType { SERVICE_TYPE_UNSPECIFIED = 0; SERVICE_TYPE_LOCAL = 1; SERVICE_TYPE_REMOTE = 2; }

  optional ServiceType embedding_service_type                        = 1;
  optional string      remote_embedding_server_address               = 2;  // REMOTE RPC endpoint
  optional string      remote_embedding_model_name                   = 3;  // REMOTE model selector
  optional int32       inflight_rpc_monitoring_interval_milliseconds = 4;  // RPC liveness poll (serialized int32)
  optional string      local_embedding_model_path                    = 5;  // LOCAL model file
  optional string      embedding_cache_path                          = 6;  // EmbeddingCacheDB on disk
  optional FusionDataProtoGenerationOptions fusion_data_proto_generation_options = 7;
  optional int32       max_batch_size                                = 8;  // RPC batch size (serialized int32)
}

message EmbeddingCacheEntry { optional bytes fingerprint = 1; optional tensorflow.TensorProto embedding = 2; }
message EmbeddingCacheDB    { repeated EmbeddingCacheEntry entries = 1; }
message FusionDataProtoGenerationOptions {
  optional bool include_standalone_fusion_module   = 1;
  optional bool include_expert_and_gating_features = 2;
}

C++ struct layout, recovered byte-exact from the copy-ctor LearnedCostModelClientOptions(Arena*, const&) @ 0x1db653e0 (vtable stored is off_21CFFC20):

OffsetFieldSizeNotes
+0x00vtable ptr80x21cffc10+0x10
+0x08internal::InternalMetadata8arena / unknown-field tag
+0x10uint32_t _has_bits_4presence bitmap
+0x14cached_size4
+0x18TaggedStringPtr remote_embedding_server_address8ForceCopy if tag bits set
+0x20TaggedStringPtr remote_embedding_model_name8
+0x28TaggedStringPtr local_embedding_model_path8
+0x30TaggedStringPtr embedding_cache_path8
+0x38FusionDataProtoGenerationOptions*8copied iff _has_bits_ & 0x10
+0x40int32_t embedding_service_type (enum, field 1)4serialized first, tag byte 0x08, from *((int*)this+16)
+0x44int32_t inflight_rpc_monitoring_interval_milliseconds (field 4)4WriteInt32ToArrayWithField<4> from *((int*)this+17)
+0x48int32_t max_batch_size (field 8)4WriteInt32ToArrayWithField<8> from *((int*)this+18)

NOTE — the copy-ctor guards the sub-message with (*(byte*)(this+0x10) & 0x10)_has_bits_ bit 4 governs fusion_data_proto_generation_options. The four string fields are proto2::internal::TaggedStringPtr and are ForceCopy'd when their low tag bits are set (arena-owned vs inline). The three tail integers are pinned by _InternalSerialize @ 0x1db65920: field 1 (embedding_service_type) writes from +0x40, field 4 (inflight_rpc_monitoring_interval_milliseconds) from +0x44, field 8 (max_batch_size) from +0x48 — all via WriteInt32ToArrayWithField, so the two interval/batch fields are serialized as 32-bit even though the .proto text below declares them int64. The copy-ctor copies the +0x40/+0x44 pair as one 8-byte word and +0x48 as a separate dword. MEDIUM confidence on the proto-declared width of the two non-enum integers (the serializer uses the int32 path; the on-wire varint is width-agnostic for small values).

The owning EmitterLearnedCostModelOptions adds the top-level switches: enable_learned_cost_model (tag 1, the gate byte), cost_model_mode, db_query_type, ml_output_validation_strategy, db_path, max_num_considered_windows, dump_fusion_data_proto[_dir]. Of these, only enable_learned_cost_model has a runtime consumer (the +0xed6 gate); the rest are deserialized but dead because the client they configure is absent.


The Analytic Fallback — Double Gate

Two independent gates drop the entire learned layer back to the analytic window search. Either being false is sufficient.

// Gate 1 — pointer null check (ComputeWindowConfigInternal @0x13172c80)
void* lcm = *(void**)(this + 0x20d0);
if (!lcm) consult = false;            // DEFAULT in shipping libtpu-0.0.40 — always taken

// Gate 2 — enable flag (lambda @0x1317fe00, and the IsEnabled branch)
if (*(byte*)(GetTpuCompEnv(hlo) + 0xed6) != 1) /* skip */ ;   // enable_learned_cost_model

// Gate 3 (defence-in-depth) — captured byte flag inside the lambda
if (**(byte**)(capture + 0x80 ... + 0x10) != 1) /* skip */ ;

The this+0x20d0 pointer is set verbatim from the EmitterLearnedCostModelBase* constructor parameter (SpatialMajorConvolution C2 @ 0x130dd180 stores it with a raw mov, no allocation), which is propagated null from ConvolutionEmitter::Create (@ 0x130d86c0) and ultimately from LoweringEmitter::LoweringEmitter (@ 0x10c309c0). Because no caller ever supplies a non-null pointer in this build, Gate 1 always fires and Gates 2–3 are unreachable — the analytic TpuHloCostAnalysis flop model and the classic window search (SetupBestConfig) drive every convolution-lowering decision.

When a client is present and a query fails, the soft fallback at spatial_major_convolution.cc:4006 logs the instruction and the failing absl::Status, then calls SetupBestConfig with the classic search result — the same code path Gate 1 reaches, so a learned-model RPC outage is functionally identical to having no client at all.

GateLocationFieldShipping valueEffect when false
1 — pointerthis+0x20d0borrowed EmitterLearnedCostModelBase*nullskip consult; analytic
2 — enableGetTpuCompEnv(hlo)+0xed6enable_learned_cost_model0 (proto default)skip register/query
3 — capturelambda capture +0x10 bytepropagated enable0skip per-candidate register
soft — query+0x30 result absl::Statusn/a (unreachable)OKlog .cc:4006, analytic

Function & Symbol Map

SymbolAddressRole
SpatialMajorConvolution::ComputeWindowConfigInternal0x13172c80hook site: enable check, fingerprint, query, fallback
…::ComputeWindowConfigInternal(…)::$_0 (policy_func)0x1317fe00per-candidate RegisterCandidateWindow call + CHECK_OK
SpatialMajorConvolution::SpatialMajorConvolution (C2)0x130dd180stores EmitterLearnedCostModelBase* at this+0x20d0
ConvolutionEmitter::Create0x130d86c0forwards the (null) client pointer
LoweringEmitter::LoweringEmitter (C1)0x10c309c0originates the borrowed client pointer
xla::GetHloInstructionFingerprint0x13180b80builds the +0x30 query key
LearnedCostModelClientOptions(Arena*, const&)0x1db653e0copy-ctor → struct layout
LearnedCostModelClientOptions::_InternalSerialize0x1db65920proto wire encode
EmitterLearnedCostModelOptions(Arena*)0x1db63f20owning options proto ctor
AutoOr<EmitterLearnedCostModelOptions>::ParseFlag0x1d745680gflag → proto parse
LearnedCostModelClientOptions vtable0x21cffc10proto vtable
EmitterLearnedCostModelBase vtable / typeinfodoes not exist (interface only)
LearnedCostModelClient concrete classdoes not exist
LearnedCostModelService::Stub (gRPC)does not exist

QUIRK — the failure-log call site reads spatial_major_convolution.cc:4006 and the RegisterCandidateWindow CHECK_OK reads :3996 in this build (0.0.40). Source line numbers are build-version-specific; the surrounding VAs and the wire contract are the stable anchors.


Cross-References