diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -238,7 +238,6 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams ¶ms) { // One thread per 128b row. - const int64_t kNumThreadsPerTile = kNumRowsPerTile; const int bitsPerElement = static_cast( params.fragmentType.getElementType().getIntOrFloatBitWidth()); const int kElementsPer128b = (128 / bitsPerElement); @@ -249,27 +248,28 @@ return AffineMap::get(1, 0, dimExprs, builder.getContext()); }; - // This case corresponds to row-major A|C or col-major B operands. - if (params.contiguousDimType == vector::IteratorType::reduction) { - AffineExpr row = d0 % (operandShape[0]); - AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); - return makeMap({row, col}); - } + // Index `idx` in vectorType `operandShape` maps to the strided dimension of + // the `srcMemref` memory of the LdMatrixOp. + int idx = + (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1; + + // Affine expr in strided and contiguous dimension encodes the coordinate + // mapping for the element a thread points to for warp-wide LdMatrixOp. + AffineExpr strided = d0 % (operandShape[idx]); + AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b); + + // This case corresponds to row-major matrixA or col-major matrixB or + // row-major matrixC. This is when the memory layout in `srcMemref` + // match mma.sync hardware vector register operand layout. + if (params.contiguousDimType == vector::IteratorType::reduction) + return makeMap({strided, contiguous}); + + // This case corresponds to col-major matrixA or row-major matrixB or + // col-major matrixC. This is when the memory layout in `srcMemref` does not + // match mma.sync hardware vector register operand layout. + if (params.contiguousDimType == vector::IteratorType::parallel) + return makeMap({contiguous, strided}); - // This case Corresponds to col-major A|C or row-major B operands. The - // operandShape given is already pre-transposed (e.g. 8x16 = KxN). - if (params.contiguousDimType == vector::IteratorType::parallel) { - const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; - // Threads are assigned in groups of 8 first across columns, then to - // rows. This is transpose of what `ldmatrix` expects, but when - // `ldmatrix` gets the `.trans` qualifier, final the effect will be to - // transpose just the blocks. - auto groupIdx = d0.floorDiv(kNumThreadsPerTile); - auto tileCol = (groupIdx % num8x128bCols); - auto tileRow = groupIdx.floorDiv(num8x128bCols); - return makeMap({tileCol * kElementsPer128b, - tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); - } return failure(); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -4,8 +4,8 @@ // INT8 row-row-row //######################################################### -// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)> // CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)> // CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> @@ -40,14 +40,15 @@ %cst = arith.constant 0 : i8 %cst0 = arith.constant 0 : i32 - // Verify that the operand A is distributed to loads correctly. + // Verify that the operandA load is lowered to warp-wide ldmatrix. - // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}] - // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_map]]()[{{%.+}}] + // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]()[{{%.+}}] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> - // Verify that the operand B is distributed to loads correctly. It's elements - // must be loaded in a non-vectorized manner to do the transpose. + // Verify that the operandB load is lowered to scalar load to be able + // to transpose at 8-bit granularity. ldmatrix can only transpose at + // 16-bit granularity. // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}] // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] @@ -84,7 +85,7 @@ // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> // CHECK-NOT: vector.load %arg2{{.*}} - %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8> %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> @@ -173,28 +174,23 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> - -// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> -// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> // CHECK-LABEL: func @m16n8k16_fp16_row_row_row func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] - // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} - %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> @@ -207,10 +203,8 @@ // FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB) //######################################################################### -// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> -// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> -// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)> -// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -224,19 +218,19 @@ %c8 = arith.constant 8 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] - // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true} + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true} %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false} %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16> // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> @@ -259,10 +253,8 @@ } // ----- -// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> -// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)> -// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> #map0 = affine_map<(d0, d1, d2) -> (d2, d1)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -274,26 +266,24 @@ %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> // CHECK: [[C0:%.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] - // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> - %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]] - // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> - %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] - // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> - %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[m_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> - vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> return } @@ -307,36 +297,36 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> +// Affine maps for ldmatrix x4 tile of `16 x 16` f16 elements in `strided x contiguous` dimensions. +// CHECK: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> -// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)> -// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)> +// CHECK: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8)> // CHECK-LABEL: func @m16n8k16_fp16_row_col_row func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32 + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32 // CHECK-SAME: transpose = false - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] - // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32 // CHECK-SAME: transpose = false - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[m_coord]], [[n_coord]]] {numTiles = 2 : i32 // CHECK-SAME: transpose = false - %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> - %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> return @@ -345,7 +335,7 @@ // ----- //######################################################### -// TF32 (multiplicand) F32 (accumulator) row-row-row +// TF32 row-row-row //######################################################### #map0 = affine_map<(d0, d1) -> (d1, d0)> @@ -406,6 +396,9 @@ // ----- +//######################################################### +// TF32 row-row-row +//######################################################### #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> @@ -467,13 +460,88 @@ // ----- //######################################################### -// INT4 row-col-row +// TF32 col-col-row //######################################################### +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4)> +// CHECK-DAG: [[$rowA8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> +// CHECK-DAG: [[$colA4_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 4)> -// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> -// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 32)> // CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> -// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 32)> +// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 4)> + +// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 16)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 24)> +// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 8)> + +// CHECK-LABEL: func @m16n8k8_tf32_f32_col_col_row +func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32> + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] + // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [0, 0] : f32 into vector<4x1xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA8_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] + // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [1, 0] : f32 into vector<4x1xf32> + + // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [2, 0] : f32 into vector<4x1xf32> + // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [3, 0] : f32 into vector<4x1xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]] + // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} + + // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]]) + // CHECK-SAME: mmaShape = [16, 8, 8] + // CHECK-SAME: -> vector<2x2xf32> + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, 3>, vector<16x8xf32> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> + + // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC8_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + vector.transfer_write %D, %arg2[%c16, %c8] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> + return +} + +// ----- + +//######################################################### +// INT4 row-col-row +//######################################################### +// Affine maps for loading operandA and operandB +// maps (laneid -> coordinate pointed by the lane in the ldmatrix operand tile) +// CHECK-DAG: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 32)> +// CHECK-DAG: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-DAG: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 32)> + +// Affine maps for accumulator registers +// maps (laneid -> coordinate pointed by the lane in accumulator register tile) // CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> // CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 // CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> @@ -490,14 +558,14 @@ %c0 = arith.constant 0 : index // CHECK: [[lane:%.+]] = gpu.lane_id - // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<4x8xi4> + // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]()[[[lane]]] + // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]()[[[lane]]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<4x8xi4> // CHECK: [[lane:%.+]] = gpu.lane_id - // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]] - // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<2x8xi4> + // CHECK: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]]()[[[lane]]] + // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]]()[[[lane]]] + // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<2x8xi4> // CHECK: [[lane:%.+]] = gpu.lane_id // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] @@ -534,12 +602,15 @@ //######################################################### // INT8 row-col-row //######################################################### - -// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> -// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)> -// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> -// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 16)> - +// Affine maps for loading operandA and operandB +// maps (laneid -> coordinate pointed by the lane in the ldmatrix operand tile) +// CHECK-DAG: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)> +// CHECK-DAG: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-DAG: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 16)> + +// Affine maps for accumulator registers +// maps (laneid -> coordinate pointed by the lane in accumulator register tile) // CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> // CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)> // CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> @@ -554,27 +625,26 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) { %cst_0 = arith.constant dense<0> : vector<32x8xi8> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index %cst = arith.constant 0 : i8 %cst0 = arith.constant 0 : i32 // CHECK: [[lane:%.+]] = gpu.lane_id - // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]()[[[lane]]] + // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]()[[[lane]]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> // CHECK: [[lane:%.+]] = gpu.lane_id - // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]] - // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<2x4xi8> + // CHECK: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]]()[[[lane]]] + // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]]()[[[lane]]] + // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<2x4xi8> // CHECK: [[lane:%.+]] = gpu.lane_id - // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]] - // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> - // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]] - // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]] - // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[m_coord:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]] + // CHECK: [[n_coord:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]] + // CHECK: vector.load %arg2[[[m_coord]], [[n_coord]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[m_coord:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]] + // CHECK: [[n_coord:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]] + // CHECK: vector.load %arg2[[[m_coord]], [[n_coord]]] : memref<128x128xi32>, vector<2xi32> // CHECK-NOT: vector.load %arg2{{.*}} %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> @@ -595,73 +665,3 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> return } - -// ----- - -#map0 = affine_map<(d0, d1) -> (d1, d0)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> -// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4)> -// CHECK-DAG: [[$rowA8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> -// CHECK-DAG: [[$colA4_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 4)> - -// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)> -// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 4)> - -// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 16)> -// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 24)> -// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 8)> - -// CHECK-LABEL: func @m16n8k8_tf32_f32_col_col_row -func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) { - %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32> - %c0 = arith.constant 0 : index - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %cst = arith.constant 0.000000e+00 : f32 - - // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32> - - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] - // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> - // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [0, 0] : f32 into vector<4x1xf32> - - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA8_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] - // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> - // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [1, 0] : f32 into vector<4x1xf32> - - // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> - // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [2, 0] : f32 into vector<4x1xf32> - // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> - // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [3, 0] : f32 into vector<4x1xf32> - - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]] - // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} - - // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]]) - // CHECK-SAME: mmaShape = [16, 8, 8] - // CHECK-SAME: -> vector<2x2xf32> - %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, 3>, vector<16x8xf32> - %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], - kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> - - // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> - // CHECK: affine.apply [[$rowC_map]] - // CHECK: affine.apply [[$colC_map]] - // CHECK: vector.store - // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32> - // CHECK: affine.apply [[$rowC8_map]] - // CHECK: affine.apply [[$colC_map]] - // CHECK: vector.store - vector.transfer_write %D, %arg2[%c16, %c8] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> - return -}