Index: mlir/include/mlir/Dialect/GPU/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU) add_public_tablegen_target(MLIRGPUPassIncGen) +set(LLVM_TARGET_DEFINITIONS GPUOps.td) +mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUOpsEnumsGen) + add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc) Index: mlir/include/mlir/Dialect/GPU/GPUBase.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUBase.td +++ mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -115,18 +115,4 @@ ]; } -// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing -// the layouts of the operands supported by the ops that use this attribute. -def RowMajor: StrEnumAttrCase<"RowMajor", 0>; -def ColMajor: StrEnumAttrCase<"ColMajor", 1>; - -// Specifies a String enum Attribute for Warp wide matrix operations, -// representing the layout of respective operands. The layout later governs -// the lowerings to appropriate intrinsics. -def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", - [RowMajor, ColMajor]> { - let stringToSymbolFnName = "LayoutStrToEnum"; - let symbolToStringFnName = "EnumToLayoutStr"; -} - #endif // GPU_BASE Index: mlir/include/mlir/Dialect/GPU/GPUDialect.h =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -166,6 +166,8 @@ } // end namespace gpu } // end namespace mlir +#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc" + #include "mlir/Dialect/GPU/GPUOpsDialect.h.inc" #include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc" Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -591,13 +591,13 @@ } // add, mul mirror the XLA ComparisonDirection enum. -def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; -def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">; -def GPU_AllReduceOpMax : StrEnumAttrCase<"max">; -def GPU_AllReduceOpMin : StrEnumAttrCase<"min">; -def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; -def GPU_AllReduceOpOr : StrEnumAttrCase<"or">; -def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">; +def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">; +def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">; +def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">; +def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">; +def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">; +def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">; +def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", "built-in reduction operations supported by gpu.allreduce.", @@ -644,7 +644,7 @@ let verifier = [{ return ::verifyAllReduce(*this); }]; } -def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">; +def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", "Indexing modes supported by gpu.shuffle.", @@ -1121,4 +1121,60 @@ }]; } +def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">; +def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">; +def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">; +def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">; + +def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp", + "elementwise operation to apply to mma matrix", + [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL, + GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> { + let cppNamespace = "::mlir::gpu"; + let storageType = "::mlir::StringAttr"; + let returnType = "::mlir::gpu::MMAElementwiseOp"; + let convertFromStorage = "*symbolizeMMAElementwiseOp($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; +} + +def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise", + [NoSideEffect, + AllTypesMatch<["args"]>]>{ + + let summary = "GPU warp elementwise operation on a matrix"; + + let description = [{ + The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and + compute a new `!gpu.mma_matrix` by applying an elementwise operation to each + element. + + Since the operation is elementwise and the matrix type must match, the + matrix elements are processed independently of the matrix layout. + + This op is meant to be used along with `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + %0 = %A, %B { operation = "ADD" } : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) + -> !gpu.mma_matrix<16x16xf16, "COp"> + ``` + }]; + + let arguments = (ins Variadic:$args, MMAElementWiseAttr:$operation); + + let results = (outs GPU_MMAMatrix:$res); + + let extraClassDeclaration = [{ + gpu::MMAMatrixType getType() { + return res().getType().cast(); + } + }]; + + let assemblyFormat = [{ + $args attr-dict `:` functional-type($args, $res) + }]; +} + #endif // GPU_OPS Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -1187,11 +1187,11 @@ } // An enum attribute case stored with StringAttr. -class StrEnumAttrCase : - EnumAttrCaseInfo, +class StrEnumAttrCase : + EnumAttrCaseInfo, StringBasedAttr< - CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">, - "case " # sym>; + CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">, + "case " # str>; // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/TypeUtilities.h" using namespace mlir; @@ -352,13 +353,90 @@ } }; +static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, + Value rhs, bool isMin) { + auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + Type i1Type = builder.getI1Type(); + if (auto vecType = lhs.getType().dyn_cast()) + i1Type = VectorType::get(vecType.getShape(), i1Type); + Value cmp = builder.create( + loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, + lhs, rhs); + Value sel = builder.create(loc, cmp, lhs, rhs); + Value isNan = builder.create( + loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = builder.create( + loc, lhs.getType(), + builder.getFloatAttr(floatType, + APFloat::getQNaN(floatType.getFloatSemantics()))); + return builder.create(loc, isNan, sel, nan); +} + +static Value createScalarOp(OpBuilder &builder, Location loc, + gpu::MMAElementwiseOp op, + ArrayRef operands) { + switch (op) { + case gpu::MMAElementwiseOp::ADDF: + return builder.create(loc, operands[0].getType(), operands); + case gpu::MMAElementwiseOp::MULF: + return builder.create(loc, operands[0].getType(), operands); + case gpu::MMAElementwiseOp::MAXF: + return createMinMaxF(builder, loc, operands[0], operands[1], + /*isMin=*/false); + case gpu::MMAElementwiseOp::MINF: + return createMinMaxF(builder, loc, operands[0], operands[1], + /*isMin=*/true); + } + llvm_unreachable("unknown op"); +} + +/// Convert GPU MMA elementwise ops to extract + op + insert. +struct WmmaElementwiseOpToNVVMLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), + adaptor.getOperands(), rewriter))) + return failure(); + Location loc = subgroupMmaElementwiseOp.getLoc(); + size_t numOperands = adaptor.getOperands().size(); + LLVM::LLVMStructType destType = convertMMAToLLVMType( + subgroupMmaElementwiseOp.getType().cast()); + Value matrixStruct = rewriter.create(loc, destType); + for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + SmallVector extractedOperands; + for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { + Type elementType = adaptor.getOperands()[opIdx] + .getType() + .cast() + .getBody()[i]; + extractedOperands.push_back(rewriter.create( + loc, elementType, adaptor.getOperands()[opIdx], + rewriter.getI32ArrayAttr(i))); + } + Value element = + createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(), + extractedOperands); + matrixStruct = rewriter.create( + loc, matrixStruct, element, rewriter.getI32ArrayAttr(i)); + } + rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); + return success(); + } +}; + } // anonymous namespace namespace mlir { void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert( - converter); + WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, + WmmaElementwiseOpToNVVMLowering>(converter); } } // namespace mlir Index: mlir/lib/Dialect/GPU/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/GPU/CMakeLists.txt +++ mlir/lib/Dialect/GPU/CMakeLists.txt @@ -24,6 +24,7 @@ DEPENDS MLIRGPUOpsIncGen + MLIRGPUOpsEnumsGen MLIRGPUOpInterfacesIncGen LINK_LIBS PUBLIC Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1185,6 +1185,7 @@ } #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" +#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/GPU/GPUOps.cpp.inc" Index: mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -220,3 +220,33 @@ return %C : !gpu.mma_matrix<16x16xf16, "COp"> } } + +// ----- + +gpu.module @test_module { + +// CHECK-LABEL: func @gpu_wmma_elementwise +// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]] : vector<2xf16> +// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]] : vector<2xf16> +// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]] : vector<2xf16> +// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]] : vector<2xf16> +// CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + builtin.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) { + %C = gpu.subgroup_mma_elementwise %A, %B { operation = "ADDF" } : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + return %C : !gpu.mma_matrix<16x16xf16, "COp"> + } +} Index: mlir/test/Dialect/GPU/ops.mlir =================================================================== --- mlir/test/Dialect/GPU/ops.mlir +++ mlir/test/Dialect/GPU/ops.mlir @@ -220,7 +220,10 @@ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp"> - // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + %2 = gpu.subgroup_mma_elementwise %1, %1 {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + %3 = gpu.subgroup_mma_elementwise %2, %1 {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> return } } Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2768,6 +2768,14 @@ ["-gen-op-defs"], "include/mlir/Dialect/GPU/GPUOps.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/GPU/GPUOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/GPU/GPUOpsEnums.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/GPU/GPUOps.td",