diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -208,6 +208,9 @@ return failure(); } + if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF) + return failure(); + // Use the original operands to check whether one of the operands is a splat // scalar value. Value lhs = elementwiseOp.getOperands().front(); diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -164,6 +164,27 @@ %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar + // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: %[[S:.+]]: f16 + gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> + %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup> + %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> gpu.return } }