diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -521,6 +521,9 @@ "Perform full unrolling when converting vector transfers to SCF">, Option<"targetRank", "target-rank", "unsigned", /*default=*/"1", "Target vector rank to which transfer ops should be lowered">, + Option<"lowerPermutationMaps", "lower-permutation-maps", "bool", + /*default=*/"false", "Replace permutation maps with vector " + "transposes/broadcasts before lowering transfer ops"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -50,6 +50,7 @@ struct VectorTransferToSCFOptions { bool unroll = false; unsigned targetRank = 1; + bool lowerPermutationMaps = false; VectorTransferToSCFOptions &setUnroll(bool u) { unroll = u; @@ -60,6 +61,11 @@ targetRank = r; return *this; } + + VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) { + lowerPermutationMaps = l; + return *this; + } }; /// Collect a set of patterns to convert from the Vector dialect to SCF + std. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -86,9 +87,16 @@ /// Collect a set of transfer read/write lowering patterns. /// /// These patterns lower transfer ops to simpler ops like `vector.load`, -/// `vector.store` and `vector.broadcast`. +/// `vector.store` and `vector.broadcast`. Includes all patterns of +/// populateVectorTransferPermutationMapLoweringPatterns. void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +/// Collect a set of transfer read/write lowering patterns that simplify the +/// permutation map (e.g., converting it to a minor identity map) by inserting +/// broadcasts and transposes. +void populateVectorTransferPermutationMapLoweringPatterns( + RewritePatternSet &patterns); + /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool enableIndexOptimizations); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -264,6 +264,7 @@ if (xferOp.mask()) { auto maskType = MemRefType::get({}, xferOp.mask().getType()); auto maskBuffer = memref_alloca(maskType).value; + b.setInsertionPoint(xferOp); memref_store(xferOp.mask(), maskBuffer); result.maskBuffer = memref_load(maskBuffer); } @@ -476,10 +477,11 @@ }; template -LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) { +LogicalResult checkPrepareXferOp(OpTy xferOp, + VectorTransferToSCFOptions options) { if (xferOp->hasAttr(kPassLabel)) return failure(); - if (xferOp.getVectorType().getRank() <= targetRank) + if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); return success(); } @@ -513,7 +515,7 @@ LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp, options.targetRank).failed()) + if (checkPrepareXferOp(xferOp, options).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -561,7 +563,7 @@ LogicalResult matchAndRewrite(TransferWriteOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp, options.targetRank).failed()) + if (checkPrepareXferOp(xferOp, options).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -1160,12 +1162,23 @@ ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { this->fullUnroll = options.unroll; this->targetRank = options.targetRank; + this->lowerPermutationMaps = options.lowerPermutationMaps; } void runOnFunction() override { VectorTransferToSCFOptions options; - options.setUnroll(fullUnroll); - options.setTargetRank(targetRank); + options.unroll = fullUnroll; + options.targetRank = targetRank; + options.lowerPermutationMaps = lowerPermutationMaps; + + // Lower permutation maps first. + if (lowerPermutationMaps) { + RewritePatternSet lowerTransferPatterns(getFunction().getContext()); + mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( + lowerTransferPatterns); + (void)applyPatternsAndFoldGreedily(getFunction(), + std::move(lowerTransferPatterns)); + } RewritePatternSet patterns(getFunction().getContext()); populateVectorToSCFConversionPatterns(patterns, options); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2934,8 +2934,8 @@ /// - The op has no mask. struct TransferReadToVectorLoadLowering : public OpRewritePattern { - TransferReadToVectorLoadLowering(MLIRContext *context) - : OpRewritePattern(context) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { SmallVector broadcastedDims; @@ -3009,8 +3009,8 @@ /// - The op has no mask. struct TransferWriteToVectorStoreLowering : public OpRewritePattern { - TransferWriteToVectorStoreLowering(MLIRContext *context) - : OpRewritePattern(context) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { // TODO: Support non-minor-identity maps @@ -3086,6 +3086,7 @@ if (permutationMap.isIdentity()) return failure(); + permutationMap = map.getPermutationMap(permutation, op.getContext()); // Caluclate the map of the new read by applying the inverse permutation. permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); @@ -4149,13 +4150,18 @@ patterns.getContext()); } +void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns) { - patterns - .add( - patterns.getContext()); + patterns.add(patterns.getContext()); + populateVectorTransferPermutationMapLoweringPatterns(patterns); } void mlir::vector::populateVectorMultiReductionLoweringPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -split-input-file | FileCheck %s + +// Ensure that the permutation map is lowered (by inserting a transpose op) +// before lowering the vector.transfer_read. + +// CHECK-LABEL: func @transfer_read_2d_mask_transposed( +// CHECK-DAG: %[[PADDING:.*]] = constant dense<-4.200000e+01> : vector<9xf32> +// CHECK-DAG: %[[MASK:.*]] = constant dense<{{.*}}> : vector<9x4xi1> +// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref> +// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1> +// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref> +// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref> to memref<4xvector<9xi1>> +// CHECK: scf.for {{.*}} { +// CHECK: scf.if {{.*}} { +// CHECK: %[[MASK_LOADED:.*]] = memref.load %[[MASK_CASTED]][%{{.*}}] : memref<4xvector<9xi1>> +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}, %{{.*}}, %[[MASK_LOADED]] : memref, vector<9xf32> +// CHECK: memref.store %[[READ]], %{{.*}} : memref<4xvector<9xf32>> +// CHECK: } +// CHECK: } +// CHECK: %[[RESULT:.*]] = memref.load %{{.*}} : memref> +// CHECK: %[[RESULT_T:.*]] = vector.transpose %[[RESULT]], [1, 0] : vector<4x9xf32> to vector<9x4xf32> +// CHECK: return %[[RESULT_T]] : vector<9x4xf32> + +// Vector load with mask + transpose. +func @transfer_read_2d_mask_transposed( + %A : memref, %base1: index, %base2: index) -> (vector<9x4xf32>) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0], + [1, 1, 1, 1], [0, 1, 1, 0], + [1, 1, 1, 1], [1, 1, 1, 1], + [1, 1, 1, 1], [0, 0, 0, 0], + [1, 1, 1, 1]]> : vector<9x4xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : + memref, vector<9x4xf32> + return %f : vector<9x4xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir deleted file mode 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// Run test with and without test-vector-transfer-lowering-patterns. - -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - -// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - - -memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ], - [10., 11., 12., 13.], - [20., 21., 22., 23.]]> - -// Vector load with transpose. -func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { - %fm42 = constant -42.0: f32 - %f = vector.transfer_read %A[%base1, %base2], %fm42 - {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : - memref, vector<3x9xf32> - vector.print %f: vector<3x9xf32> - return -} - -func @entry() { - %c0 = constant 0: index - %c1 = constant 1: index - %c2 = constant 2: index - %c3 = constant 3: index - %0 = memref.get_global @gv : memref<3x4xf32> - %A = memref.cast %0 : memref<3x4xf32> to memref - - // 1. Read 2D vector from 2D memref with transpose. - call @transfer_read_2d(%A, %c1, %c2) : (memref, index, index) -> () - // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) ) - - return -} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -3,7 +3,17 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -3,7 +3,17 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -3,7 +3,17 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s