diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -2302,10 +2303,43 @@ /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { /// tensor.dim %y, %c0 : tensor<...> /// } +/// +/// Note: Dim ops are folded only if it can be proven that the runtime type of +/// the yielded value (in case of outputs) does not change with loop iterations. template struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + /// A simple, conservative analysis to determine if the loop is shape + /// conserving. I.e., the type of the arg-th yielded value is the same as the + /// type of the corresponding basic block argument of the loop. + /// Note: This function handles only simple cases. Expand as needed. + static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { + auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); + assert(arg < static_cast(yieldOp.values().size()) && + "arg is out of bounds"); + Value value = yieldOp.values()[arg]; + while (value) { + if (value == loopOp.getRegionOutputArgs()[arg]) + return true; + OpResult opResult = value.dyn_cast(); + if (!opResult) + return false; + + using tensor::InsertSliceOp; + value = llvm::TypeSwitch(opResult.getOwner()) + .template Case( + [&](InsertSliceOp op) { return op.dest(); }) + .template Case([&](TiledLoopOp loopOp) { + return isShapePreserving(loopOp, opResult.getResultNumber()) + ? loopOp.outputs()[opResult.getResultNumber()] + : Value(); + }) + .Default([&](auto op) { return Value(); }); + } + return false; + } + LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const final { auto src = dimOp.source().template dyn_cast(); @@ -2315,6 +2349,12 @@ dyn_cast(src.getOwner()->getParent()->getParentOp()); if (!loopOp) return failure(); + unsigned numLoops = loopOp.getNumLoops(); + unsigned numInputArgs = loopOp.getRegionInputArgs().size(); + if (src.getArgNumber() >= numInputArgs + numLoops && + !isShapePreserving(loopOp, + src.getArgNumber() - numInputArgs - numLoops)) + return failure(); auto inputArgs = loopOp.getRegionInputArgs(); auto it1 = llvm::find(inputArgs, src);