Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

TMA Atoms

Abstract

The cute_nvgpu TMA atom family surfaces Hopper and Blackwell tensor-memory transfers as descriptor-driven IR. A TMA descriptor records the global tensor, tile box, strides, rank, swizzle, fill behaviour, and cache policy. Executable TMA atoms bind that descriptor to coordinates, an mbarrier, optional multicast state, and cache hints, then lower to asynchronous tensor copy or reduce instructions. This page documents the atom family, the descriptor contract, the verifier rules, and the lowering shape.

Atom Family

OperationRole
atom.tma_loadExecute asynchronous global-to-shared tensor load.
atom.tma_storeExecute asynchronous shared-to-global tensor store.
atom.tma_reduceExecute asynchronous tensor reduction into global memory.
atom.non_exec_tiled_tma_loadDescribe a tiled TMA load before mbarrier/cache binding.
atom.non_exec_tiled_tma_storeDescribe a tiled TMA store before execution binding.
atom.non_exec_tiled_tma_reduceDescribe a tiled TMA reduce before execution binding.
prefetch_tma_descPrefetch descriptor state before a transfer.
tma_descriptor_tiledDescriptor type for ordinary tiled tensor movement.
tma_descriptor_im2colDescriptor type for im2col tensor movement.
atom.make_exec_tmaBind a non-exec atom with mbarrier, multicast, and cache mode.

The non-exec atoms pay off because layout and partitioning can be verified before any pass commits to a runtime barrier or cache policy.

Partition Op and Mode Enums

The TMA atom family rooted at cute_nvgpu.atom.tma_partition routes every executable and non-exec TMA atom through one partition op — the canonical place where descriptor shape, transfer mode, multicast cardinality, and reduce kind are validated together. The partition verifier enforces eleven invariants on every TMA partition op and, on success, returns a packed result record per partitioned tile.

Three mode enums select the transfer variant. Load-mode covers single-CTA, two-CTA cooperative, and warp-multicast loads at two granularities; store-mode covers tiled stores and im2col-flavour stores; reduce-kind covers the asynchronous reduces the hardware supports.

typedef enum TmaLoadMode {
    TMA_LOAD_NO_MULTICAST   = 0,   // single-CTA load
    TMA_LOAD_TWO_CTA        = 1,   // 2-CTA cluster cooperative load
    TMA_LOAD_W_MULTICAST    = 2,   // warp multicast (16-thread)
    TMA_LOAD_W128_MULTICAST = 3,   // wide warp multicast (128-thread)
} TmaLoadMode;

typedef enum TmaStoreMode {
    TMA_STORE_TILED       = 0,     // tiled SMEM -> GMEM
    TMA_STORE_IM2COL      = 1,     // im2col-flavor tiled store
    TMA_STORE_IM2COL_W    = 2,     // im2col + warp multicast
    TMA_STORE_IM2COL_W128 = 3,
} TmaStoreMode;

typedef enum TmaReduceKind {
    TMA_REDUCE_ADD = 0, TMA_REDUCE_MIN = 1, TMA_REDUCE_MAX = 2,
    TMA_REDUCE_INC = 3, TMA_REDUCE_DEC = 4,
    TMA_REDUCE_AND = 5, TMA_REDUCE_OR  = 6, TMA_REDUCE_XOR = 7,
} TmaReduceKind;

The enums are part of the verifier's input contract. Consistency between load mode, store mode, and reduce kind is checked together with rank and swizzle in the eleven-step walk below.

Partition Result ABI

The partition verifier returns one packed TmaPartitionResult per partitioned tile into a SmallVector owned by the verifier and forwarded to the executable-atom builder. The 24-byte record carries the interned TMA tensor type, the tile element count, a flags word, and the non-exec atom body that downstream lowering consumes.

typedef struct TmaPartitionResult {
    /*+0x00*/ uint64_t   tma_tensor_type;       // interned MLIR Type * (TmaLoad/Store/ReduceAtomType)
    /*+0x08*/ uint32_t   tile_element_count;    // size(canonical_smem) * num_multicast
    /*+0x0C*/ uint16_t   flags;                 // see "Flags Word" below
    /*+0x0E*/ uint8_t    swizzle_mode;          // 0=none, 2=128B, 3=32B/64B blend
    /*+0x0F*/ uint8_t    rank;                  // descriptor rank (1..5)
    /*+0x10*/ uint64_t   non_exec_atom_body;    // interned non-exec atom Attribute *
} TmaPartitionResult;

Only the tensor type and atom body get consumed during executable-atom binding. The flags word, swizzle mode, and rank are echoed back so the executable-atom builder and downstream prefetch logic do not have to re-derive them from the descriptor type.

Flags Word

The 16-bit flags word records every property the partition verifier learned about the tile while it was walking the layout — the multicast mode, the im2col shape, the sparsity tier, the two-CTA cooperative bit, and a handful of operand-source bits used to short-circuit later checks. Downstream passes read this word bit-by-bit rather than re-running the partition algorithm.

BitFieldMeaning
0multicasttile lowers to a multicast TMA load (W or W128)
1im2coltile is im2col-flavoured (rank reduced before transfer)
2im2col_wim2col with warp-cooperative offset table
3im2col_w128im2col with wide warp-cooperative offset table (128-thread)
4two_cta2-CTA cooperative load; CTA V-map has been folded into the SMEM layout
5sparsemetadata operand present; sparsity-aware stride walk
6static_smemSMEM layout passed the static-shape predicate
7static_vmapCTA V-map passed the static-shape predicate
8gmem_int_strideGMEM layout passed the integer-stride walk
9smem_int_strideSMEM layout passed the integer-stride walk
10shape_equivtop-level shape equivalence between SMEM and V-map held
11g_basis_okG-basis computation returned a valid layout
12s2t_descriptorresult wraps a get_copy_s2t_smem_desc view (Blackwell SMEM-to-tmem descriptor)
13prefetch_eligibledescriptor handle survives prefetch (no per-axis dynamism that would invalidate it)
14reserved
15reserved

Bits 6 through 13 mirror the predicate checks the eleven-step verifier ran in steps 4 through 7 and 11. Folding the outcomes back into the flags word lets the executable-atom builder skip the equivalent predicates entirely — the partition verifier is the only place where these layout invariants get checked.

Eleven-Step Partition Verifier

The partition verifier walks eleven invariants in fixed order. Each gate emits a verbatim diagnostic on failure; the strings are part of the user-visible contract and a reimplementation must preserve them byte-for-byte.

#StepVerbatim diagnostic
1Type gate on the SMEM and GMEM operands"invalid operand types, got "
2SMEM layout-kind gate (LayoutType or ComposedLayoutType)"invalid smem layout type, expected LayoutType or ComposedLayoutType, got "
3GMEM layout-kind gate"unsupported layout for the GMEM tensor, got "
4Integer-stride walk on both layouts"expected the GMEM and SMEM layouts to have integer stride elements, but got "
5SMEM layout must be a swizzle layout"expected the SMEM layout to be a swizzle layout, but got "
6SMEM layout and CTA V-map must be static"expected the SMEM layout and the CTA V-map to be static, but got "
7Top-level shape equivalence between SMEM and V-map"expected top-level shape equivalence between the SMEM layout and the CTA V-map, but got "
8TMA G-basis computation"failed to compute the TMA G-basis, got "
9Final TMA layout validity check"Computed TMA layout is invalid, got "
10TMA tensor-type construction"Failed to construct the TMA tensor type"
11Multicast-count consistency (load variant only)"missing or invalid num_multicast for a multicast TMA load"

Order matters. The cheap type and structural gates — steps 1 through 6 — run before the more expensive G-basis and layout-product computations in steps 8 and 9. Step 11 is specific to the load variant; the store and reduce variants skip it because TMA store and reduce never take a multicast operand.

A twelfth string "got num_multicast of " is emitted as a companion to step 11 when the multicast mode is non-multicast (mode 0 or mode 2) but the supplied num_multicast value is not 1. The two error paths share the same FAIL label and treat the pair as one diagnostic: a missing or zero multicast count for a multicast mode, or a non-unit count for a non-multicast mode.

Treat only the descriptor base pointer, per-axis dimension sizes, and non-leading strides as device-mutable. Rank, element type, swizzle, multicast count, and mode are descriptor-construction facts and cannot change once the partition op has verified.

Worked Example: Rank-6 Rejection

A TMA descriptor builder consuming a rank-6 input lands on step 1 of the partition verifier. The SMEM and GMEM types print as a rank-6 layout, which is outside the accepted LayoutType / ComposedLayoutType union the partition core requires, and the diagnostic chain emits the verbatim ladder shown below before the verifier returns failure.

// Input op — rank 6 is one above the TMA hardware cap.
%bad = cute_nvgpu.atom.non_exec_tiled_tma_load
       %desc_r6, %tile_r6, %cta_map_r6 : !cute_nvgpu.tma_descriptor_tiled
error: invalid operand types, got !cute.layout<(a,b,c,d,e,f),...>, !cute.layout<(a,b,c,d,e,f),...>, and !cute.layout<...>

The rank-6 tile prints into the first <smem_ty> slot, the rank-6 GMEM type into the second <gmem_ty> slot, and the CTA V-map into the trailing <v_map_ty> slot. The verifier prints all three because the failing condition is a combination — the type gate runs on the trio as a unit, so the diagnostic must show every operand that participated.

