diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -59,6 +59,10 @@ /// single deallocate if it exists or nullptr. std::optional findDealloc(Value allocValue); +/// Return the dimension of the given memref value. +OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, + int64_t dim); + /// Return the dimensions of the given memref value. SmallVector getMixedSizes(OpBuilder &builder, Location loc, Value value); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -108,18 +108,22 @@ return NoneType::get(type.getContext()); } +OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value, + int64_t dim) { + auto memrefType = llvm::cast(value.getType()); + SmallVector result; + if (memrefType.isDynamicDim(dim)) + return builder.createOrFold(loc, value, dim); + + return builder.getIndexAttr(memrefType.getDimSize(dim)); +} + SmallVector memref::getMixedSizes(OpBuilder &builder, Location loc, Value value) { auto memrefType = llvm::cast(value.getType()); SmallVector result; - for (int64_t i = 0; i < memrefType.getRank(); ++i) { - if (memrefType.isDynamicDim(i)) { - Value size = builder.create(loc, value, i); - result.push_back(size); - } else { - result.push_back(builder.getIndexAttr(memrefType.getDimSize(i))); - } - } + for (int64_t i = 0; i < memrefType.getRank(); ++i) + result.push_back(getMixedSize(builder, loc, value, i)); return result; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -38,26 +38,6 @@ using namespace mlir; using namespace mlir::vector; -static std::optional extractConstantIndex(Value v) { - if (auto cstOp = v.getDefiningOp()) - return cstOp.value(); - if (auto affineApplyOp = v.getDefiningOp()) - if (affineApplyOp.getAffineMap().isSingleConstant()) - return affineApplyOp.getAffineMap().getSingleConstantResult(); - return std::nullopt; -} - -// Missing foldings of scf.if make it necessary to perform poor man's folding -// eagerly, especially in the case of unrolling. In the future, this should go -// away once scf.if folds properly. -static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { - auto maybeCstV = extractConstantIndex(v); - auto maybeCstUb = extractConstantIndex(ub); - if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) - return Value(); - return b.create(v.getLoc(), arith::CmpIPredicate::sle, v, ub); -} - /// Build the condition to ensure that a particular VectorTransferOpInterface /// is in-bounds. static Value createInBoundsCond(RewriterBase &b, @@ -74,14 +54,19 @@ // Fold or create the check that `index + vector_size` <= `memref_size`. Location loc = xferOp.getLoc(); int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); - auto d0 = getAffineDimExpr(0, xferOp.getContext()); - auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); - Value sum = affine::makeComposedAffineApply(b, loc, d0 + vs, - {xferOp.indices()[indicesIdx]}); - Value cond = createFoldedSLE( - b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); - if (!cond) + OpFoldResult sum = affine::makeComposedFoldedAffineApply( + b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize), + {xferOp.indices()[indicesIdx]}); + OpFoldResult dimSz = + memref::getMixedSize(b, loc, xferOp.source(), indicesIdx); + auto maybeCstSum = getConstantIntValue(sum); + auto maybeCstDimSz = getConstantIntValue(dimSz); + if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz) return; + Value cond = + b.create(loc, arith::CmpIPredicate::sle, + getValueOrCreateConstantIndexOp(b, loc, sum), + getValueOrCreateConstantIndexOp(b, loc, dimSz)); // Conjunction over all dims for which we are in-bounds. if (inBoundsCond) inBoundsCond = b.create(loc, inBoundsCond, cond); @@ -199,8 +184,8 @@ auto isaWrite = isa(xferOp); xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { using MapList = ArrayRef>; - Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), - xferOp.source(), indicesIdx); + Value dimMemRef = + b.create(xferOp.getLoc(), xferOp.source(), indicesIdx); Value dimAlloc = b.create(loc, alloc, resultIdx); Value index = xferOp.indices()[indicesIdx]; AffineExpr i, j, k;