diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -24,42 +24,47 @@ using namespace mlir; -// See SPV_NV_cooperative_matrix for supported element wise ops. -static void createElementWiseOp(ConversionPatternRewriter &builder, +/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op +/// when the elementwise op directly supports with cooperative matrix type. +/// Returns false if cannot. +/// +/// See SPV_NV_cooperative_matrix for supported elementwise ops. +static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, spirv::CooperativeMatrixNVType coopType, ValueRange operands) { switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::ADDI: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::SUBF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::SUBI: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVS: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::DIVU: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::NEGATEF: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); - return; + return true; default: - llvm_unreachable("unknown op"); + break; } + return false; } namespace { @@ -163,13 +168,14 @@ } }; -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops. -struct WmmaElementwiseOpToSPIRVLowering +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// the default case. +struct WmmaElementwiseOpToSPIRVDefaultLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, + matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. @@ -178,9 +184,58 @@ return failure(); } auto coopType = convertMMAToSPIRVType( - subgroupMmaElementwiseOp.getType().cast()); - createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType, - adaptor.getOperands()); + elementwiseOp.getType().cast()); + return success(createElementwiseOp(rewriter, elementwiseOp, coopType, + adaptor.getOperands())); + } +}; + +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// matrix times scalar case. +struct WmmaElementwiseOpToSPIRVScalarMulLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getOperands().size() != 2) + return failure(); + // All operands should be of cooperative matrix types. + for (Value operand : adaptor.getOperands()) { + if (!operand.getType().isa()) + return failure(); + } + + // Use the original operands to check whether one of the operands is a splat + // scalar value. + Value lhs = elementwiseOp.getOperands().front(); + Value rhs = elementwiseOp.getOperands().back(); + Value splat = nullptr; + Value matrix = nullptr; + if (lhs.getDefiningOp()) { + splat = adaptor.getOperands().front(); + matrix = adaptor.getOperands().back(); + } else if (rhs.getDefiningOp()) { + matrix = adaptor.getOperands().front(); + splat = adaptor.getOperands().back(); + } + if (!splat || !matrix) + return failure(); + + // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. + Value scalar = nullptr; + auto cc = splat.getDefiningOp(); + if (!cc) + return failure(); + assert(cc.getConstituents().size() == 1); + scalar = cc.getConstituents().front(); + + auto coopType = convertMMAToSPIRVType( + elementwiseOp.getType().cast()); + rewriter.replaceOpWithNewOp( + elementwiseOp, coopType, ValueRange{matrix, scalar}); return success(); } }; @@ -198,8 +253,11 @@ void mlir::populateGpuWMMAToSPIRVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); patterns.add(converter, - patterns.getContext()); + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); + // Give the following patterns higher benefit to prevail over the default one. + patterns.add(converter, context, + /*benefit=*/2); } diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -4,10 +4,8 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_load_op - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index @@ -27,7 +25,6 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi @@ -50,11 +47,9 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_store_op - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index @@ -74,7 +69,6 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) @@ -98,12 +92,10 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_mma_op - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>}) - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> @@ -120,7 +112,6 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spirv.func @gpu_wmma_constant_op gpu.func @gpu_wmma_constant_op() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { @@ -140,11 +131,10 @@ gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { - // CHECK: spirv.module @{{.*}} Logical GLSL450 { - // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) - gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> @@ -157,3 +147,24 @@ } } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar + // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: %[[S:.+]]: f16 + gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.return + } + } +}