diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -658,6 +658,20 @@ /// `t` with simplified layout. MemRefType canonicalizeStridedLayout(MemRefType t); +/// Given MemRef `sizes` that are either static or dynamic, returns the +/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and +/// once a dynamic dimension is encountered, all canonical strides become +/// dynamic and need to be encoded with a different symbol. +/// For canonical strides expressions, the offset is always 0 and and fastest +/// varying stride is always `1`. +/// +/// Examples: +/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`. +/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`. +/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`. +AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, + MLIRContext *context); + /// Return true if the layout for `t` is compatible with strided semantics. bool isStrided(MemRefType t); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -456,49 +456,6 @@ return success(); } -/// Given MemRef `sizes` that are either static or dynamic, returns the -/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and -/// once a dynamic dimension is encountered, all canonical strides become -/// dynamic and need to be encoded with a different symbol. -/// For canonical strides expressions, the offset is always 0 and and fastest -/// varying stride is always `1`. -/// -/// Examples: -/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`. -/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`. -/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`. -static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, - MLIRContext *context) { - AffineExpr expr; - bool dynamicPoisonBit = false; - unsigned nSymbols = 0; - int64_t runningSize = 1; - unsigned rank = sizes.size(); - for (auto en : llvm::enumerate(llvm::reverse(sizes))) { - auto size = en.value(); - auto position = rank - 1 - en.index(); - // Degenerate case, no size =-> no stride - if (size == 0) - continue; - auto d = getAffineDimExpr(position, context); - // Static case: stride = runningSize and runningSize *= size. - if (!dynamicPoisonBit) { - auto cst = getAffineConstantExpr(runningSize, context); - expr = expr ? expr + cst * d : cst * d; - if (size > 0) - runningSize *= size; - else - // From now on bail into dynamic mode. - dynamicPoisonBit = true; - continue; - } - // Dynamic case, new symbol for each new stride. - auto sym = getAffineSymbolExpr(nSymbols++, context); - expr = expr ? expr + d * sym : d * sym; - } - return simplifyAffineExpr(expr, rank, nSymbols); -} - // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( // i.e. single term). Accumulate the AffineExpr into the existing one. static void extractStridesFromTerm(AffineExpr e, @@ -766,6 +723,38 @@ return MemRefType::Builder(t).setAffineMaps({}); } +AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, + MLIRContext *context) { + AffineExpr expr; + bool dynamicPoisonBit = false; + unsigned nSymbols = 0; + int64_t runningSize = 1; + unsigned rank = sizes.size(); + for (auto en : llvm::enumerate(llvm::reverse(sizes))) { + auto size = en.value(); + auto position = rank - 1 - en.index(); + // Degenerate case, no size =-> no stride + if (size == 0) + continue; + auto d = getAffineDimExpr(position, context); + // Static case: stride = runningSize and runningSize *= size. + if (!dynamicPoisonBit) { + auto cst = getAffineConstantExpr(runningSize, context); + expr = expr ? expr + cst * d : cst * d; + if (size > 0) + runningSize *= size; + else + // From now on bail into dynamic mode. + dynamicPoisonBit = true; + continue; + } + // Dynamic case, new symbol for each new stride. + auto sym = getAffineSymbolExpr(nSymbols++, context); + expr = expr ? expr + d * sym : d * sym; + } + return simplifyAffineExpr(expr, rank, nSymbols); +} + /// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset;