Index: mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp =================================================================== --- mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -238,7 +238,6 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams ¶ms) { // One thread per 128b row. - const int64_t kNumThreadsPerTile = kNumRowsPerTile; const int bitsPerElement = static_cast( params.fragmentType.getElementType().getIntOrFloatBitWidth()); const int kElementsPer128b = (128 / bitsPerElement); @@ -249,27 +248,28 @@ return AffineMap::get(1, 0, dimExprs, builder.getContext()); }; - // This case corresponds to row-major A|C or col-major B operands. - if (params.contiguousDimType == vector::IteratorType::reduction) { - AffineExpr row = d0 % (operandShape[0]); - AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); - return makeMap({row, col}); - } + // Index `idx` in vectorType `operandShape` maps to the strided dimension of + // the `srcMemref` memory of the LdMatrixOp. + int idx = + (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1; + + // Affine expr in strided and contiguous dimension encodes the coordinate + // mapping for the element a thread points to for warp-wide LdMatrixOp. + AffineExpr strided = d0 % (operandShape[idx]); + AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b); + + // This case corresponds to row-major matrixA or col-major matrixB or + // row-major matrixC. This is when the memory layout in `srcMemref` + // match mma.sync hardware vector register operand layout. + if (params.contiguousDimType == vector::IteratorType::reduction) + return makeMap({strided, contiguous}); + + // This case corresponds to col-major matrixA or row-major matrixB or + // col-major matrixC. This is when the memory layout in `srcMemref` does not + // match mma.sync hardware vector register operand layout. + if (params.contiguousDimType == vector::IteratorType::parallel) + return makeMap({contiguous, strided}); - // This case Corresponds to col-major A|C or row-major B operands. The - // operandShape given is already pre-transposed (e.g. 8x16 = KxN). - if (params.contiguousDimType == vector::IteratorType::parallel) { - const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; - // Threads are assigned in groups of 8 first across columns, then to - // rows. This is transpose of what `ldmatrix` expects, but when - // `ldmatrix` gets the `.trans` qualifier, final the effect will be to - // transpose just the blocks. - auto groupIdx = d0.floorDiv(kNumThreadsPerTile); - auto tileCol = (groupIdx % num8x128bCols); - auto tileRow = groupIdx.floorDiv(num8x128bCols); - return makeMap({tileCol * kElementsPer128b, - tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); - } return failure(); } Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -173,28 +173,23 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> - -// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> -// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> // CHECK-LABEL: func @m16n8k16_fp16_row_row_row func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] - // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} - %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> @@ -207,10 +202,8 @@ // FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB) //######################################################################### -// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> -// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> -// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)> -// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -224,19 +217,19 @@ %c8 = arith.constant 8 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] - // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true} + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true} %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] - // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false} %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16> // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> @@ -259,10 +252,8 @@ } // ----- -// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> -// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> -// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)> -// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> #map0 = affine_map<(d0, d1, d2) -> (d2, d1)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -274,26 +265,24 @@ %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> // CHECK: [[C0:%.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index %cst = arith.constant 0.000000e+00 : f16 - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] - // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> - %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]] - // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> - %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> - // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] - // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] - // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> - %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[m_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> - vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> return }