A stride-4-byte (not 16-byte-aligned) input fails one step later. Steps 1 through 3 pass because the layout kinds are accepted; step 4 walks the GMEM stride tuple, finds a non-integer or below-16-byte entry, and emits:

error: expected the GMEM and SMEM layouts to have integer stride elements, but got !cute.layout<...>, and !cute.layout<...>

The two printed layouts are the GMEM and SMEM layouts in the same order step 1 printed them. The trailing " and " between the two arguments is the same shared format helper the type-gate diagnostic uses.

Descriptor Builder

Descriptor construction consumes a global tensor, a layout, dynamic shapes, dynamic strides, padding values, TMA mode, store mode, element width, multicast metadata, and operand segment sizes.

TmaDescriptor build_tma_descriptor(Tensor tensor,
                                   Layout layout,
                                   ArrayRef<Value> shapes,
                                   ArrayRef<Value> strides,
                                   TmaMode mode,
                                   TmaStoreMode store_mode) {
    require(tensor.memory_space == GLOBAL_MEMORY);
    require(rank(tensor) >= 1 && rank(tensor) <= 5);
    require(!is_composed_layout(layout));
    require(layout_is_static_enough_for_tma(layout));

    TmaDescriptor desc;
    desc.base = tensor.base;
    desc.element_bits = bit_width(tensor.element_type);
    desc.rank = rank(tensor);
    desc.box = compute_box_sizes(layout, shapes);
    desc.strides = compute_tma_strides(layout, strides);
    desc.mode = mode;
    desc.store_mode = store_mode;
    desc.cache_policy = default_cache_policy();
    return desc;
}

The first box dimension times the element bit width must divide evenly by the TMA transfer granularity. Padding values are restricted — non-zero padding requires a mode that explicitly supports it.

Non-Exec Atom Verification

The shared non-exec verifier checks the tuple of shared-memory layout, global layout, partitioner tile, and CTA value map. Success yields a TMA tensor type and a non-executing atom body ready to bind to runtime state later.

LogicalResult verify_non_exec_tma(NonExecTmaAtom atom) {
    require(is_smem_layout(atom.smem_layout));
    require(is_global_layout(atom.global_layout));
    require(is_tile_like(atom.partitioner));
    require(is_cta_value_map(atom.cta_v_map));
    require(smem_layout_uses_supported_swizzle(atom.smem_layout));
    require(layouts_are_statically_resolvable(atom.smem_layout, atom.cta_v_map));
    require(tma_partition_is_valid(atom));
    return success();
}

Load, store, and reduce variants add mode-specific checks. TMA reduce accepts only the reductions the target instruction family supports.

Executable Atom Binding

atom.make_exec_tma turns a non-exec atom into an executable atom by attaching runtime state:

ExecTmaAtom make_exec_tma(NonExecTmaAtom atom,
                          MBarrier barrier,
                          CacheMode cache,
                          Optional<MulticastMask> multicast) {
    require(atom.verified);
    require(barrier.memory_space == SHARED_MEMORY);

    ExecTmaAtom exec;
    exec.atom = atom;
    exec.barrier = barrier;
    exec.cache_mode = cache;
    exec.multicast = multicast;
    return exec;
}

Executable TMA lowering increments the barrier transaction count by the number of bytes the transfer will complete.

Lowering Shape

void lower_tma_load(ExecTmaAtom atom, MemRef dst, Coord coord) {
    require(atom.atom.kind == TMA_LOAD);
    require(dst.memory_space == SHARED_MEMORY);
    require(coord.rank == atom.atom.descriptor.rank);

    prefetch_descriptor_if_requested(atom.atom.descriptor);
    emit_cp_async_bulk_tensor_load(atom.atom.descriptor,
                                   dst,
                                   coord,
                                   atom.barrier,
                                   atom.cache_mode,
                                   atom.multicast);
}

void lower_tma_store(ExecTmaAtom atom, MemRef src, Coord coord) {
    require(atom.atom.kind == TMA_STORE);
    require(src.memory_space == SHARED_MEMORY);
    emit_cp_async_bulk_tensor_store(atom.atom.descriptor, src, coord, atom.cache_mode);
}

TMA load completes through an mbarrier — a consumer must wait on the barrier before using the destination tile. TMA store and reduce follow the target's async-bulk ordering rules and must not be reordered across conflicting memory effects.

Descriptor Mutation

Device-side descriptor mutation is limited to three fields: the global base pointer, per-axis dimension extents, and non-leading strides (the leading stride is implicit element-size and never written). The atom dialect exposes those three changes as dedicated update kinds rather than as a general byte write, so verification can reject any other mutation at IR construction time:

