diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -112,9 +112,12 @@ PatternBenefit benefit = 1); /// Collects patterns that lower scalar vector transfer ops to memref loads and -/// stores when beneficial. +/// stores when beneficial. If `allowMultipleUses` is false, the patterns are +/// only applied to vector transfer reads with a single use. Otherwise, only +/// vector transfer reads with a single use will be lowered. void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); + PatternBenefit benefit, + bool allowMultipleUses); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -561,27 +561,35 @@ } }; -/// Rewrite extractelement(transfer_read) to memref.load. -/// -/// Rewrite only if the extractelement op is the single user of the transfer op. -/// E.g., do not rewrite IR such as: -/// %0 = vector.transfer_read ... : vector<1024xf32> -/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32> -/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32> -/// Rewriting such IR (replacing one vector load with multiple scalar loads) may -/// negatively affect performance. -class RewriteScalarExtractElementOfTransferRead - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` +/// to `memref.load` patterns. The `match` method is shared for both +/// `vector.extract` and `vector.extract_element`. +template +class RewriteScalarExtractOfTransferReadBase + : public OpRewritePattern { + using Base = OpRewritePattern; - LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, - PatternRewriter &rewriter) const override { - auto xferOp = extractOp.getVector().getDefiningOp(); +public: + RewriteScalarExtractOfTransferReadBase(MLIRContext *context, + PatternBenefit benefit, + bool allowMultipleUses) + : Base::OpRewritePattern(context, benefit), + allowMultipleUses(allowMultipleUses) {} + + LogicalResult match(VectorExtractOp extractOp) const override { + auto xferOp = + extractOp.getVector().template getDefiningOp(); if (!xferOp) return failure(); - // xfer result must have a single use. Otherwise, it may be better to - // perform a vector load. - if (!extractOp.getVector().hasOneUse()) + // If multiple uses are not allowed, check if xfer has a single use. + if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) + return failure(); + // If multiple uses are allowed, check if all the xfer uses are extract ops. + if (allowMultipleUses && + !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { + return isa( + use.getOwner()); + })) return failure(); // Mask not supported. if (xferOp.getMask()) @@ -589,11 +597,35 @@ // Map not supported. if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); - // Cannot rewrite if the indices may be out of bounds. The starting point is - // always inbounds, so we don't care in case of 0d transfers. - if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) + // Cannot rewrite if the indices may be out of bounds. + if (xferOp.hasOutOfBoundsDim()) return failure(); + return success(); + } + +private: + bool allowMultipleUses; +}; + +/// Rewrite extractelement(transfer_read) to memref.load. +/// +/// Rewrite only if the extractelement op is the single user of the transfer op. +/// E.g., do not rewrite IR such as: +/// %0 = vector.transfer_read ... : vector<1024xf32> +/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32> +/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32> +/// Rewriting such IR (replacing one vector load with multiple scalar loads) may +/// negatively affect performance so multiple scalar loads are only generated +/// when `allowMultipleUses` is set to true. +class RewriteScalarExtractElementOfTransferRead + : public RewriteScalarExtractOfTransferReadBase { + using RewriteScalarExtractOfTransferReadBase:: + RewriteScalarExtractOfTransferReadBase; + + void rewrite(vector::ExtractElementOp extractOp, + PatternRewriter &rewriter) const override { // Construct scalar load. + auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); if (extractOp.getPosition()) { @@ -617,7 +649,6 @@ rewriter.replaceOpWithNewOp( extractOp, xferOp.getSource(), newIndices); } - return success(); } }; @@ -629,34 +660,17 @@ /// %1 = vector.extract %0[0] : vector<1024xf32> /// %2 = vector.extract %0[5] : vector<1024xf32> /// Rewriting such IR (replacing one vector load with multiple scalar loads) may -/// negatively affect performance. +/// negatively affect performance so multiple scalar loads are only generated +/// when `allowMultipleUses` is set to true. class RewriteScalarExtractOfTransferRead - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public RewriteScalarExtractOfTransferReadBase { + using RewriteScalarExtractOfTransferReadBase:: + RewriteScalarExtractOfTransferReadBase; - LogicalResult matchAndRewrite(vector::ExtractOp extractOp, - PatternRewriter &rewriter) const override { - // Only match scalar extracts. - if (extractOp.getType().isa()) - return failure(); - auto xferOp = extractOp.getVector().getDefiningOp(); - if (!xferOp) - return failure(); - // xfer result must have a single use. Otherwise, it may be better to - // perform a vector load. - if (!extractOp.getVector().hasOneUse()) - return failure(); - // Mask not supported. - if (xferOp.getMask()) - return failure(); - // Map not supported. - if (!xferOp.getPermutationMap().isMinorIdentity()) - return failure(); - // Cannot rewrite if the indices may be out of bounds. The starting point is - // always inbounds, so we don't care in case of 0d transfers. - if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) - return failure(); + void rewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { // Construct scalar load. + auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (const auto &it : llvm::enumerate(extractOp.getPosition())) { @@ -680,7 +694,6 @@ rewriter.replaceOpWithNewOp( extractOp, xferOp.getSource(), newIndices); } - return success(); } }; @@ -744,10 +757,12 @@ } void mlir::vector::populateScalarVectorTransferLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + bool allowMultipleUses) { patterns.add( - patterns.getContext(), benefit); + RewriteScalarExtractOfTransferRead>(patterns.getContext(), + benefit, allowMultipleUses); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorTransferDropUnitDimsPatterns( diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering=allow-multiple-uses -split-input-file | FileCheck %s --check-prefix=MULTIUSE // CHECK-LABEL: func @transfer_read_0d( // CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index @@ -108,3 +109,30 @@ vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref return } + +// ----- + +// CHECK-LABEL: func @transfer_read_multi_use( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index +// CHECK-NOT: memref.load +// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]]] +// CHECK: %[[e0:.*]] = vector.extract %[[r]][0] +// CHECK: %[[e1:.*]] = vector.extract %[[r]][1] +// CHECK: return %[[e0]], %[[e1]] + +// MULTIUSE-LABEL: func @transfer_read_multi_use( +// MULTIUSE-SAME: %[[m:.*]]: memref, %[[idx0:.*]]: index +// MULTIUSE-NOT: vector.transfer_read +// MULTIUSE: %[[r0:.*]] = memref.load %[[m]][%[[idx0]] +// MULTIUSE: %[[idx1:.*]] = affine.apply +// MULTIUSE: %[[r1:.*]] = memref.load %[[m]][%[[idx1]] +// MULTIUSE: return %[[r0]], %[[r1]] + +func.func @transfer_read_multi_use(%m: memref, %idx: index) -> (f32, f32) { + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref, vector<16xf32> + %1 = vector.extract %0[0] : vector<16xf32> + %2 = vector.extract %0[1] : vector<16xf32> + return %1, %2 : f32, f32 +} + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -298,23 +298,33 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestScalarVectorTransferLoweringPatterns) + TestScalarVectorTransferLoweringPatterns() = default; + TestScalarVectorTransferLoweringPatterns( + const TestScalarVectorTransferLoweringPatterns &pass) + : PassWrapper(pass) {} + StringRef getArgument() const final { return "test-scalar-vector-transfer-lowering"; } StringRef getDescription() const final { return "Test lowering of scalar vector transfers to memref loads/stores."; } - TestScalarVectorTransferLoweringPatterns() = default; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + Option allowMultipleUses{ + *this, "allow-multiple-uses", + llvm::cl::desc("Fold transfer operations with multiple uses"), + llvm::cl::init(false)}; + void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - vector::populateScalarVectorTransferLoweringPatterns(patterns); + vector::populateScalarVectorTransferLoweringPatterns( + patterns, /*benefit=*/1, allowMultipleUses.getValue()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } };