Index: mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h =================================================================== --- mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h +++ mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h @@ -13,16 +13,28 @@ #ifndef MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H #define MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { class SPIRVTypeConverter; +namespace gpu { +class MMAMatrixType; +} // namespace gpu + /// Appends to a pattern list additional patterns for translating GPU Ops to /// SPIR-V ops. For a gpu.func to be converted, it should have a /// spirv.entry_point_abi attribute. void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); + +/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV. +void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H Index: mlir/include/mlir/Dialect/GPU/IR/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1246,19 +1246,35 @@ }]; } -def GPU_ElementwiseOpAdd : I32EnumAttrCase<"ADDF", 0, "addf">; -def GPU_ElementwiseOpMul : I32EnumAttrCase<"MULF", 1, "mulf">; -def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 2, "maxf">; -def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 3, "minf">; -def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 4, "divf">; +def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">; +def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">; +def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">; +def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 3, "maxf">; +def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 4, "minf">; +def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 5, "divf">; +def GPU_ElementwiseOpAddI : I32EnumAttrCase<"ADDI", 6, "addi">; +def GPU_ElementwiseOpMulI : I32EnumAttrCase<"MULI", 7, "muli">; +def GPU_ElementwiseOpSUBI : I32EnumAttrCase<"SUBI", 8, "subi">; +def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">; +def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">; +def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">; +def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">; def MMAElementWise : I32EnumAttr<"MMAElementwiseOp", "elementwise operation to apply to mma matrix", [ - GPU_ElementwiseOpAdd, - GPU_ElementwiseOpMul, + GPU_ElementwiseOpAddF, + GPU_ElementwiseOpMulF, + GPU_ElementwiseOpSUBF, GPU_ElementwiseOpMaxF, GPU_ElementwiseOpMinF, - GPU_ElementwiseOpDivF + GPU_ElementwiseOpDivF, + GPU_ElementwiseOpAddI, + GPU_ElementwiseOpMulI, + GPU_ElementwiseOpSUBI, + GPU_ElementwiseOpDivS, + GPU_ElementwiseOpDivU, + GPU_ElementwiseOpNEGF, + GPU_ElementwiseOpNEGS ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::gpu"; Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -42,7 +42,18 @@ // Operand type same as result type. SPIRV_UnaryOp; + [Pure, SameOperandsAndResultType])> { + // In addition to normal types arithmetic instructions can support cooperative + // matrix. + let arguments = (ins + SPIRV_ScalarOrVectorOrCoopMatrixOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOrCoopMatrixOf:$result + ); + let assemblyFormat = "operands attr-dict `:` type($result)"; + } // ----- Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -311,8 +311,9 @@ case gpu::MMAElementwiseOp::MINF: return createMinMaxF(builder, loc, operands[0], operands[1], /*isMin=*/true); + default: + llvm_unreachable("unknown op"); } - llvm_unreachable("unknown op"); } /// Convert GPU MMA elementwise ops to extract + op + insert. Index: mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_conversion_library(MLIRGPUToSPIRV GPUToSPIRV.cpp GPUToSPIRVPass.cpp + WmmaOpsToSPIRV.cpp DEPENDS MLIRConversionPassIncGen Index: mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -86,9 +86,12 @@ SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); + typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type { + return convertMMAToSPIRVType(type); + }); RewritePatternSet patterns(context); populateGPUToSPIRVPatterns(typeConverter, patterns); - + populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns); // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns); Index: mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -0,0 +1,203 @@ +//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of patterns to lower GPU Subgroup MMA ops to +// SPIRV Dialect ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" +#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +// See SPV_NV_cooperative_matrix for supported element wise ops. +static void createElementWiseOp(ConversionPatternRewriter &builder, + gpu::SubgroupMmaElementwiseOp op, + spirv::CooperativeMatrixNVType coopType, + ValueRange operands) { + switch (op.getOpType()) { + case gpu::MMAElementwiseOp::ADDF: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::ADDI: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::SUBF: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::SUBI: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::DIVF: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::DIVS: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::DIVU: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::NEGATEF: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + case gpu::MMAElementwiseOp::NEGATES: + builder.replaceOpWithNewOp(op, coopType, operands); + return; + default: + llvm_unreachable("unknown op"); + } +} + +namespace { + +/// This class implements the conversion of GPU MMA loadOp to +/// CooperativeMatrixLoad op in the SPIRV dialect. +struct WmmaLoadOpToSPIRVLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = subgroupMmaLoadMatrixOp->getLoc(); + gpu::MMAMatrixType retType = + subgroupMmaLoadMatrixOp.getRes().getType().cast(); + auto memrefType = + subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(); + Value bufferPtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, + adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); + auto coopType = convertMMAToSPIRVType(retType); + auto stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); + auto i32Type = rewriter.getI32Type(); + auto strideValue = rewriter.create( + loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto coloumnMajor = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor, + spirv::MemoryAccessAttr()); + return success(); + } +}; + +/// This class implements the conversion of GPU MMA StoreOp to +/// CooperativeMatrixStore op in the SPIRV dialect. +struct WmmaStoreOpToSPIRVLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = subgroupMmaStoreMatrixOp->getLoc(); + auto memrefType = + subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(); + Value bufferPtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, + adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); + auto stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue(); + auto i32Type = rewriter.getI32Type(); + auto strideValue = rewriter.create( + loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto coloumnMajor = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue, + coloumnMajor, spirv::MemoryAccessAttr()); + return success(); + } +}; + +/// This class implements the conversion of GPU MMA Compute to +/// CooperativeMatrixMulAdd op in the SPIRV dialect. +struct WmmaMmaOpToSPIRVLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(), + adaptor.getOpB(), adaptor.getOpC()); + return success(); + } +}; + +/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops. +struct WmmaConstantOpToSPIRVLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cst = adaptor.getOperands()[0]; + auto coopType = convertMMAToSPIRVType( + subgroupMmaConstantMatrixOp.getType().cast()); + rewriter.replaceOpWithNewOp( + subgroupMmaConstantMatrixOp, coopType, cst); + return success(); + } +}; + +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops. +struct WmmaElementwiseOpToSPIRVLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // All operands should be of cooperative matrix types. + for (Value operand : adaptor.getOperands()) { + if (!operand.getType().isa()) + return failure(); + } + auto coopType = convertMMAToSPIRVType( + subgroupMmaElementwiseOp.getType().cast()); + createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType, + adaptor.getOperands()); + return success(); + } +}; + +} // namespace + +/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. +mlir::spirv::CooperativeMatrixNVType +mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) { + ArrayRef retTypeShape = type.getShape(); + Type elementType = type.getElementType(); + return spirv::CooperativeMatrixNVType::get( + elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); +} + +void mlir::populateGpuWMMAToSPIRVConversionPatterns( + SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter, + patterns.getContext()); +} \ No newline at end of file Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1192,18 +1192,11 @@ auto resMatrixType = resType.cast(); auto operand = resMatrixType.getOperand(); auto srcMemrefType = srcType.cast(); - auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); if (!isLastMemrefDimUnitStride(srcMemrefType)) return emitError( "expected source memref most minor dim must have unit stride"); - if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && - srcMemSpace != kGlobalMemorySpace) - return emitError( - "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " - "kGlobalMemorySpace only allowed"); - if (!operand.equals("AOp") && !operand.equals("BOp") && !operand.equals("COp")) return emitError("only AOp, BOp and COp can be loaded"); @@ -1220,17 +1213,11 @@ auto dstType = getDstMemref().getType(); auto srcMatrixType = srcType.cast(); auto dstMemrefType = dstType.cast(); - auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); if (!isLastMemrefDimUnitStride(dstMemrefType)) return emitError( "expected destination memref most minor dim must have unit stride"); - if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && - dstMemSpace != kGlobalMemorySpace) - return emitError("destination memorySpace of kGenericMemorySpace, " - "kGlobalMemorySpace or kSharedMemorySpace only allowed"); - if (!srcMatrixType.getOperand().equals("COp")) return emitError( "expected the operand matrix being stored to have 'COp' operand type"); Index: mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_load_op + // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi : vector<3xi32>> + gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> + %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_store_op + // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi : vector<3xi32>> + gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> + gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_mma_op + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>}) + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi : vector<3xi32>> + gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_constant_op + gpu.func @gpu_wmma_constant_op() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + // CHECK: {{%.*}} = spirv.Constant + %cst = arith.constant 1.0 : f16 + // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> + %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) + gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} \ No newline at end of file Index: mlir/test/Dialect/GPU/invalid.mlir =================================================================== --- mlir/test/Dialect/GPU/invalid.mlir +++ mlir/test/Dialect/GPU/invalid.mlir @@ -515,16 +515,6 @@ // ----- -func.func @mmaLoadOp_invalid_mem_space(){ - %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5> - %i = arith.constant 16 : index - // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}} - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp"> - return -} - -// ----- - #layout_map_col_major = affine_map<(i, j) -> (j, i)> func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () { @@ -538,17 +528,6 @@ // ----- -func.func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () { - %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5> - %i = arith.constant 16 : index - %j = arith.constant 16 : index - // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}} - gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5> - return -} - -// ----- - func.func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () { %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index