void update_tma_descriptor(TmaDescriptor *desc, TmaUpdate update) {
    switch (update.kind) {
    case UPDATE_BASE_POINTER:
        desc->base = update.base;
        break;
    case UPDATE_DIM:
        desc->shape[update.axis] = update.value;
        break;
    case UPDATE_STRIDE:
        require(update.axis > 0);
        desc->strides[update.axis] = update.value;
        break;
    default:
        fail("TMA descriptor field is not device-mutable");
    }
}

The three update kinds map directly to the tensormap.replace.tile.{global_address, global_dim, global_stride} PTX mutator family. The rebind sequence on the device side — acquire fence, address write, rank dim writes, rank-1 stride writes, release fence — and the proxy-fence ordering that pairs each rebind with its cp.async.bulk.tensor.* consumer is documented in TMA Tensormap and cp.async.bulk — TMA Descriptor Mutators. The descriptor builder above is the partition-time view; the atom-lowering page covers how the runtime side issues those three mutators in the contractually mandated order.

If You Know CUTLASS (open source) — cross-walk

Coming from CUTLASS Hopper/Blackwell TMA usage:

CUTLASS C++tileiras IR (cute_nvgpu)
cuTensorMapEncodeTiled(&tmap, ...) (host-side, runtime API)nv_tileas.make_tiled_tma_desc op materialising a !tma_descriptor_tiled typed value
cuTensorMapEncodeIm2col(&tmap, ...)nv_tileas.make_tiled_tma_desc with im2col mode → !tma_descriptor_im2col
cute::SM90_TMA_LOAD::copy(...)cute_nvgpu.atom.tma_load op (after make_exec_tma binding)
cute::SM90_TMA_STORE::copy(...)cute_nvgpu.atom.tma_store op
cute::SM90_TMA_REDUCE_ADD::copy(...)cute_nvgpu.atom.tma_reduce with kind = TMA_REDUCE_ADD
Multicast TMA (SM90_TMA_LOAD_MULTICAST)tma_load_mode attribute on the partition op
cute::prefetch_tma_descriptor(tmap)cute_nvgpu.prefetch_tma_desc op
mbarrier::arrive_and_expect_tx(mbar, bytes) paired with TMAbarrier operand + expect_tx attribute on the executable TMA op

The structural difference: in CUTLASS the descriptor is an opaque CUtensorMap blob bound at runtime. Tileiras carries rank, element width, swizzle mode, box shape, and stride layout as typed IR attributes the partition verifier re-checks before each TMA op lowers. Device-side mutation is restricted to base pointer, per-axis dimension, and non-leading stride (see Descriptor Mutation above) — the same surface the hardware allows, exposed through dedicated ops rather than raw byte writes.

Worked Example

%desc = nv_tileas.make_tiled_tma_desc %tensor, %layout
    shapes(%m, %n, %k) strides(%sn, %sk) paddings()
    {mode = #cute_nvgpu.tma_load_mode<tiled>,
     elementBitWidth = 16} : !cute_nvgpu.tma_descriptor_tiled

%atom = cute_nvgpu.atom.non_exec_tiled_tma_load %desc, %tile, %cta_map
    {num_multicast = 1}

%exec = cute_nvgpu.atom.make_exec_tma %atom, %mbar
    {cache_mode = #cute_nvgpu.load_cache_mode<cg>}

cute_nvgpu.atom.tma_load %exec, %smem_tile, %coord
    {allow_tma = true, inBounds = true}

After lowering the executable load becomes a cp.async.bulk.tensor-style op with descriptor, coordinate, destination, barrier, and optional cache or multicast modifiers.

Invariants

  • TMA rank is between one and five.
  • Descriptor pointers are aligned to the hardware descriptor requirement.
  • Composed layouts are rejected where the descriptor builder needs a plain static layout.
  • Shared-memory layouts use supported swizzle modes.
  • Global and shared layouts agree with the partitioner and CTA value map.
  • Descriptor base, dimensions, and strides are the only mutable device fields.
  • TMA load completion is ordered through an mbarrier.
  • Im2col and multicast modes are architecture-gated.

Cross-References

Mode Pattern Verifiers — Swizzle Legality documents the swizzle-legality, UMMA Canonical Layout Verifier, and tcgen05.mma Kind-Word Verifier verifiers that the TMA partition core composes with. SM Tier Roster and Copy Atom Registry — Atom TypeID Registry covers the SM90/SM100/SM120 atom interfaces TMA atoms implement. cute Atom Builders and Desugar — Kernel-entry ABI covers the kernel-entry ABI that hoists TMA descriptors as .param constant-space arguments.