The Global GEMM ; Putting It All Together
Difficulty: IntermediateβAdvanced Prerequisites: Tutorial 02: Partitioning, Tutorial 04: TiledCopy, Tutorial 05: Swizzling, Tutorial 06: Hello, MMA
1. The Problem (The "Why")
Tutorial 06 triggered a single Tensor Core instruction on a 16Γ8Γ16 (MΓNΓK) micro-tile. That's 2048 multiply-adds ; impressive for one instruction, but real-world matrices are thousands of elements wide. You need to compute C(MΓN) = A(MΓK) Γ Bα΅(NΓK) where M, N, and K can each be 4096 or more.
The answer isn't one giant instruction. It's tiling ; the same idea from Tutorial 02, now applied at every level:
| Level | What's tiled | Controlled by |
|---|---|---|
| CTA grid | Output C is divided into BLK_MΓBLK_N tiles, one per thread block | local_tile + blockIdx |
| K-loop | The reduction dimension K is divided into BLK_K chunks, processed sequentially | The mainloop for loop |
| Warp MMA | Each BLK_MΓBLK_N tile is further divided by the TiledMMA across 4 warps |
partition_A/B/C + gemm() |
This is the structure of every high-performance GEMM on GPUs ; from CUTLASS to cuBLAS. The names change but the three-level tiling is always the same. This tutorial writes that kernel from scratch, using nothing but the CuTe primitives you already know.
B200 Note: On Hopper/Blackwell, the structure is identical but the data-movement layer changes: TiledCopy with
cp.asyncor TMA replaces explicit loads, and the K-loop becomes a software pipeline (Tutorial 10). The compute layer (TiledMMA+gemm()) stays the same.
2. The Mental Model (The Visual)
Running the Factory for a Full Production Day
Tutorials 03β06 built the factory:
- Loading dock (TiledCopy) ; moves raw materials from the truck (global memory) into the warehouse (shared memory), using dollies (vectorized loads) and staggered shelving (swizzle).
- Stamping press (TiledMMA) ; processes materials from the warehouse shelves into finished products (C accumulator).
Now we run the factory for a full production day:
Global Memory (The Truck)
ββββββββββββββββββββββββββββββββββββββββββββββββββ
β A (M Γ K) B (N Γ K) β
ββββββββββββ¬ββββββββββββββββββββββ¬ββββββββββββββββ
β β
ββββββββββββͺββββββββββββββββββββββͺββββββ K-loop (the shift) ββββ
β βΌ βΌ β
β βββββββββββββββ βββββββββββββββ β
β β A tile β β B tile β β TiledCopy β
β β (128 Γ 32) β β (128 Γ 32) β gmem β smem β
β ββββββββ¬βββββββ ββββββββ¬βββββββ β
β β Shared Memory β β
β βΌ (swizzled) βΌ β
β βββββββββββββββββββββββββββββββββββ β
β β STAMPING PRESS β β TiledMMA β
β β partition β fragment β gemm β smem β regs β MMA β
β ββββββββββββββββββ¬βββββββββββββββββ β
β β β
β accum += partial C β
β β
βββββββββββββββββββββββββββββββββββββββ repeat for each K tile ββ
β
βΌ Epilogue (end of shift ; ship the product)
ββββββββββββ
β C tile β β copy accum β gmem
β (128Γ128) β
ββββββββββββ
Each thread block is an independent factory. The grid of blocks tiles the MΓN output:
Grid of CTAs over the output C (M Γ N):
N β
βββββββββββ¬ββββββββββ¬ββββββββββ¬ββββ
β CTA β CTA β CTA β β
M β (0,0) β (0,1) β (0,2) β...β Each CTA computes
β β 128Γ128 β 128Γ128 β 128Γ128 β β one 128Γ128 tile of C
βββββββββββΌββββββββββΌββββββββββΌββββ€
β CTA β CTA β CTA β β
β (1,0) β (1,1) β (1,2) β...β
β 128Γ128 β 128Γ128 β 128Γ128 β β
βββββββββββΌββββββββββΌββββββββββΌββββ€
β ... β ... β ... β β
βββββββββββ΄ββββββββββ΄ββββββββββ΄ββββ
grid = (M / 128, N / 128)
The Three Phases
Phase 1: COPY (the loading dock)
gmem A[BLK_M Γ BLK_K] ββTiledCopyβββΆ smem_A (swizzled)
gmem B[BLK_N Γ BLK_K] ββTiledCopyβββΆ smem_B (swizzled)
__syncthreads()
Phase 2: COMPUTE (the stamping press)
smem_A ββpartition_AβββΆ registers βββ
βββgemm()βββΆ accum += partial
smem_B ββpartition_BβββΆ registers βββ
__syncthreads()
Phase 3: EPILOGUE (shipping)
accum ββpartition_CβββΆ gmem C[BLK_M Γ BLK_N]
Phases 1 and 2 repeat ceil(K / BLK_K) times. Phase 3 happens once.
3. The Solution (The Code)
A complete GEMM kernel: C (MΓN, float) += A (MΓK, half) Γ B^T (NΓK, half). Uses 128 threads (4 warps), 128Γ128 CTA tiles, and SM80 Tensor Cores.
Constraints: M and N must be multiples of 128; K must be a multiple of 32. (Removing these constraints requires boundary masking ; important for production, orthogonal to the core GEMM structure.)
#include <cute/tensor.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/swizzle.hpp>
#include <cuda_fp16.h>
#include <cstdio>
#include <cstdlib>
#include <cmath>
using namespace cute;
// Tile sizes
constexpr int BLK_M = 128;
constexpr int BLK_N = 128;
constexpr int BLK_K = 32;
constexpr int NUM_THREADS = 128; // 4 warps
__global__ void gemm_kernel(
half const* __restrict__ A_ptr,
half const* __restrict__ B_ptr,
float* __restrict__ C_ptr,
int M, int N, int K)
{
// SETUP: Layouts, TiledCopy, TiledMMA
// Global memory tensors
// A: (M, K) with K stride-1 B: (N, K) with K stride-1
// C: (M, N) with N stride-1
auto mA = make_tensor(make_gmem_ptr(A_ptr), make_shape(M, K), make_stride(K, Int<1>{}));
auto mB = make_tensor(make_gmem_ptr(B_ptr), make_shape(N, K), make_stride(K, Int<1>{}));
auto mC = make_tensor(make_gmem_ptr(C_ptr), make_shape(M, N), make_stride(N, Int<1>{}));
// CTA tiling: each block claims its tile of A, B, C
auto cta_coord = make_coord(blockIdx.x, blockIdx.y);
// gA: (BLK_M, BLK_K, k_tiles) ; this block's strip of A, sliced along K
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_K>{}),
make_coord(get<0>(cta_coord), _));
// gB: (BLK_N, BLK_K, k_tiles)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_K>{}),
make_coord(get<1>(cta_coord), _));
// gC: (BLK_M, BLK_N) ; this block's output tile
auto gC = local_tile(mC, make_shape(Int<BLK_M>{}, Int<BLK_N>{}), cta_coord);
// Shared memory layouts: K stride-1, swizzled
// Swizzle<3,3,3>: M=3 β 8 halfs (128 bits) stay contiguous for vectorized loads
auto sA_layout = composition(Swizzle<3, 3, 3>{},
make_layout(make_shape(Int<BLK_M>{}, Int<BLK_K>{}),
make_stride(Int<BLK_K>{}, Int<1>{})));
auto sB_layout = composition(Swizzle<3, 3, 3>{},
make_layout(make_shape(Int<BLK_N>{}, Int<BLK_K>{}),
make_stride(Int<BLK_K>{}, Int<1>{})));
// Shared memory allocation (static)
__shared__ half smem_A[cosize_v<decltype(sA_layout)>];
__shared__ half smem_B[cosize_v<decltype(sB_layout)>];
auto sA = make_tensor(make_smem_ptr(smem_A), sA_layout);
auto sB = make_tensor(make_smem_ptr(smem_B), sB_layout);
// TiledCopy: gmem β smem via cp.async
// Each atom loads 128 bits = 8 halfs along the K (stride-1) direction.
// 128 threads, one per M-row. Each thread loads 32 K-values (4 Γ 128-bit async copies).
auto tiled_copy = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, half>{},
Layout<Shape<_128, _1>>{}, // 128 threads all in M
Layout<Shape<_1, _32>>{} // 32 K-values per thread (4 atoms of 8 halfs)
);
// Coverage: (128 Γ 1, 1 Γ 32) = (128, 32) β
// TiledMMA: the Tensor Core plan
// SM80_16x8x16 atom, 4 warps arranged (2,2) in MΓN.
// Base coverage: (2Γ16, 2Γ8) = (32, 16) per atom group.
// For 128Γ128: gemm() iterates 4Γ in M, 8Γ in N, 2Γ in K automatically.
auto tiled_mma = make_tiled_mma(
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>{},
Layout<Shape<_2, _2, _1>>{} // 4 warps: 2 in M, 2 in N
);
// PARTITIONS: thread-level views
// Copy partitions
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
auto tAgA = thr_copy.partition_S(gA); // (CPY_VEC, CPY_M, CPY_K, k_tiles)
auto tAsA = thr_copy.partition_D(sA); // (CPY_VEC, CPY_M, CPY_K)
auto tBgB = thr_copy.partition_S(gB); // same structure for B
auto tBsB = thr_copy.partition_D(sB);
// MMA partitions
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
auto tCsA = thr_mma.partition_A(sA); // (MMA_VAL, MMA_M, MMA_K)
auto tCsB = thr_mma.partition_B(sB); // (MMA_VAL, MMA_N, MMA_K)
auto tCgC = thr_mma.partition_C(gC); // (MMA_VAL, MMA_M, MMA_N)
// Register fragments
auto tCrA = make_fragment_like(tCsA);
auto tCrB = make_fragment_like(tCsB);
auto accum = make_fragment_like(tCgC); // C accumulator (float)
clear(accum);
// MAINLOOP: iterate over K tiles
int k_tiles = size<2>(gA); // = K / BLK_K
for (int k = 0; k < k_tiles; ++k) {
// Phase 1: COPY gmem β smem
copy(tiled_copy, tAgA(_, _, _, k), tAsA);
copy(tiled_copy, tBgB(_, _, _, k), tBsB);
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
// Phase 2: COMPUTE smem β registers β MMA
copy(tCsA, tCrA);
copy(tCsB, tCrB);
gemm(tiled_mma, tCrA, tCrB, accum);
__syncthreads();
}
// EPILOGUE: write C accumulator to global memory
copy(accum, tCgC);
}
// Host: verification + timing
int main()
{
constexpr int M = 512, N = 512, K = 256;
// Allocate and initialize
int sizeA = M * K, sizeB = N * K, sizeC = M * N;
half *h_A = new half[sizeA];
half *h_B = new half[sizeB];
float *h_C = new float[sizeC];
// A = all 1.0h, B = all 1.0h β C[m][n] = K
for (int i = 0; i < sizeA; ++i) h_A[i] = __float2half(1.0f);
for (int i = 0; i < sizeB; ++i) h_B[i] = __float2half(1.0f);
half *d_A, *d_B;
float *d_C;
cudaMalloc(&d_A, sizeA * sizeof(half));
cudaMalloc(&d_B, sizeB * sizeof(half));
cudaMalloc(&d_C, sizeC * sizeof(float));
cudaMemcpy(d_A, h_A, sizeA * sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, sizeB * sizeof(half), cudaMemcpyHostToDevice);
// Launch
dim3 grid(M / BLK_M, N / BLK_N); // (4, 4) for 512Γ512
dim3 block(NUM_THREADS); // 128
gemm_kernel<<<grid, block>>>(d_A, d_B, d_C, M, N, K);
cudaDeviceSynchronize();
// Verify
cudaMemcpy(h_C, d_C, sizeC * sizeof(float), cudaMemcpyDeviceToHost);
float max_err = 0.0f;
float expected = float(K); // K * 1.0 * 1.0
for (int i = 0; i < sizeC; ++i)
max_err = fmaxf(max_err, fabsf(h_C[i] - expected));
printf("GEMM: C(%d x %d) = A(%d x %d) * B^T(%d x %d)\n", M, N, M, K, N, K);
printf("Tile: %d x %d x %d, Threads: %d\n\n", BLK_M, BLK_N, BLK_K, NUM_THREADS);
// Print a 4Γ4 corner
printf("C[0:4][0:4] (expected %.0f everywhere):\n", expected);
for (int m = 0; m < 4; ++m) {
printf(" ");
for (int n = 0; n < 4; ++n)
printf("%8.1f", h_C[m * N + n]);
printf("\n");
}
printf("\nMax absolute error: %e\n", max_err);
printf("Result: %s\n", max_err < 1e-3f ? "PASS" : "FAIL");
// Timing
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
int warmup = 5, iters = 20;
for (int i = 0; i < warmup; ++i)
gemm_kernel<<<grid, block>>>(d_A, d_B, d_C, M, N, K);
cudaDeviceSynchronize();
cudaEventRecord(start);
for (int i = 0; i < iters; ++i)
gemm_kernel<<<grid, block>>>(d_A, d_B, d_C, M, N, K);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
float ms;
cudaEventElapsedTime(&ms, start, stop);
float avg_ms = ms / iters;
double flops = 2.0 * M * N * K;
double tflops = (flops / (avg_ms / 1000.0)) / 1e12;
printf("\nPerformance: %.3f ms (%.2f TFLOPS)\n", avg_ms, tflops);
// Cleanup
cudaFree(d_A); cudaFree(d_B); cudaFree(d_C);
cudaEventDestroy(start); cudaEventDestroy(stop);
delete[] h_A; delete[] h_B; delete[] h_C;
return 0;
}
Expected Output:
GEMM: C(512 x 512) = A(512 x 256) * B^T(512 x 256)
Tile: 128 x 128 x 32, Threads: 128
C[0:4][0:4] (expected 256 everywhere):
256.0 256.0 256.0 256.0
256.0 256.0 256.0 256.0
256.0 256.0 256.0 256.0
256.0 256.0 256.0 256.0
Max absolute error: 0.000000e+00
Result: PASS
Performance: X.XXX ms (Y.YY TFLOPS)
Each entry of C is 256.0: the dot product of a length-256 all-ones row from A with a length-256 all-ones row from B. The K-loop ran 256 / 32 = 8 iterations, each accumulating a BLK_K=32 partial product. After 8 iterations, the accumulator held the full result.
4. Step-by-Step Explanation
Setup Phase
Lines: Global tensor creation with make_tensor
auto mA = make_tensor(make_gmem_ptr(A_ptr), make_shape(M, K), make_stride(K, Int<1>{}));
Creates a CuTe tensor wrapping global memory. The stride (K, 1) means K is stride-1 ; the TN layout required by our MMA atom (Tutorial 06). Note that K is a runtime value while Int<1>{} is static. CuTe handles mixed static/dynamic strides.
Lines: CTA tiling with local_tile
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_K>{}),
make_coord(get<0>(cta_coord), _));
This is Tutorial 02's local_tile at the CTA level ; the casino floor manager assigning tables:
- Divides A's M dimension into chunks of BLK_M (128). Block
blockIdx.xclaims its chunk. - Divides A's K dimension into chunks of BLK_K (32). The underscore
_means "give me all K-tiles as a third mode." - Result shape:
(128, 32, k_tiles); this block's full strip of A, withk_tiles = K / BLK_Kslices along K.
For C, both coordinates are fixed (no underscore):
auto gC = local_tile(mC, make_shape(Int<BLK_M>{}, Int<BLK_N>{}), cta_coord);
Result shape: (128, 128) ; this block's one output tile.
Lines: Shared memory with swizzle
auto sA_layout = composition(Swizzle<3, 3, 3>{},
make_layout(make_shape(Int<128>{}, Int<32>{}), make_stride(Int<32>{}, Int<1>{})));
The base layout is (128, 32):(32, 1) ; K stride-1, same as gmem. composition(Swizzle<3,3,3>{}, ...) applies the XOR remapping from Tutorial 05. The parameters: M=3 preserves 8-half (128-bit) contiguity for vectorized stores, B=3 scrambles 8-bank groups, S=3 separates the bit fields. This is the standard swizzle for half-precision data.
Lines: TiledCopy construction
auto tiled_copy = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, half>{},
Layout<Shape<_128, _1>>{},
Layout<Shape<_1, _32>>{}
);
This is Tutorial 04's supervisor's clipboard, now sized for the GEMM tile:
- Copy_Atom:
SM80_CP_ASYNC_CACHEALWAYS; hardware async copy, 128 bits = 8 halfs per transaction along K (stride-1). Data goes directly from gmem to smem, bypassing registers. - thr_layout
(128, 1): All 128 threads assigned to the M dimension ; one thread per row. - val_layout
(1, 32): Each thread copies 32 K-values (4 atoms Γ 8 halfs each). - Coverage:
(128 Γ 1, 1 Γ 32) = (128, 32); exactly one smem tile. Each K-loop iteration fills the tile with onecopy()call per matrix.
Lines: TiledMMA construction
auto tiled_mma = make_tiled_mma(
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>{},
Layout<Shape<_2, _2, _1>>{}
);
Tutorial 06's stamping press, now with multiple warps. The second argument arranges 4 warps in a 2Γ2 grid over M and N:
- Base atom covers 16M Γ 8N Γ 16K.
- 2Γ2 warps β combined base: 32M Γ 16N Γ 16K.
- For the 128Γ128Γ32 smem tile,
gemm()internally loops 4Γ in M (128/32), 8Γ in N (128/16), and 2Γ in K (32/16). You don't write these loops ;gemm()handles them.
Partition Phase
Lines: Copy partitions
auto tAgA = thr_copy.partition_S(gA); // (CPY_VEC, CPY_M, CPY_K, k_tiles)
auto tAsA = thr_copy.partition_D(sA); // (CPY_VEC, CPY_M, CPY_K)
Same as Tutorial 04: partition_S creates this thread's view of the source (gmem), partition_D for the destination (smem). The extra k_tiles mode on tAgA comes from gA's third mode ; it passes through untouched. In the mainloop, tAgA(_, _, _, k) selects the k-th K-tile for copying.
Lines: MMA partitions and fragments
auto tCsA = thr_mma.partition_A(sA); // (MMA_VAL, MMA_M, MMA_K)
auto tCrA = make_fragment_like(tCsA); // register copy
auto accum = make_fragment_like(tCgC); // C accumulator
Same as Tutorial 06: partition_A maps smem elements to this thread's MMA fragment slots. make_fragment_like creates a register tensor with the same shape. The accumulator accum lives in registers for the entire kernel ; it's never written to smem.
Mainloop
for (int k = 0; k < k_tiles; ++k) {
copy(tiled_copy, tAgA(_, _, _, k), tAsA); // gmem β smem
copy(tiled_copy, tBgB(_, _, _, k), tBsB);
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
copy(tCsA, tCrA); // smem β registers
copy(tCsB, tCrB);
gemm(tiled_mma, tCrA, tCrB, accum); // MMA!
__syncthreads();
}
Each iteration processes one BLK_K=32 slice of the K dimension:
Copy ; TiledCopy loads A's k-th tile
(128Γ32)and B's k-th tile(128Γ32)from global to shared memory. The swizzle is transparent ; the TiledCopy writes through swizzled addresses automatically.Sync ; All 128 threads must finish writing to smem before any thread starts reading for the MMA.
Compute ; Copy from swizzled smem to register fragments, then
gemm()executes the MMA atoms. Because the tile is larger than one atom,gemm()internally loops over M, N, and K repetitions. The result accumulates intoaccum; each K-iteration adds to the running sum, computing the partial dot products.Sync ; All threads must finish reading smem before the next iteration overwrites it.
After k_tiles iterations, accum holds the complete C tile.
Epilogue
copy(accum, tCgC);
Each thread writes its accumulator fragment (128 floats for this config) to global memory via the MMA's C partition. As noted in Tutorial 06, the C fragments are scattered ; each thread writes to non-contiguous addresses. For a production kernel, CUTLASS routes the epilogue through shared memory for coalesced writes. For correctness, the simple scattered write works fine.
5. Engineer's Notebook (Latent Space Notes)
Analogy: The Full Production Day.
- CTA grid = multiple independent factories, one per output tile. They never interact.
- K-loop (mainloop) = one production shift. Each iteration loads a batch of raw materials (A/B K-tiles via TiledCopy), processes them on the stamping press (TiledMMA + gemm), and adds the result to the running accumulator.
- Epilogue = end of shift. Ship the finished product (accumulated C tile) to global memory.
The three-level tiling is the single most important structural insight for GPU GEMM. Everything else ; pipelining, double buffering, warpgroup MMA ; is optimization on top of this skeleton.
Tile size cheat sheet:
| Parameter | Typical Values | Trade-off |
|---|---|---|
| BLK_M Γ BLK_N | 128Γ128, 128Γ256, 256Γ128 | Larger tiles β more compute per gmem load (better arithmetic intensity). But more smem usage β fewer CTAs per SM β less occupancy. |
| BLK_K | 32, 64 | Larger K β fewer mainloop iterations β less loop overhead. But more smem per tile (BLK_M Γ BLK_K Γ sizeof(half) per matrix). |
| Threads | 128 (4 warps), 256 (8 warps) | More threads β more copy bandwidth, but more MMA rep sharing overhead. |
Why the swizzle parameters are <3, 3, 3> for half:
- M=3: 2^3 = 8 halfs = 128 bits stay contiguous. Matches the 128-bit vectorized copy.
- B=3: 2^3 = 8 banks scrambled. Breaks the column-stride conflict pattern.
- S=3: Source bits
[6:9), target bits[3:6). Non-overlapping.
Compare with Tutorial 05's Swizzle<3, 2, 3> for float: M=2 (4 floats = 128 bits). The only difference is the free-bit count M, which must match the element size for 128-bit vectorization.
Why two __syncthreads() in the mainloop:
Iteration k: Iteration k+1:
copy gmem β smem copy gmem β smem β OVERWRITES smem!
ββsyncββ (1) ββsyncββ
gemm reads smem gemm reads smem
ββsyncββ (2) ββsyncββ
Sync (1): Ensure all threads finished writing smem before any thread reads it for gemm. Sync (2): Ensure all threads finished reading smem before the next iteration's copy overwrites it.
Without sync (2), a fast thread could start writing the (k+1)-th A tile into smem while a slow thread is still reading the k-th A tile for its gemm.
The epilogue bottleneck:
Our epilogue uses copy(accum, tCgC) ; each thread writes its C fragments directly to global memory with scattered stores. On A100, this wastes memory bus bandwidth because each 4-byte float write fetches a full 128-byte cache line.
CUTLASS's production epilogues:
- Each warp writes its C fragment to smem (in a contiguous layout).
- Threads read back from smem in a coalesced pattern and write to gmem.
- Optional: fuse element-wise operations (bias add, ReLU) into the epilogue.
For this tutorial, the simple epilogue is correct and clear. Optimizing it is a straightforward extension.
Performance expectations:
This kernel will not match cuBLAS. Missing optimizations:
- No double-buffering (the copy and compute phases are serialized)
- No software pipelining (Tutorial 10)
- No TMA (Tutorial 08)
- Scattered C epilogue
- No predication for boundary tiles
On A100 with 512Γ512Γ256, expect 5β15 TFLOPS (vs. ~150 TFLOPS peak for cuBLAS with large matrices). The point is not speed ; it's understanding the structure that every fast GEMM kernel shares.
Gotcha ; dynamic vs. static shapes: The global tensors use dynamic shapes (runtime M, N, K), but the smem layouts and tile shapes use static
Int<N>{}. This is intentional ; CuTe needs static shapes to generate efficient code (compile-time layout math, unrolled loops). Dynamic shapes in gmem are fine because the copy and MMA operate on the statically-shaped smem/register tiles.
Gotcha ; smem capacity: Two smem tiles of 128Γ32 halfs = 2 Γ 128 Γ 32 Γ 2 bytes = 16 KB. Well within SM80's 164 KB limit. But 128Γ128 tiles with BLK_K=64 would need 32 KB ; still fine, but leaves less for double buffering. Always check:
smem_bytes = 2 Γ BLK_M Γ BLK_K Γ sizeof(half).
What comes next: This kernel serializes copy and compute ; the Tensor Cores are idle while waiting for data loads. Tutorial 08 (The TMA Revolution) introduces hardware-accelerated copies that free the threads. Tutorial 10 (Pipeline Overlap) shows how to overlap copy and compute using ping-pong buffering, approaching peak throughput.