Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -123,6 +123,12 @@ return constantOp.value().isa(); } +/// Return true if this is a broadcast from scalar to a 2D vector. +static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { + return broadcastOp.getVectorType().getRank() == 2 && + broadcastOp.source().getType().isa(); +} + static bool supportsMMaMatrixType(Operation *op) { if (isa(op)) return true; @@ -134,6 +140,8 @@ return contractSupportsMMAMatrixType(contract); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); + if (auto broadcast = dyn_cast(op)) + return broadcastSupportsMMAMatrixType(broadcast); return false; } @@ -141,8 +149,11 @@ // slice can be converted to MMA operations. static SetVector getOpToConvert(mlir::Operation *op) { auto hasVectorDest = [](Operation *op) { - return op->getNumResults() == 0 || - llvm::any_of(op->getResultTypes(), + return llvm::any_of(op->getResultTypes(), + [](Type t) { return t.isa(); }); + }; + auto hasVectorSrc = [](Operation *op) { + return llvm::any_of(op->getOperandTypes(), [](Type t) { return t.isa(); }); }; SetVector opToConvert; @@ -150,7 +161,7 @@ if (opToConvert.contains(contract.getOperation())) return; SetVector dependentOps = - getSlice(contract, hasVectorDest, hasVectorDest); + getSlice(contract, hasVectorDest, hasVectorSrc); // If any instruction cannot use MMA matrix type drop the whole // chaine. MMA matrix are stored in an opaque type so they cannot be used // by all operations. @@ -329,6 +340,20 @@ valueMapping[op.getResult()] = matrix; } +/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. +static void convertBroadcastOp(vector::BroadcastOp op, + llvm::DenseMap &valueMapping) { + assert(broadcastSupportsMMAMatrixType(op)); + OpBuilder b(op); + const char *fragType = inferFragType(op); + auto vecType = op.getVectorType(); + gpu::MMAMatrixType type = gpu::MMAMatrixType::get( + vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); + auto matrix = b.create(op.getLoc(), type, + op.source()); + valueMapping[op.getResult()] = matrix; +} + // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, @@ -416,6 +441,8 @@ convertContractOp(contractOp, valueMapping); } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); + } else if (auto broadcastOp = dyn_cast(op)) { + convertBroadcastOp(broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { convertForOp(forOp, valueMapping); } else if (auto yiledOp = dyn_cast(op)) { Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -41,6 +41,24 @@ return } +// CHECK-LABEL: func @matmul_broadcast +// CHECK-SAME: (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16) +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %f: f16) { + %C = vector.broadcast %f : f16 to vector<16x16xf16> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +} + // CHECK-LABEL: func @matmul_loop // CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) {