diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -16,4 +16,32 @@ #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace memref { + +/// Given the 'indices' indexing into a memref that is the result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = memref.subview %0[%arg0, %arg1][4, 4][%stride0, %stride1] +/// memref<12x42xf32> to memref<4x4xf32, #map0> +/// %2 = memref.load %1[%i0, %i1] : memref<4x4xf32, #map0> +/// +/// The reference to %1[%i0, %i1] can be replaced with a reference to +/// %0 if we invert the indexing transform represented by the subview: +/// +/// memref.load %0[%arg0 + %i0 * %stride0][%arg1 + %i1 * %stride1] : +/// memref<12x42xf32> +LogicalResult invertSubViewIndexMapping(Location loc, PatternRewriter &rewriter, + memref::SubViewOp subViewOp, + ValueRange indices, + SmallVectorImpl &sourceIndices); + +} // namespace memref +} // namespace mlir + #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRMemRefDialect + MLIRMemRefUtils MLIRPass MLIRTensorDialect MLIRTransforms diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -27,69 +28,6 @@ // Utility functions //===----------------------------------------------------------------------===// -/// Given the 'indices' of an load/store operation where the memref is a result -/// of a subview op, returns the indices w.r.t to the source memref of the -/// subview op. For example -/// -/// %0 = ... : memref<12x42xf32> -/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to -/// memref<4x4xf32, offset=?, strides=[?, ?]> -/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> -/// -/// could be folded into -/// -/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> -static LogicalResult -resolveSourceIndices(Location loc, PatternRewriter &rewriter, - memref::SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { - SmallVector mixedOffsets = subViewOp.getMixedOffsets(); - SmallVector mixedSizes = subViewOp.getMixedSizes(); - SmallVector mixedStrides = subViewOp.getMixedStrides(); - - SmallVector useIndices; - // Check if this is rank-reducing case. Then for every unit-dim size add a - // zero to the indices. - unsigned resultDim = 0; - llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { - if (unusedDims.test(dim)) - useIndices.push_back(rewriter.create(loc, 0)); - else - useIndices.push_back(indices[resultDim++]); - } - if (useIndices.size() != mixedOffsets.size()) - return failure(); - sourceIndices.resize(useIndices.size()); - for (auto index : llvm::seq(0, mixedOffsets.size())) { - SmallVector dynamicOperands; - AffineExpr expr = rewriter.getAffineDimExpr(0); - unsigned numSymbols = 0; - dynamicOperands.push_back(useIndices[index]); - - // Multiply the stride; - if (auto attr = mixedStrides[index].dyn_cast()) { - expr = expr * attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedStrides[index].get()); - expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); - } - - // Add the offset. - if (auto attr = mixedOffsets[index].dyn_cast()) { - expr = expr + attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedOffsets[index].get()); - expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); - } - Location loc = subViewOp.getLoc(); - sourceIndices[index] = rewriter.create( - loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); - } - return success(); -} - /// Helpers to access the memref operand for each op. template static Value getMemRefOperand(LoadOrStoreOpTy op) { @@ -216,8 +154,8 @@ return failure(); SmallVector sourceIndices; - if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, - loadOp.getIndices(), sourceIndices))) + if (failed(invertSubViewIndexMapping(loadOp.getLoc(), rewriter, subViewOp, + loadOp.getIndices(), sourceIndices))) return failure(); replaceOp(loadOp, subViewOp, sourceIndices, rewriter); @@ -234,8 +172,8 @@ return failure(); SmallVector sourceIndices; - if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, - storeOp.getIndices(), sourceIndices))) + if (failed(invertSubViewIndexMapping(storeOp.getLoc(), rewriter, subViewOp, + storeOp.getIndices(), sourceIndices))) return failure(); replaceOp(storeOp, subViewOp, sourceIndices, rewriter); diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt @@ -3,5 +3,10 @@ ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + + DEPENDS + MLIRAffineDialect + MLIRArithmeticDialect + MLIRMemRefDialect ) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -11,3 +11,58 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" + +using namespace mlir; +using namespace mlir::memref; + +LogicalResult +memref::invertSubViewIndexMapping(Location loc, PatternRewriter &rewriter, + SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl &sourceIndices) { + SmallVector mixedOffsets = subViewOp.getMixedOffsets(); + SmallVector mixedSizes = subViewOp.getMixedSizes(); + SmallVector mixedStrides = subViewOp.getMixedStrides(); + + SmallVector useIndices; + // Check if this is rank-reducing case. Then for every unit-dim size add a + // zero to the indices. + unsigned resultDim = 0; + llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); + for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { + if (unusedDims.test(dim)) + useIndices.push_back(rewriter.create(loc, 0)); + else + useIndices.push_back(indices[resultDim++]); + } + if (useIndices.size() != mixedOffsets.size()) + return failure(); + sourceIndices.resize(useIndices.size()); + for (auto index : llvm::seq(0, mixedOffsets.size())) { + SmallVector dynamicOperands; + AffineExpr expr = rewriter.getAffineDimExpr(0); + unsigned numSymbols = 0; + dynamicOperands.push_back(useIndices[index]); + + // Multiply the stride; + if (auto attr = mixedStrides[index].dyn_cast()) { + expr = expr * attr.cast().getInt(); + } else { + dynamicOperands.push_back(mixedStrides[index].get()); + expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); + } + + // Add the offset. + if (auto attr = mixedOffsets[index].dyn_cast()) { + expr = expr + attr.cast().getInt(); + } else { + dynamicOperands.push_back(mixedOffsets[index].get()); + expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); + } + Location loc = subViewOp.getLoc(); + sourceIndices[index] = rewriter.create( + loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); + } + return success(); +}