diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -96,8 +96,12 @@ // i.e. (d0, d1, ...) -> (dn-1, dn-2). static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) { auto nDim = permutationMap.getNumDims(); - if (nDim < 2) - return false; + if (nDim < 2) { + // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. + AffineExpr dim0 = b.getAffineDimExpr(0); + AffineExpr zero = b.getAffineConstantExpr(0); + return permutationMap == AffineMap::get(1, 0, {dim0, zero}, b.getContext()); + } AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); @@ -458,12 +462,18 @@ llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); + std::optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); + AffineMap map = op.getPermutationMap(); + OpBuilder b(op); + bool isTranspose = isTransposeMatrixLoadMap(b, map); + // Handle broadcast by setting the stride to 0. - if (map.getResult(0).isa()) { - assert(map.getResult(0).cast().getValue() == 0); + if (auto cstExpr = + map.getResult(isTranspose).dyn_cast()) { + assert(cstExpr.getValue() == 0); stride = 0; } assert(stride); @@ -471,8 +481,6 @@ gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), op.getVectorType().getElementType(), fragType); - OpBuilder b(op); - bool isTranspose = isTransposeMatrixLoadMap(b, map); Value load = b.create( op.getLoc(), type, op.getSource(), op.getIndices(), b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr()); diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -189,3 +189,21 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } + +// CHECK-LABEL: func @matmul_transposed_broadcasted +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func.func @matmul_transposed_broadcasted(%arg0: memref<16xf16>, %arg1: memref<16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0) -> (d0, 0)>} : memref<16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0) -> (d0, 0)>} : memref<16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +}