diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -142,6 +142,11 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Collects patterns that lower scalar vector transfer ops to memref loads and +/// stores when beneficial. +void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -32,6 +32,7 @@ MLIRMemRefDialect MLIRSCFDialect MLIRSideEffectInterfaces + MLIRTensorDialect MLIRTransforms MLIRVectorDialect MLIRVectorInterfaces diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -11,8 +11,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" @@ -556,6 +558,99 @@ } }; +/// Rewrite extractelement(transfer_read) to memref.load. +/// +/// Rewrite only if the extractelement op is the single user of the transfer op. +/// E.g., do not rewrite IR such as: +/// %0 = vector.transfer_read ... : vector<1024xf32> +/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32> +/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32> +/// Rewriting such IR (replacing one vector load with multiple scalar loads) may +/// negatively affect performance. +class FoldScalarExtractOfTransferRead + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, + PatternRewriter &rewriter) const override { + auto xferOp = extractOp.getVector().getDefiningOp(); + // xfer result must have a single use. Otherwise, it may be better to + // perform a vector load. + if (!extractOp.getVector().hasOneUse()) + return failure(); + // Mask not supported. + if (xferOp.getMask()) + return failure(); + // Map not supported. + if (!xferOp.getPermutationMap().isMinorIdentity()) + return failure(); + // Cannot rewrite if the indices may be out of bounds. The starting point is + // always inbounds, so we don't care in case of 0d transfers. + if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) + return failure(); + // Construct scalar load. + SmallVector newIndices(xferOp.getIndices().begin(), + xferOp.getIndices().end()); + if (extractOp.getPosition()) { + AffineExpr sym0, sym1; + bindSymbols(extractOp.getContext(), sym0, sym1); + OpFoldResult ofr = makeComposedFoldedAffineApply( + rewriter, extractOp.getLoc(), sym0 + sym1, + {newIndices[newIndices.size() - 1], extractOp.getPosition()}); + if (ofr.is()) { + newIndices[newIndices.size() - 1] = ofr.get(); + } else { + newIndices[newIndices.size() - 1] = + rewriter.create(extractOp.getLoc(), + *getConstantIntValue(ofr)); + } + } + if (xferOp.getSource().getType().isa()) { + rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), + newIndices); + } else { + rewriter.replaceOpWithNewOp( + extractOp, xferOp.getSource(), newIndices); + } + return success(); + } +}; + +/// Rewrite scalar transfer_write(broadcast) to memref.store. +class FoldScalarTransferWriteOfBroadcast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, + PatternRewriter &rewriter) const override { + // Must be a scalar write. + auto vecType = xferOp.getVectorType(); + if (vecType.getRank() != 0 && + (vecType.getRank() != 1 || vecType.getShape()[0] != 1)) + return failure(); + // Mask not supported. + if (xferOp.getMask()) + return failure(); + // Map not supported. + if (!xferOp.getPermutationMap().isMinorIdentity()) + return failure(); + // Must be a broadcast of a scalar. + auto broadcastOp = xferOp.getVector().getDefiningOp(); + if (broadcastOp.getSource().getType().isa()) + return failure(); + // Construct a scalar store. + if (xferOp.getSource().getType().isa()) { + rewriter.replaceOpWithNewOp( + xferOp, broadcastOp.getSource(), xferOp.getSource(), + xferOp.getIndices()); + } else { + rewriter.replaceOpWithNewOp( + xferOp, broadcastOp.getSource(), xferOp.getSource(), + xferOp.getIndices()); + } + return success(); + } +}; } // namespace void mlir::vector::transferOpflowOpt(Operation *rootOp) { @@ -574,6 +669,13 @@ opt.removeDeadOp(); } +void mlir::vector::populateScalarVectorTransferLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns + .add( + patterns.getContext(), benefit); +} + void mlir::vector::populateVectorTransferDropUnitDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s + +// CHECK-LABEL: func @transfer_read_0d( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index +// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: return %[[r]] +func.func @transfer_read_0d(%m: memref, %idx: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref, vector + %1 = vector.extractelement %0[] : vector + return %1 : f32 +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @transfer_read_1d( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[idx2:.*]]: index +// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]] +// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]] +// CHECK: return %[[r]] +func.func @transfer_read_1d(%m: memref, %idx: index, %idx2: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref, vector<5xf32> + %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @tensor_transfer_read_0d( +// CHECK-SAME: %[[t:.*]]: tensor, %[[idx:.*]]: index +// CHECK: %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: return %[[r]] +func.func @tensor_transfer_read_0d(%t: tensor, %idx: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor, vector + %1 = vector.extractelement %0[] : vector + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @transfer_write_0d( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[f:.*]]: f32 +// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +func.func @transfer_write_0d(%m: memref, %idx: index, %f: f32) { + %0 = vector.broadcast %f : f32 to vector + vector.transfer_write %0, %m[%idx, %idx, %idx] : vector, memref + return +} + +// ----- + +// CHECK-LABEL: func @transfer_write_1d( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[f:.*]]: f32 +// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +func.func @transfer_write_1d(%m: memref, %idx: index, %f: f32) { + %0 = vector.broadcast %f : f32 to vector<1xf32> + vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref + return +} + +// ----- + +// CHECK-LABEL: func @tensor_transfer_write_0d( +// CHECK-SAME: %[[t:.*]]: tensor, %[[idx:.*]]: index, %[[f:.*]]: f32 +// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: return %[[r]] +func.func @tensor_transfer_write_0d(%t: tensor, %idx: index, %f: f32) -> tensor { + %0 = vector.broadcast %f : f32 to vector + %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector, tensor + return %1 : tensor +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -462,6 +462,33 @@ } }; +struct TestScalarVectorTransferLoweringPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestScalarVectorTransferLoweringPatterns) + + StringRef getArgument() const final { + return "test-scalar-vector-transfer-lowering"; + } + StringRef getDescription() const final { + return "Test lowering of scalar vector transfers to memref loads/stores."; + } + TestScalarVectorTransferLoweringPatterns() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + vector::populateScalarVectorTransferLoweringPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestVectorTransferOpt : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) @@ -869,6 +896,8 @@ PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();