diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -97,6 +97,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options); +/// Collect a set of patterns to reduce the rank of the operands of vector +/// transfer ops to operate on the largest contigious vector. +/// These patterns are useful when lowering to dialects with 1d vector type +/// such as llvm and it will result fewer memory reads. +void populateVectorTransferRankReductionPatterns(RewritePatternSet &patterns); + /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. /// If `ifOp` is not null and the result is `success, the `ifOp` points to the diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3813,6 +3813,82 @@ } }; +// Drop inner most contiguous unit dimensions from transfer_read operand. +class DropInnerMostUnitDims : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + auto srcType = readOp.source().getType().cast(); + if (!srcType || !srcType.hasStaticShape()) + return failure(); + + if (!readOp.permutation_map().isIdentity()) + return failure(); + + if (!srcType.getAffineMaps().empty() && + srcType.getAffineMaps().size() != 1) { + return failure(); + } + + auto targetType = readOp.getResult().getType().cast(); + if (targetType.getRank() <= 1) + return failure(); + + SmallVector srcStrides; + int64_t srcOffset; + if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + return failure(); + + size_t dimsToDrop = 0; + for (size_t i = 1; i < srcStrides.size(); ++i) { + int dim = srcType.getRank() - i - 1; + if (srcStrides[dim] == 1) + dimsToDrop++; + } + if (dimsToDrop == 0) + return failure(); + + auto resultTargetVecType = + VectorType::get(targetType.getShape().drop_back(dimsToDrop), + targetType.getElementType()); + + MemRefType resultMemrefType; + if (!srcType.getAffineMaps().empty()) { + AffineMap map = srcType.getAffineMaps()[0]; + int numResultDims = map.getNumDims() - dimsToDrop; + int numSymbols = map.getNumSymbols(); + for (size_t i = 0; i < dimsToDrop; ++i) { + int dim = srcType.getRank() - i - 1; + map = map.replace(rewriter.getAffineDimExpr(dim), + rewriter.getAffineConstantExpr(0), numResultDims, + numSymbols); + resultMemrefType = MemRefType::get( + srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), + map, srcType.getMemorySpaceAsInt()); + } + } else { + resultMemrefType = MemRefType::get( + srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), + {}, srcType.getMemorySpaceAsInt()); + } + + auto loc = readOp.getLoc(); + SmallVector offsets(srcType.getRank(), 0); + SmallVector strides(srcType.getRank(), 1); + Value rankedReducedView = rewriter.create( + loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), + strides); + Value result = rewriter.create( + loc, resultTargetVecType, rankedReducedView, + readOp.indices().drop_back(dimsToDrop)); + rewriter.replaceOpWithNewOp(readOp, targetType, + result); + + return success(); + } +}; + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool enableIndexOptimizations) { patterns.add( patterns.getContext(), options); } + +void mlir::vector::populateVectorTransferRankReductionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-rank-reduction-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-rank-reduction-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-rank-reduction-transforms.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -test-vector-transfer-rank-reduction -split-input-file | FileCheck %s + +#map1 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)> +func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, #map1>) -> vector<1x1x8x1xf32>{ + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x1x8x1xf32, #map1>, vector<1x1x8x1xf32> + return %0 : vector<1x1x8x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2)> +// CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, #[[MAP0]]> +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK-SAME: memref<1x1x8x1xf32, #[[MAP0]]> to memref<1x1x8xf32, #[[MAP1]]> +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] +// CHECK-SAME: memref<1x1x8xf32, #[[MAP1]]>, vector<1x1x8xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] +// CHECK: return %[[RESULT]] +// ----- + +func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32> + return %1 : vector<8x1xf32> +} +// CHECK: func @contiguous_inner_most_dim(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32> +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK-SAME: memref<16x1xf32> to memref<16xf32> +// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]] +// CHECK: %[[RESULT]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32> +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -471,6 +471,33 @@ } }; +struct TestVectorTransferRankReductionPatterns + : public PassWrapper { + TestVectorTransferRankReductionPatterns() = default; + TestVectorTransferRankReductionPatterns( + const TestVectorTransferRankReductionPatterns &pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-vector-transfer-rank-reduction"; + } + + StringRef getDescription() const final { + return "Test conversion patterns that reducedes the rank of the vector " + "transfer memory and vector operands."; + } + + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateVectorTransferRankReductionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -495,6 +522,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir