diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -113,6 +113,15 @@ return true; } +/// Return true if the constant is a splat to a 2D vector so that it can be +/// converted to a MMA constant matrix op. +static bool constantSupportsMMAMatrixType(ConstantOp constantOp) { + auto vecType = constantOp.getType().dyn_cast(); + if (!vecType || vecType.getRank() != 2) + return false; + return constantOp.value().isa(); +} + static bool supportsMMaMatrixType(Operation *op) { if (auto transferRead = dyn_cast(op)) return transferReadSupportsMMAMatrixType(transferRead); @@ -120,6 +129,8 @@ return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); + if (auto constant = dyn_cast(op)) + return constantSupportsMMAMatrixType(constant); return false; } @@ -241,10 +252,11 @@ } // namespace // MMA types have different layout based on how they are used in matmul ops. -// Figure the right layout to use by looking at Transfer op uses. +// Figure the right layout to use by looking at op uses. // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. -static const char *inferFragType(vector::TransferReadOp op) { +template +static const char *inferFragType(OpTy op) { for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) @@ -297,6 +309,23 @@ valueMapping[op.getResult()] = matmul; } +/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. +static void convertConstantOp(ConstantOp op, + llvm::DenseMap &valueMapping) { + assert(constantSupportsMMAMatrixType(op)); + OpBuilder b(op); + Attribute splat = op.getValue().cast().getSplatValue(); + auto scalarConstant = + b.create(op.getLoc(), splat.getType(), splat); + const char *fragType = inferFragType(op); + auto vecType = op.getType().cast(); + gpu::MMAMatrixType type = gpu::MMAMatrixType::get( + vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); + auto matrix = b.create(op.getLoc(), type, + scalarConstant); + valueMapping[op.getResult()] = matrix; +} + namespace mlir { void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { @@ -314,6 +343,8 @@ convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); + } else if (auto constantOp = dyn_cast(op)) { + convertConstantOp(constantOp, valueMapping); } } } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -23,6 +23,24 @@ return } +// CHECK-LABEL: func @matmul_cst +// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f16 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +} + // Negative test until scf.for support is added. // CHECK-LABEL: func @matmul_loop // CHECK: vector.contract