diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -130,6 +130,26 @@ broadcastOp.source().getType().isa(); } +/// Return the MMA elementwise enum associated with `op` if it is supported. +/// Return `llvm::None` otherwise. +static llvm::Optional +convertElementwiseOpToMMA(Operation *op) { + if (isa(op)) + return gpu::MMAElementwiseOp::ADDF; + if (isa(op)) + return gpu::MMAElementwiseOp::MULF; + if (isa(op)) + return gpu::MMAElementwiseOp::MAXF; + if (isa(op)) + return gpu::MMAElementwiseOp::MINF; + return llvm::None; +} + +/// Return true if the op is supported as elementwise op on MMAMatrix type. +static bool elementwiseSupportsMMAMatrixType(Operation *op) { + return convertElementwiseOpToMMA(op).hasValue(); +} + static bool supportsMMaMatrixType(Operation *op) { if (isa(op)) return true; @@ -143,7 +163,7 @@ return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); - return false; + return elementwiseSupportsMMAMatrixType(op); } // Analyze slice of operations based on convert op to figure out if the whole @@ -423,6 +443,18 @@ op.erase(); } +/// Convert an elementwise op to the equivalent elementwise op on MMA matrix. +static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + SmallVector matrixOperands; + for (Value operand : op->getOperands()) + matrixOperands.push_back(valueMapping.find(operand)->second); + Value newOp = b.create( + op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); + valueMapping[op->getResult(0)] = newOp; +} + namespace mlir { void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { @@ -448,6 +480,8 @@ convertForOp(forOp, valueMapping); } else if (auto yiledOp = dyn_cast(op)) { convertYieldOp(yiledOp, valueMapping); + } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { + convertElementwiseOp(op, *elementwiseType, valueMapping); } } } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -83,3 +83,26 @@ vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16> return } + +// CHECK-LABEL: func @matmul_fused_elementwise +// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK-DAG: %[[C1:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_1]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[E:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[C1]] {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[E]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> + %cst_1 = arith.constant dense<1.000000e+00> : vector<16x16xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %E = arith.addf %D, %cst_1 : vector<16x16xf16> + vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +}