Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -112,6 +112,15 @@ return true; } +// Return true if the constant is a splat to a 2D vector so that it can be +// converted t oa MMA constant matrix op. +static bool constantSupportsMMAMatrixType(ConstantOp constantOp) { + auto vecType = constantOp.getType().dyn_cast(); + if (!vecType || vecType.getRank() != 2) + return false; + return constantOp.value().isa(); +} + static bool supportsMMaMatrixType(Operation *op) { if (auto transferRead = dyn_cast(op)) return transferReadSupportsMMAMatrixType(transferRead); @@ -119,6 +128,8 @@ return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); + if (auto constant = dyn_cast(op)) + return constantSupportsMMAMatrixType(constant); return false; } @@ -240,10 +251,11 @@ } // namespace // MMA types have different layout based on how they are used in matmul ops. -// Figure the right layout to use by looking at Transfer op uses. +// Figure the right layout to use by looking at op uses. // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. -static const char *inferFragType(vector::TransferReadOp op) { +template +static const char *inferFragType(OpTy op) { for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) @@ -296,6 +308,22 @@ valueMapping[op.getResult()] = matmul; } +static void convertConstantOp(ConstantOp op, + llvm::DenseMap &valueMapping) { + assert(constantSupportsMMAMatrixType(op)); + OpBuilder b(op); + Attribute splat = op.getValue().cast().getSplatValue(); + auto scalarConstant = + b.create(op.getLoc(), splat.getType(), splat); + const char *fragType = inferFragType(op); + auto vecType = op.getType().cast(); + gpu::MMAMatrixType type = gpu::MMAMatrixType::get( + vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); + auto matrix = b.create(op.getLoc(), type, + scalarConstant); + valueMapping[op.getResult()] = matrix; +} + namespace mlir { void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { @@ -313,6 +341,8 @@ convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); + } else if (auto constantOp = dyn_cast(op)) { + convertConstantOp(constantOp, valueMapping); } } } Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -23,6 +23,24 @@ return } +// CHECK-LABEL: func @matmul_cst +// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f16 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +} + // Negative test until scf.for support is added. // CHECK-LABEL: func @matmul_loop // CHECK: vector.contract