diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -16,4 +16,31 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def LowerVectorsOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Indicates that the vector operations nested under the isolated from above op + `target` should be lowered to finer-grained vector primitives. + + At this time, the transform is all or nothing. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + // TODO: evolve this to proper enums. + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$contraction_lowering, + DefaultValuedAttr:$multireduction_lowering, + DefaultValuedAttr:$split_transfers, + DefaultValuedAttr:$transpose_lowering, + DefaultValuedAttr:$transpose_avx2_lowering, + DefaultValuedAttr:$unroll_vector_transfers + ); + let results = (outs PDL_Operation:$results); + + let assemblyFormat = "$target attr-dict"; +} + #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt @@ -16,4 +16,6 @@ MLIRSideEffectInterfaces MLIRTransformDialect MLIRVectorDialect + MLIRVectorToSCF + MLIRX86VectorTransforms ) diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -8,11 +8,14 @@ #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -20,6 +23,128 @@ using namespace mlir::vector; using namespace mlir::transform; +//===----------------------------------------------------------------------===// +// LowerVectorsOp +//===----------------------------------------------------------------------===// + +void transform::LowerVectorsOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getResults(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure transform::LowerVectorsOp::apply( + mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + + SmallVector results; + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (Operation *target : payloadOps) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + MLIRContext *ctx = getContext(); + RewritePatternSet patterns(ctx); + vector::VectorTransposeLowering vectorTransposeLowering = + llvm::StringSwitch( + getTransposeLowering()) + .Case("eltwise", vector::VectorTransposeLowering::EltWise) + .Case("flat_transpose", vector::VectorTransposeLowering::Flat) + .Case("shuffle", vector::VectorTransposeLowering::Shuffle) + .Default(vector::VectorTransposeLowering::EltWise); + vector::VectorMultiReductionLowering vectorMultiReductionLowering = + llvm::StringSwitch( + getMultireductionLowering()) + .Case("innerreduction", + vector::VectorMultiReductionLowering::InnerReduction) + .Default(vector::VectorMultiReductionLowering::InnerParallel); + vector::VectorContractLowering vectorContractLowering = + llvm::StringSwitch( + getContractionLowering()) + .Case("matrixintrinsics", vector::VectorContractLowering::Matmul) + .Case("dot", vector::VectorContractLowering::Dot) + .Case("outerproduct", vector::VectorContractLowering::OuterProduct) + .Default(vector::VectorContractLowering::OuterProduct); + vector::VectorTransferSplit vectorTransferSplit = + llvm::StringSwitch(getSplitTransfers()) + .Case("none", vector::VectorTransferSplit::None) + .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy) + .Case("vector-transfers", + vector::VectorTransferSplit::VectorTransfer) + .Default(vector::VectorTransferSplit::None); + + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) + .setVectorMultiReductionLowering(vectorMultiReductionLowering) + .setVectorTransposeLowering(vectorTransposeLowering) + .setVectorTransferSplit(vectorTransferSplit); + + VectorTransferToSCFOptions vectorTransferToSCFOptions = + VectorTransferToSCFOptions() + .enableFullUnroll(getUnrollVectorTransfers()) + .enableLowerPermutationMaps(); + + int maxTransferRank = 1; + + auto avx2LoweringOptions = + x86vector::avx2::LoweringOptions().setTransposeOptions( + x86vector::avx2::TransposeLoweringOptions() + .lower4x8xf32(getTransposeAvx2Lowering()) + .lower8x8xf32(getTransposeAvx2Lowering())); + + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + + // In the future we may want to more finely select particular stages. + // Stage 1: contraction lowerings. + patterns.add(vectorTransformOptions, + ctx); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + // Stage 2: multi-reduction lowerings. + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + + // Stage 3: Rewrite vector.transfer into full and partial parts. + patterns.add( + ctx, vectorTransformOptions); + + // Stage 4: Lower vector transfers. + vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); + + // Stage 5: Vector to scf patterns. + populateVectorToSCFConversionPatterns( + patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank)); + + // Stage 6: Lower vector.shape_cast. + vector::populateVectorShapeCastLoweringPatterns(patterns); + + // Stage 7: Lower vector.transpose. + vector::populateVectorTransposeLoweringPatterns(patterns, + vectorTransformOptions); + if (getTransposeAvx2Lowering()) + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avx2LoweringOptions, /*benefit=*/10); + + // Apply everything. + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + } + + transformResults.set(getResults().cast(), results); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @matmul_tensors +func.func @matmul_tensors( + %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) + -> tensor<8x32xf32> { +// CHECK-NOT: linalg +// CHECK: vector.extract {{.*}} : vector<8x4xf32> +// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> + %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) + outs(%arg2: tensor<8x32xf32>) + -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op + %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] + %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation + transform.structured.vectorize %2 + transform.bufferization.one_shot_bufferize %module_op + + %func = transform.structured.match ops{["func.func"]} in %module_op + transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"} +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3363,14 +3363,16 @@ ":ArithDialect", ":AsmParser", ":IR", - ":VectorDialect", - ":VectorTransformOpsIncGen", - ":VectorTransforms", ":PDLDialect", ":Parser", ":SideEffectInterfaces", ":TransformDialect", ":TransformUtils", + ":VectorDialect", + ":VectorToSCF", + ":VectorTransformOpsIncGen", + ":VectorTransforms", + ":X86VectorTransforms", "//llvm:Support", ], )