diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -39,9 +39,9 @@ /// Collects a set of patterns to rewrite ops within the memref dialect. void populateExpandOpsPatterns(RewritePatternSet &patterns); -/// Appends patterns for folding memref.subview ops into consumer load/store ops -/// into `patterns`. -void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); +/// Appends patterns for folding memref aliasing ops into consumer load/store +/// ops into `patterns`. +void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns); /// Appends patterns that resolve `memref.dim` operations with values that are /// defined by operations that implement the @@ -91,9 +91,9 @@ /// `memref_reinterpret_cast`. std::unique_ptr createExpandOpsPass(); -/// Creates an operation pass to fold memref.subview ops into consumer +/// Creates an operation pass to fold memref aliasing ops into consumer /// load/store ops into `patterns`. -std::unique_ptr createFoldSubViewOpsPass(); +std::unique_ptr createFoldMemRefAliasOpsPass(); /// Creates an interprocedural pass to normalize memrefs to have a trivial /// (identity) layout map. diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -16,13 +16,13 @@ let constructor = "mlir::memref::createExpandOpsPass()"; } -def FoldSubViewOps : Pass<"fold-memref-subview-ops"> { - let summary = "Fold memref.subview ops into consumer load/store ops"; +def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> { + let summary = "Fold memref alias ops into consumer load/store ops"; let description = [{ - The pass folds loading/storing from/to subview ops to loading/storing + The pass folds loading/storing from/to memref aliasing ops to loading/storing from/to the original memref. }]; - let constructor = "mlir::memref::createFoldSubViewOpsPass()"; + let constructor = "mlir::memref::createFoldMemRefAliasOpsPass()"; let dependentDialects = [ "AffineDialect", "memref::MemRefDialect", "vector::VectorDialect" ]; diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_UTILS_INDEXINGUTILS_H #define MLIR_DIALECT_UTILS_INDEXINGUTILS_H +#include "mlir/IR/Builders.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -47,6 +48,15 @@ SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); + +/// Computes and returns linearized affine expression w.r.t. `basis`. +mlir::AffineExpr getLinearAffineExpr(ArrayRef basis, mlir::Builder &b); + +/// Given the strides in the dimension space, returns the affine expressions for +/// vector-space offsets in each dimension for a de-linearized index. +SmallVector +getDelinearizedAffineExpr(ArrayRef strides, mlir::Builder &b); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1486,6 +1486,13 @@ "ValueRange":$indices, CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; + + let extraClassDeclaration = [{ + /// This method is added to maintain uniformity with load/store + /// ops of other dialects. + Value getValue() { return getVector(); } + }]; + let hasFolder = 1; let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms ComposeSubView.cpp ExpandOps.cpp - FoldSubViewOps.cpp + FoldMemRefAliasOps.cpp MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -0,0 +1,562 @@ +//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of a load/store operation where the memref is a result +/// of a expand_shape op, returns the indices w.r.t to the source memref of the +/// expand_shape op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = memref.expand_shape %0 [[0, 1], [2]] +/// : memref<12x42xf32> into memref<2x6x42xf32> +/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 +/// +/// could be folded into +/// +/// %2 = load %0[6 * i1 + i2, %i3] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + for (SmallVector groups : + expandShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + unsigned groupSize = groups.size(); + SmallVector suffixProduct(groupSize); + // Calculate suffix product of dimension sizes for all dimensions of expand + // shape op result. + suffixProduct[groupSize - 1] = 1; + for (unsigned i = groupSize - 1; i > 0; i--) + suffixProduct[i - 1] = + suffixProduct[i] * + expandShapeOp.getType().cast().getDimSize(groups[i]); + SmallVector dynamicIndices(groupSize); + for (unsigned i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; + // Construct the expression for the index value w.r.t to expand shape op + // source corresponding the indices wrt to expand shape op result. + AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter); + sourceIndices.push_back(rewriter.create( + loc, + AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr), + dynamicIndices)); + } + return success(); +} + +/// Given the 'indices' of a load/store operation where the memref is a result +/// of a collapse_shape op, returns the indices w.r.t to the source memref of +/// the collapse_shape op. For example +/// +/// %0 = ... : memref<2x6x42xf32> +/// %1 = memref.collapse_shape %0 [[0, 1], [2]] +/// : memref<2x6x42xf32> into memref<12x42xf32> +/// %2 = load %1[%i1, %i2] : memref<12x42xf32> +/// +/// could be folded into +/// +/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : +/// memref<2x6x42xf32> +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + unsigned cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (SmallVector groups : + collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + unsigned groupSize = groups.size(); + SmallVector suffixProduct(groupSize); + // Calculate suffix product for all collapse op source dimension sizes. + suffixProduct[groupSize - 1] = 1; + for (unsigned i = groupSize - 1; i > 0; i--) + suffixProduct[i - 1] = + suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]); + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + SmallVector srcIndexExpr = + getDelinearizedAffineExpr(suffixProduct, rewriter); + for (unsigned i = 0; i < groupSize; i++) + sourceIndices.push_back(rewriter.create( + loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]), + dynamicIndices)); + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + unsigned srcRank = + collapseShapeOp.getViewSource().getType().cast().getRank(); + for (unsigned i = 0; i < srcRank; i++) + sourceIndices.push_back( + rewriter.create(loc, zeroAffineMap, dynamicIndices)); + } + return success(); +} + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter, + memref::SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl &sourceIndices) { + SmallVector mixedOffsets = subViewOp.getMixedOffsets(); + SmallVector mixedSizes = subViewOp.getMixedSizes(); + SmallVector mixedStrides = subViewOp.getMixedStrides(); + + SmallVector useIndices; + // Check if this is rank-reducing case. Then for every unit-dim size add a + // zero to the indices. + unsigned resultDim = 0; + llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); + for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { + if (unusedDims.test(dim)) + useIndices.push_back(rewriter.create(loc, 0)); + else + useIndices.push_back(indices[resultDim++]); + } + if (useIndices.size() != mixedOffsets.size()) + return failure(); + sourceIndices.resize(useIndices.size()); + for (auto index : llvm::seq(0, mixedOffsets.size())) { + SmallVector dynamicOperands; + AffineExpr expr = rewriter.getAffineDimExpr(0); + unsigned numSymbols = 0; + dynamicOperands.push_back(useIndices[index]); + + // Multiply the stride; + if (auto attr = mixedStrides[index].dyn_cast()) { + expr = expr * attr.cast().getInt(); + } else { + dynamicOperands.push_back(mixedStrides[index].get()); + expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); + } + + // Add the offset. + if (auto attr = mixedOffsets[index].dyn_cast()) { + expr = expr + attr.cast().getInt(); + } else { + dynamicOperands.push_back(mixedOffsets[index].get()); + expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); + } + Location loc = subViewOp.getLoc(); + sourceIndices[index] = rewriter.create( + loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); + } + return success(); +} + +/// Helpers to access the memref operand for each op. +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.getMemref(); +} + +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getMemRefOperand(vector::TransferWriteOp op) { + return op.getSource(); +} + +/// Given the permutation map of the original +/// `vector.transfer_read`/`vector.transfer_write` operations compute the +/// permutation map to use after the subview is folded with it. +static AffineMapAttr getPermutationMapAttr(MLIRContext *context, + memref::SubViewOp subViewOp, + AffineMap currPermutationMap) { + llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); + SmallVector exprs; + int64_t sourceRank = subViewOp.getSourceType().getRank(); + for (auto dim : llvm::seq(0, sourceRank)) { + if (unusedDims.test(dim)) + continue; + exprs.push_back(getAffineDimExpr(dim, context)); + } + auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); + return AffineMapAttr::get( + currPermutationMap.compose(resultDimToSourceDimMap)); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { +/// Merges subview operation with load/transferRead operation. +template +class LoadOpOfSubViewOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges expand_shape operation with load/transferRead operation. +template +class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with load/transferRead operation. +template +class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges subview operation with store/transferWriteOp operation. +template +class StoreOpOfSubViewOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy storeOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges expand_shape operation with store/transferWriteOp operation. +template +class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy storeOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with store/transferWriteOp operation. +template +class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy storeOp, + PatternRewriter &rewriter) const override; +}; + +} // namespace + +static SmallVector +calculateExpandedAccessIndices(AffineMap affineMap, SmallVector indices, + Location loc, PatternRewriter &rewriter) { + SmallVector expandedIndices; + for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) + expandedIndices.push_back( + rewriter.create(loc, affineMap.getSubMap({i}), indices)); + return expandedIndices; +} + +template +LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto subViewOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!subViewOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp, + indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), + sourceIndices); + }) + .Case([&](vector::TransferReadOp transferReadOp) { + if (transferReadOp.getTransferRank() == 0) { + // TODO: Propagate the error. + return; + } + rewriter.replaceOpWithNewOp( + transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), + sourceIndices, + getPermutationMapAttr(rewriter.getContext(), subViewOp, + transferReadOp.getPermutationMap()), + transferReadOp.getPadding(), + /*mask=*/Value(), transferReadOp.getInBoundsAttr()); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LoadOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp( + loadOp, expandShapeOp.getViewSource(), sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LoadOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(loadOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp( + loadOp, collapseShapeOp.getViewSource(), sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( + OpTy storeOp, PatternRewriter &rewriter) const { + auto subViewOp = + getMemRefOperand(storeOp).template getDefiningOp(); + + if (!subViewOp) + return failure(); + + SmallVector indices(storeOp.getIndices().begin(), + storeOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineStoreOp = dyn_cast(storeOp.getOperation())) { + AffineMap affineMap = affineStoreOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, storeOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp, + indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(storeOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices); + }) + .Case([&](vector::TransferWriteOp op) { + // TODO: support 0-d corner case. + if (op.getTransferRank() == 0) + return; + rewriter.replaceOpWithNewOp( + op, op.getValue(), subViewOp.source(), sourceIndices, + getPermutationMapAttr(rewriter.getContext(), subViewOp, + op.getPermutationMap()), + op.getInBoundsAttr()); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult StoreOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy storeOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(storeOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(storeOp.getIndices().begin(), + storeOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineStoreOp = dyn_cast(storeOp.getOperation())) { + AffineMap affineMap = affineStoreOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, storeOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(storeOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp(storeOp, storeOp.getValue(), + expandShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult StoreOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy storeOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(storeOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(storeOp.getIndices().begin(), + storeOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineStoreOp = dyn_cast(storeOp.getOperation())) { + AffineMap affineMap = affineStoreOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, storeOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(storeOp) + .Case([&](auto op) { + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { + patterns.add, + LoadOpOfSubViewOpFolder, + LoadOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, + LoadOpOfExpandShapeOpFolder, + LoadOpOfExpandShapeOpFolder, + StoreOpOfExpandShapeOpFolder, + StoreOpOfExpandShapeOpFolder, + LoadOpOfCollapseShapeOpFolder, + LoadOpOfCollapseShapeOpFolder, + StoreOpOfCollapseShapeOpFolder, + StoreOpOfCollapseShapeOpFolder>( + patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" + +struct FoldMemRefAliasOpsPass final + : public FoldMemRefAliasOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void FoldMemRefAliasOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); +} + +std::unique_ptr memref::createFoldMemRefAliasOpsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ /dev/null @@ -1,276 +0,0 @@ -//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This transformation pass folds loading/storing from/to subview ops into -// loading/storing from/to the original memref. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallBitVector.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -/// Given the 'indices' of an load/store operation where the memref is a result -/// of a subview op, returns the indices w.r.t to the source memref of the -/// subview op. For example -/// -/// %0 = ... : memref<12x42xf32> -/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to -/// memref<4x4xf32, offset=?, strides=[?, ?]> -/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> -/// -/// could be folded into -/// -/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> -static LogicalResult -resolveSourceIndices(Location loc, PatternRewriter &rewriter, - memref::SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { - SmallVector mixedOffsets = subViewOp.getMixedOffsets(); - SmallVector mixedSizes = subViewOp.getMixedSizes(); - SmallVector mixedStrides = subViewOp.getMixedStrides(); - - SmallVector useIndices; - // Check if this is rank-reducing case. Then for every unit-dim size add a - // zero to the indices. - unsigned resultDim = 0; - llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { - if (unusedDims.test(dim)) - useIndices.push_back(rewriter.create(loc, 0)); - else - useIndices.push_back(indices[resultDim++]); - } - if (useIndices.size() != mixedOffsets.size()) - return failure(); - sourceIndices.resize(useIndices.size()); - for (auto index : llvm::seq(0, mixedOffsets.size())) { - SmallVector dynamicOperands; - AffineExpr expr = rewriter.getAffineDimExpr(0); - unsigned numSymbols = 0; - dynamicOperands.push_back(useIndices[index]); - - // Multiply the stride; - if (auto attr = mixedStrides[index].dyn_cast()) { - expr = expr * attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedStrides[index].get()); - expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); - } - - // Add the offset. - if (auto attr = mixedOffsets[index].dyn_cast()) { - expr = expr + attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedOffsets[index].get()); - expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); - } - Location loc = subViewOp.getLoc(); - sourceIndices[index] = rewriter.create( - loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); - } - return success(); -} - -/// Helpers to access the memref operand for each op. -template -static Value getMemRefOperand(LoadOrStoreOpTy op) { - return op.getMemref(); -} - -static Value getMemRefOperand(vector::TransferReadOp op) { - return op.getSource(); -} - -static Value getMemRefOperand(vector::TransferWriteOp op) { - return op.getSource(); -} - -/// Given the permutation map of the original -/// `vector.transfer_read`/`vector.transfer_write` operations compute the -/// permutation map to use after the subview is folded with it. -static AffineMapAttr getPermutationMapAttr(MLIRContext *context, - memref::SubViewOp subViewOp, - AffineMap currPermutationMap) { - llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - SmallVector exprs; - int64_t sourceRank = subViewOp.getSourceType().getRank(); - for (auto dim : llvm::seq(0, sourceRank)) { - if (unusedDims.test(dim)) - continue; - exprs.push_back(getAffineDimExpr(dim, context)); - } - auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); - return AffineMapAttr::get( - currPermutationMap.compose(resultDimToSourceDimMap)); -} - -//===----------------------------------------------------------------------===// -// Patterns -//===----------------------------------------------------------------------===// - -namespace { -/// Merges subview operation with load/transferRead operation. -template -class LoadOpOfSubViewFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy loadOp, - PatternRewriter &rewriter) const override; - -private: - void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, - PatternRewriter &rewriter) const; -}; - -/// Merges subview operation with store/transferWriteOp operation. -template -class StoreOpOfSubViewFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy storeOp, - PatternRewriter &rewriter) const override; - -private: - void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, - PatternRewriter &rewriter) const; -}; - -template -void LoadOpOfSubViewFolder::replaceOp( - LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, - PatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(loadOp, subViewOp.getSource(), - sourceIndices); -} - -template <> -void LoadOpOfSubViewFolder::replaceOp( - vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, PatternRewriter &rewriter) const { - // TODO: support 0-d corner case. - if (transferReadOp.getTransferRank() == 0) - return; - rewriter.replaceOpWithNewOp( - transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(), - sourceIndices, - getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferReadOp.getPermutationMap()), - transferReadOp.getPadding(), - /*mask=*/Value(), transferReadOp.getInBoundsAttr()); -} - -template -void StoreOpOfSubViewFolder::replaceOp( - StoreOpTy storeOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, PatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(storeOp, storeOp.getValue(), - subViewOp.getSource(), sourceIndices); -} - -template <> -void StoreOpOfSubViewFolder::replaceOp( - vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, PatternRewriter &rewriter) const { - // TODO: support 0-d corner case. - if (transferWriteOp.getTransferRank() == 0) - return; - rewriter.replaceOpWithNewOp( - transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(), - sourceIndices, - getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferWriteOp.getPermutationMap()), - transferWriteOp.getInBoundsAttr()); -} -} // namespace - -template -LogicalResult -LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, - PatternRewriter &rewriter) const { - auto subViewOp = - getMemRefOperand(loadOp).template getDefiningOp(); - if (!subViewOp) - return failure(); - - SmallVector sourceIndices; - if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, - loadOp.getIndices(), sourceIndices))) - return failure(); - - replaceOp(loadOp, subViewOp, sourceIndices, rewriter); - return success(); -} - -template -LogicalResult -StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, - PatternRewriter &rewriter) const { - auto subViewOp = - getMemRefOperand(storeOp).template getDefiningOp(); - if (!subViewOp) - return failure(); - - SmallVector sourceIndices; - if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, - storeOp.getIndices(), sourceIndices))) - return failure(); - - replaceOp(storeOp, subViewOp, sourceIndices, rewriter); - return success(); -} - -void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { - patterns.add, - LoadOpOfSubViewFolder, - LoadOpOfSubViewFolder, - StoreOpOfSubViewFolder, - StoreOpOfSubViewFolder, - StoreOpOfSubViewFolder>( - patterns.getContext()); -} - -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - -namespace { - -struct FoldSubViewOpsPass final - : public FoldSubViewOpsBase { - void runOnOperation() override; -}; - -} // namespace - -void FoldSubViewOpsPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - memref::populateFoldSubViewOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); -} - -std::unique_ptr memref::createFoldSubViewOpsPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -8,6 +8,8 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { @@ -42,3 +44,26 @@ res.push_back((*it).getValue().getSExtValue()); return res; } + +mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef basis, + mlir::Builder &b) { + AffineExpr resultExpr = b.getAffineDimExpr(0); + resultExpr = resultExpr * basis[0]; + for (unsigned i = 1; i < basis.size(); i++) + resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i]; + return resultExpr; +} + +llvm::SmallVector +mlir::getDelinearizedAffineExpr(mlir::ArrayRef strides, Builder &b) { + AffineExpr resultExpr = b.getAffineDimExpr(0); + int64_t rank = strides.size(); + SmallVector vectorOffsets(rank); + vectorOffsets[0] = resultExpr.floorDiv(strides[0]); + resultExpr = resultExpr % strides[0]; + for (unsigned i = 1; i < rank; i++) { + vectorOffsets[i] = resultExpr.floorDiv(strides[i]); + resultExpr = resultExpr % strides[i]; + } + return vectorOffsets; +} diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir rename from mlir/test/Dialect/MemRef/fold-subview-ops.mlir rename to mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -fold-memref-subview-ops -split-input-file %s -o - | FileCheck %s +// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> @@ -272,3 +272,154 @@ // CHECK-NEXT: return return %1 : f32 } + +// ----- + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 6 + d1)> +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 { +func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32> + %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32> + return %1 : f32 +} +// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32> +// CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 6)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 6)> +// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 { + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32> + %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32> + return %1 : f32 +} +// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]](%[[ARG1]]) +// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]](%[[ARG1]]) +// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32> +// CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0 * 6 + d1 * 3 + d2)> +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d +// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 { +func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 { + %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32> + %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32> + return %1 : f32 +} +// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]], %[[ARG3]]) +// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32> +// CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) +func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 1024 { + affine.for %arg5 = 0 to 1020 { + affine.for %arg6 = 0 to 1 { + %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32> + affine.store %1, %arg1[%arg2] : memref<1xf32> + } + } + } + } + %2 = affine.load %arg1[%arg2] : memref<1xf32> + return %2 : f32 +} +// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 { +// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { +// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { +// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { +// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32> + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 + d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) +func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 1024 { + affine.for %arg5 = 0 to 1020 { + affine.for %arg6 = 0 to 1 { + %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32> + affine.store %1, %arg1[%arg2] : memref<1xf32> + } + } + } + } + %2 = affine.load %arg1[%arg2] : memref<1xf32> + return %2 : f32 +} +// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 { +// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { +// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { +// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG3]], %[[TMP1]]) +// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #map2(%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32> + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) +func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + %cst = arith.constant 0 : index + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 1024 { + affine.for %arg5 = 0 to 1020 { + affine.for %arg6 = 0 to 1 { + %1 = memref.load %0[%arg3, %cst, %arg5, %arg6] : memref<1x1024x1024x1xf32> + memref.store %1, %arg1[%arg2] : memref<1xf32> + } + } + } + } + %2 = memref.load %arg1[%arg2] : memref<1xf32> + return %2 : f32 +} +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 { +// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { +// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { +// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ZERO]]) +// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32> + +// ----- + +// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result +// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>) +func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> { + %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref + affine.for %arg2 = 0 to 3 { + %1 = affine.load %0[] : memref + affine.store %1, %arg1[0] : memref<1xf32> + } + return %arg1 : memref<1xf32> +} +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 { +// CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32> diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -46,7 +46,7 @@ applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); - passManager.addPass(memref::createFoldSubViewOpsPass()); + passManager.addPass(memref::createFoldMemRefAliasOpsPass()); passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); OpPassManager &modulePM = passManager.nest();