diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -143,12 +143,15 @@ /// `*this` and apply replace with `map` on its subexpressions. AffineExpr replace(const DenseMap &map) const; - /// Replace dims[0 .. numDims - 1] by dims[shift .. shift + numDims - 1]. - AffineExpr shiftDims(unsigned numDims, unsigned shift) const; - - /// Replace symbols[0 .. numSymbols - 1] by - /// symbols[shift .. shift + numSymbols - 1]. - AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const; + /// Replace dims[offset ... numDims) + /// by dims[offset + shift ... shift + numDims). + AffineExpr shiftDims(unsigned numDims, unsigned shift, + unsigned offset = 0) const; + + /// Replace symbols[offset ... numSymbols) + /// by symbols[offset + shift ... shift + numSymbols). + AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, + unsigned offset = 0) const; AffineExpr operator+(int64_t v) const; AffineExpr operator+(AffineExpr other) const; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -207,24 +207,28 @@ AffineMap replace(const DenseMap &map, unsigned numResultDims, unsigned numResultSyms) const; - /// Replace dims[0 .. numDims - 1] by dims[shift .. shift + numDims - 1]. - AffineMap shiftDims(unsigned shift) const { - return AffineMap::get( - getNumDims() + shift, getNumSymbols(), - llvm::to_vector<4>(llvm::map_range( - getResults(), - [&](AffineExpr e) { return e.shiftDims(getNumDims(), shift); })), - getContext()); + /// Replace dims[offset ... numDims) + /// by dims[offset + shift ... shift + numDims). + AffineMap shiftDims(unsigned shift, unsigned offset = 0) const { + assert(offset <= getNumDims()); + return AffineMap::get(getNumDims() + shift, getNumSymbols(), + llvm::to_vector<4>(llvm::map_range( + getResults(), + [&](AffineExpr e) { + return e.shiftDims(getNumDims(), shift, offset); + })), + getContext()); } - /// Replace symbols[0 .. numSymbols - 1] by - /// symbols[shift .. shift + numSymbols - 1]. - AffineMap shiftSymbols(unsigned shift) const { + /// Replace symbols[offset ... numSymbols) + /// by symbols[offset + shift ... shift + numSymbols). + AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const { return AffineMap::get(getNumDims(), getNumSymbols() + shift, llvm::to_vector<4>(llvm::map_range( getResults(), [&](AffineExpr e) { - return e.shiftSymbols(getNumSymbols(), shift); + return e.shiftSymbols(getNumSymbols(), shift, + offset); })), getContext()); } diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -101,19 +101,26 @@ return replaceDimsAndSymbols({}, symReplacements); } -/// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1]. -AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const { +/// Replace dims[offset ... numDims) +/// by dims[offset + shift ... shift + numDims). +AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift, + unsigned offset) const { SmallVector dims; - for (unsigned idx = 0; idx < numDims; ++idx) + for (unsigned idx = 0; idx < offset; ++idx) + dims.push_back(getAffineDimExpr(idx, getContext())); + for (unsigned idx = offset; idx < numDims; ++idx) dims.push_back(getAffineDimExpr(idx + shift, getContext())); return replaceDimsAndSymbols(dims, {}); } -/// Replace symbols[0 .. numSymbols - 1] by -/// symbols[shift .. shift + numSymbols - 1]. -AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const { +/// Replace symbols[offset ... numSymbols) +/// by symbols[offset + shift ... shift + numSymbols). +AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift, + unsigned offset) const { SmallVector symbols; - for (unsigned idx = 0; idx < numSymbols; ++idx) + for (unsigned idx = 0; idx < offset; ++idx) + symbols.push_back(getAffineSymbolExpr(idx, getContext())); + for (unsigned idx = offset; idx < numSymbols; ++idx) symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); return replaceDimsAndSymbols({}, symbols); }