diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -24,42 +24,47 @@ using namespace mlir; -// See SPV_NV_cooperative_matrix for supported element wise ops. -static void createElementWiseOp(ConversionPatternRewriter &builder, +/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op +/// when the elementwise op directly supports with cooperative matrix type. +/// Returns false if cannot. +/// +/// See SPV_NV_cooperative_matrix for supported element wise ops. +static bool createElementWiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, spirv::CooperativeMatrixNVType coopType, ValueRange operands) { switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::ADDI: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::SUBF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::SUBI: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVS: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVU: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::NEGATEF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; default: - llvm_unreachable("unknown op"); + break; } + return false; } namespace { @@ -162,13 +167,14 @@ } }; -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops. -struct WmmaElementwiseOpToSPIRVLowering +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// the default case. +struct WmmaElementwiseOpToSPIRVDefaultLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, + matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. @@ -177,9 +183,58 @@ return failure(); } auto coopType = convertMMAToSPIRVType( - subgroupMmaElementwiseOp.getType().cast()); - createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType, - adaptor.getOperands()); + elementwiseOp.getType().cast()); + return success(createElementWiseOp(rewriter, elementwiseOp, coopType, + adaptor.getOperands())); + } +}; + +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// matrix times scalar case. +struct WmmaElementwiseOpToSPIRVScalarMulLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // All operands should be of cooperative matrix types. + for (Value operand : adaptor.getOperands()) { + if (!operand.getType().isa()) + return failure(); + } + if (adaptor.getOperands().size() != 2) + return failure(); + + // Use the original operands to check whether one of the operands is a splat + // scalar value. + Value lhs = elementwiseOp.getOperands().front(); + Value rhs = elementwiseOp.getOperands().back(); + Value splat = nullptr, matrix = nullptr; + if (lhs.getDefiningOp()) { + splat = adaptor.getOperands().front(); + matrix = adaptor.getOperands().back(); + } else if (rhs.getDefiningOp()) { + matrix = adaptor.getOperands().front(); + splat = adaptor.getOperands().back(); + } + if (!splat || !matrix) + return failure(); + + // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. + Value scalar = nullptr; + if (auto cc = splat.getDefiningOp()) { + assert(cc.getConstituents().size() == 1); + scalar = cc.getConstituents().front(); + } else { + return failure(); + } + + auto coopType = convertMMAToSPIRVType( + elementwiseOp.getType().cast()); + rewriter.replaceOpWithNewOp( + elementwiseOp, coopType, ValueRange{matrix, scalar}); return success(); } }; @@ -197,8 +252,11 @@ void mlir::populateGpuWMMAToSPIRVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); patterns.add(converter, - patterns.getContext()); + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); + // Give the following patterns higher benefit to prevail over the default one. + patterns.add(converter, context, + /*benefit=*/2); } diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -4,10 +4,8 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_load_op - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index @@ -26,11 +24,9 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_store_op - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index @@ -49,12 +45,10 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_mma_op - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>}) - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> @@ -71,7 +65,6 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_constant_op gpu.func @gpu_wmma_constant_op() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { @@ -91,11 +84,10 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { - // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) - gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> @@ -107,4 +99,25 @@ gpu.return } } -} \ No newline at end of file +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar + // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: %[[S:.+]]: f16 + gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.return + } + } +}