diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/OpImplementation.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" @@ -31,16 +32,30 @@ // TODO: evolve this to proper enums. let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$contraction_lowering, - DefaultValuedAttr:$multireduction_lowering, - DefaultValuedAttr:$split_transfers, - DefaultValuedAttr:$transpose_lowering, + DefaultValuedAttr:$contraction_lowering, + DefaultValuedAttr: + $multireduction_lowering, + DefaultValuedAttr:$split_transfers, + DefaultValuedAttr:$transpose_lowering, DefaultValuedAttr:$transpose_avx2_lowering, DefaultValuedAttr:$unroll_vector_transfers ); let results = (outs PDL_Operation:$results); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = [{ + $target + oilist ( + `contraction_lowering` `=` $contraction_lowering + | `multireduction_lowering` `=` $multireduction_lowering + | `split_transfers` `=` $split_transfers + | `transpose_lowering` `=` $transpose_lowering + ) + attr-dict + }]; } #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS VectorTransformsBase.td) +mlir_tablegen(VectorTransformsEnums.h.inc -gen-enum-decls) +mlir_tablegen(VectorTransformsEnums.cpp.inc -gen-enum-defs) + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector) add_public_tablegen_target(MLIRVectorTransformsIncGen) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -13,6 +13,7 @@ #include #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -25,48 +26,6 @@ //===----------------------------------------------------------------------===// // Vector transformation options exposed as auxiliary structs. //===----------------------------------------------------------------------===// -/// Enum to control the lowering of `vector.transpose` operations. -enum class VectorTransposeLowering { - /// Lower transpose into element-wise extract and inserts. - EltWise = 0, - /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix - /// intrinsics. - Flat = 1, - /// Lower 2-D transpose to `vector.shuffle`. - Shuffle = 2, -}; -/// Enum to control the lowering of `vector.multi_reduction` operations. -enum class VectorMultiReductionLowering { - /// Lower multi_reduction into outer-reduction and inner-parallel ops. - InnerParallel = 0, - /// Lower multi_reduction into outer-parallel and inner-reduction ops. - InnerReduction = 1, -}; -/// Enum to control the lowering of `vector.contract` operations. -enum class VectorContractLowering { - /// Progressively lower to finer grained `vector.contract` and dot-products. - Dot = 0, - /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. - Matmul = 1, - /// Lower to `vector.outerproduct`. - OuterProduct = 2, - /// Lower contract with all reduction dimensions unrolled to 1 to a vector - /// elementwise operations. - ParallelArith = 3, -}; -/// Enum to control the splitting of `vector.transfer` operations into -/// in-bounds and out-of-bounds variants. -enum class VectorTransferSplit { - /// Do not split vector transfer operations. - None = 0, - /// Split using in-bounds + out-of-bounds vector.transfer operations. - VectorTransfer = 1, - /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy - /// operations. - LinalgCopy = 2, - /// Do not split vector transfer operation but instead mark it as "in-bounds". - ForceInBounds = 3 -}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { /// Option to control the lowering of vector.contract. diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td @@ -0,0 +1,86 @@ +//===- VectorTransformBase.td - Vector transform ops --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef VECTOR_TRANSFORMS_BASE +#define VECTOR_TRANSFORMS_BASE + +include "mlir/IR/EnumAttr.td" + +// Lower transpose into element-wise extract and inserts. +def VectorTransposeLowering_Elementwise: + I32EnumAttrCase<"EltWise", 0, "eltwise">; +// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix +// intrinsics. +def VectorTransposeLowering_FlatTranspose: + I32EnumAttrCase<"Flat", 1, "flat_transpose">; +// Lower 2-D transpose to `vector.shuffle`. +def VectorTransposeLowering_Shuffle: + I32EnumAttrCase<"Shuffle", 2, "shuffle">; +def VectorTransposeLoweringAttr : I32EnumAttr< + "VectorTransposeLowering", + "control the lowering of `vector.transpose` operations.", + [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose, + VectorTransposeLowering_Shuffle]> { + let cppNamespace = "::mlir::vector"; +} + +// Lower multi_reduction into outer-reduction and inner-parallel ops. +def VectorMultiReductionLowering_InnerParallel: + I32EnumAttrCase<"InnerParallel", 0, "innerparallel">; +// Lower multi_reduction into outer-parallel and inner-reduction ops. +def VectorMultiReductionLowering_InnerReduction: + I32EnumAttrCase<"InnerReduction", 1, "innerreduction">; +def VectorMultiReductionLoweringAttr: I32EnumAttr< + "VectorMultiReductionLowering", + "control the lowering of `vector.multi_reduction`.", + [VectorMultiReductionLowering_InnerParallel, + VectorMultiReductionLowering_InnerReduction]> { + let cppNamespace = "::mlir::vector"; +} + +// Progressively lower to finer grained `vector.contract` and dot-products. +def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">; +// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. +def VectorContractLowering_Matmul: + I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">; +// Lower to `vector.outerproduct`. +def VectorContractLowering_OuterProduct: + I32EnumAttrCase<"OuterProduct", 2, "outerproduct">; +// Lower contract with all reduction dimensions unrolled to 1 to a vector +// elementwise operations. +def VectorContractLowering_ParallelArith: + I32EnumAttrCase<"ParallelArith", 3, "parallelarith">; +def VectorContractLoweringAttr: I32EnumAttr< + "VectorContractLowering", + "control the lowering of `vector.contract` operations.", + [VectorContractLowering_Dot, VectorContractLowering_Matmul, + VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> { + let cppNamespace = "::mlir::vector"; +} + +// Do not split vector transfer operations. +def VectorTransferSplit_None: I32EnumAttrCase<"None", 0, "none">; +// Split using in-bounds + out-of-bounds vector.transfer operations. +def VectorTransferSplit_VectorTransfer: + I32EnumAttrCase<"VectorTransfer", 1, "vector-transfer">; +// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy +// operations. +def VectorTransferSplit_LinalgCopy: + I32EnumAttrCase<"LinalgCopy", 2, "linalg-copy">; +// Do not split vector transfer operation but instead mark it as "in-bounds". +def VectorTransferSplit_ForceInBounds: + I32EnumAttrCase<"ForceInBounds", 3, "force-in-bounds">; +def VectorTransferSplitAttr: I32EnumAttr< + "VectorTransferSplit", + "control the splitting of `vector.transfer` operations into in-bounds" + " and out-of-bounds variants.", + [VectorTransferSplit_None, VectorTransferSplit_VectorTransfer, + VectorTransferSplit_LinalgCopy, VectorTransferSplit_ForceInBounds]> { + let cppNamespace = "::mlir::vector"; +} +#endif diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -53,32 +53,12 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); vector::VectorTransposeLowering vectorTransposeLowering = - llvm::StringSwitch( - getTransposeLowering()) - .Case("eltwise", vector::VectorTransposeLowering::EltWise) - .Case("flat_transpose", vector::VectorTransposeLowering::Flat) - .Case("shuffle", vector::VectorTransposeLowering::Shuffle) - .Default(vector::VectorTransposeLowering::EltWise); + getTransposeLowering(); vector::VectorMultiReductionLowering vectorMultiReductionLowering = - llvm::StringSwitch( - getMultireductionLowering()) - .Case("innerreduction", - vector::VectorMultiReductionLowering::InnerReduction) - .Default(vector::VectorMultiReductionLowering::InnerParallel); + getMultireductionLowering(); vector::VectorContractLowering vectorContractLowering = - llvm::StringSwitch( - getContractionLowering()) - .Case("matrixintrinsics", vector::VectorContractLowering::Matmul) - .Case("dot", vector::VectorContractLowering::Dot) - .Case("outerproduct", vector::VectorContractLowering::OuterProduct) - .Default(vector::VectorContractLowering::OuterProduct); - vector::VectorTransferSplit vectorTransferSplit = - llvm::StringSwitch(getSplitTransfers()) - .Case("none", vector::VectorTransferSplit::None) - .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy) - .Case("vector-transfers", - vector::VectorTransferSplit::VectorTransfer) - .Default(vector::VectorTransferSplit::None); + getContractionLowering(); + vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers(); vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -3040,3 +3040,9 @@ RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } + +//===----------------------------------------------------------------------===// +// TableGen'd enum attribute definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc" diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -21,5 +21,5 @@ transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op {bufferize_function_boundaries = true} %func = transform.structured.match ops{["func.func"]} in %module_op - transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"} + transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" } diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -22,5 +22,5 @@ transform.bufferization.one_shot_bufferize %module_op %func = transform.structured.match ops{["func.func"]} in %module_op - transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"} + transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" }