Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -624,7 +624,8 @@ // at least 8 rows to read and the width to read for the transpose is 128 // bits. if (!op.getPermutationMap().isMinorIdentity() && - (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) + (bitWidth != 16 || vecTy.getDimSize(1) < 8 || + vecTy.getDimSize(0) * bitWidth < 128)) isLdMatrixCompatible = false; if (!isLdMatrixCompatible) Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -347,3 +347,63 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> return } + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)> + +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)> + +// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> +// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)> + +// CHECK-LABEL: func @m16n8k8_tf32_f32_row_row_row +func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + + // b and c are not loaded by ldmatrix in this test. + // CHECK-NOT: nvgpu.ldmatrix + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: [[b_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[b_frag0:%.+]] = vector.insert [[b_el0]], {{.*}} : f32 into vector<2x1xf32> + // CHECK: [[b_el1:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[b_frag1:%.+]] = vector.insert [[b_el1]], {{.*}} : f32 into vector<2x1xf32> + + // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag1]], [[c_frag]]) + // CHECK-SAME: mmaShape = [16, 8, 8] + // CHECK-SAME: -> vector<2x2xf32> + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x8xf32> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> + + // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC8_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> + return +}