diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -114,7 +114,8 @@ /// Collects patterns that lower scalar vector transfer ops to memref loads and /// stores when beneficial. void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); + PatternBenefit benefit, + bool allowMultipleUses); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -574,14 +574,27 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; +public: + RewriteScalarExtractElementOfTransferRead(MLIRContext *context, + PatternBenefit benefit, + bool allowMultipleUses) + : OpRewritePattern(context, benefit), + allowMultipleUses(allowMultipleUses) {} + LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, PatternRewriter &rewriter) const override { auto xferOp = extractOp.getVector().getDefiningOp(); if (!xferOp) return failure(); - // xfer result must have a single use. Otherwise, it may be better to - // perform a vector load. - if (!extractOp.getVector().hasOneUse()) + // If multiple uses are not allowed, check if xfer has a single use. + if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) + return failure(); + // If multiple uses are allowed, check if all the xfer uses are extract ops. + if (allowMultipleUses && + !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { + return isa(use.getOwner()) || + isa(use.getOwner()); + })) return failure(); // Mask not supported. if (xferOp.getMask()) @@ -589,9 +602,8 @@ // Map not supported. if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); - // Cannot rewrite if the indices may be out of bounds. The starting point is - // always inbounds, so we don't care in case of 0d transfers. - if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) + // Cannot rewrite if the indices may be out of bounds. + if (xferOp.hasOutOfBoundsDim()) return failure(); // Construct scalar load. SmallVector newIndices(xferOp.getIndices().begin(), @@ -619,6 +631,9 @@ } return success(); } + +private: + bool allowMultipleUses; }; /// Rewrite extract(transfer_read) to memref.load. @@ -634,6 +649,13 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; +public: + RewriteScalarExtractOfTransferRead(MLIRContext *context, + PatternBenefit benefit, + bool allowMultipleUses) + : OpRewritePattern(context, benefit), + allowMultipleUses(allowMultipleUses) {} + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Only match scalar extracts. @@ -642,9 +664,15 @@ auto xferOp = extractOp.getVector().getDefiningOp(); if (!xferOp) return failure(); - // xfer result must have a single use. Otherwise, it may be better to - // perform a vector load. - if (!extractOp.getVector().hasOneUse()) + // If multiple uses are not allowed, check if xfer has a single use. + if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) + return failure(); + // If multiple uses are allowed, check if all the xfer uses are extract ops. + if (allowMultipleUses && + !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { + return isa(use.getOwner()) || + isa(use.getOwner()); + })) return failure(); // Mask not supported. if (xferOp.getMask()) @@ -652,9 +680,8 @@ // Map not supported. if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); - // Cannot rewrite if the indices may be out of bounds. The starting point is - // always inbounds, so we don't care in case of 0d transfers. - if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) + // Cannot rewrite if the indices may be out of bounds. + if (xferOp.hasOutOfBoundsDim()) return failure(); // Construct scalar load. SmallVector newIndices(xferOp.getIndices().begin(), @@ -682,6 +709,9 @@ } return success(); } + +private: + bool allowMultipleUses; }; /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) @@ -744,10 +774,12 @@ } void mlir::vector::populateScalarVectorTransferLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + bool allowMultipleUses) { patterns.add( - patterns.getContext(), benefit); + RewriteScalarExtractOfTransferRead>(patterns.getContext(), + benefit, allowMultipleUses); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorTransferDropUnitDimsPatterns( diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering=allow-multiple-uses -split-input-file | FileCheck %s --check-prefix=MULTIUSE // CHECK-LABEL: func @transfer_read_0d( // CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index @@ -108,3 +109,30 @@ vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref return } + +// ----- + +// CHECK-LABEL: func @transfer_read_multi_use( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index +// CHECK-NOT: memref.load +// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]]] +// CHECK: %[[e0:.*]] = vector.extract %[[r]][0] +// CHECK: %[[e1:.*]] = vector.extract %[[r]][1] +// CHECK: return %[[e0]], %[[e1]] + +// MULTIUSE-LABEL: func @transfer_read_multi_use( +// MULTIUSE-SAME: %[[m:.*]]: memref, %[[idx0:.*]]: index +// MULTIUSE-NOT: vector.transfer_read +// MULTIUSE: %[[r0:.*]] = memref.load %[[m]][%[[idx0]] +// MULTIUSE: %[[idx1:.*]] = affine.apply +// MULTIUSE: %[[r1:.*]] = memref.load %[[m]][%[[idx1]] +// MULTIUSE: return %[[r0]], %[[r1]] + +func.func @transfer_read_multi_use(%m: memref, %idx: index) -> (f32, f32) { + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref, vector<16xf32> + %1 = vector.extract %0[0] : vector<16xf32> + %2 = vector.extract %0[1] : vector<16xf32> + return %1, %2 : f32, f32 +} + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -298,23 +298,33 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestScalarVectorTransferLoweringPatterns) + TestScalarVectorTransferLoweringPatterns() = default; + TestScalarVectorTransferLoweringPatterns( + const TestScalarVectorTransferLoweringPatterns &pass) + : PassWrapper(pass) {} + StringRef getArgument() const final { return "test-scalar-vector-transfer-lowering"; } StringRef getDescription() const final { return "Test lowering of scalar vector transfers to memref loads/stores."; } - TestScalarVectorTransferLoweringPatterns() = default; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + Option allowMultipleUses{ + *this, "allow-multiple-uses", + llvm::cl::desc("Fold transfer operations with multiple uses"), + llvm::cl::init(false)}; + void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - vector::populateScalarVectorTransferLoweringPatterns(patterns); + vector::populateScalarVectorTransferLoweringPatterns( + patterns, /*benefit=*/1, allowMultipleUses.getValue()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } };