diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -2299,10 +2300,47 @@ /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { /// tensor.dim %y, %c0 : tensor<...> /// } +/// +/// Note: Dim ops are folded only if it can be proven that the runtime type of +/// the yielded value (in case of outputs) does not change with loop iterations. template struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + /// A simple, conservative analysis to determine if the loop is shape + /// conserving. I.e., the type of the arg-th yielded value is the same as the + /// type of the corresponding basic block argument of the loop. + /// Note: This function handles only simple cases. Expand as needed. + static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { + auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); + if (yieldOp.values().empty()) + // Tiled loop either has no outputs or is a "memref-based version". In + // either case, the loop is shape conserving. + return true; + assert(arg < static_cast(yieldOp.values().size()) && + "arg is out of bounds"); + Value value = yieldOp.values()[arg]; + while (value) { + if (value == loopOp.getRegionOutputArgs()[arg]) + return true; + OpResult opResult = value.dyn_cast(); + if (!opResult) + return false; + + using tensor::InsertSliceOp; + value = llvm::TypeSwitch(opResult.getOwner()) + .template Case( + [&](InsertSliceOp op) { return op.dest(); }) + .template Case([&](TiledLoopOp loopOp) { + return isShapePreserving(loopOp, opResult.getResultNumber()) + ? loopOp.outputs()[opResult.getResultNumber()] + : Value(); + }) + .Default([&](auto op) { return Value(); }); + } + return false; + } + LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const final { auto src = dimOp.source().template dyn_cast(); @@ -2312,6 +2350,12 @@ dyn_cast(src.getOwner()->getParent()->getParentOp()); if (!loopOp) return failure(); + unsigned numLoops = loopOp.getNumLoops(); + unsigned numInputArgs = loopOp.getRegionInputArgs().size(); + if (src.getArgNumber() >= numInputArgs + numLoops && + !isShapePreserving(loopOp, + src.getArgNumber() - numInputArgs - numLoops)) + return failure(); auto inputArgs = loopOp.getRegionInputArgs(); auto it1 = llvm::find(inputArgs, src); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -904,6 +904,34 @@ // ----- +// CHECK-LABEL: func @dim_of_tiled_loop_input_no_canonicalize( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: linalg.tiled_loop {{.*}} outs (%[[o:.*]] = +// CHECK: %[[dim:.*]] = tensor.dim %[[o]], %[[c0]] +// CHECK: index_cast %[[dim]] +func @dim_of_tiled_loop_input_no_canonicalize(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0) + to (%d0, %d1) step (%c1, %c1) + ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) + outs (%out1 = %arg2 : tensor) { + %inner_dim = tensor.dim %out1, %c0 : tensor + %cast1 = std.index_cast %inner_dim : index to i32 + %cast2 = std.sitofp %cast1 : i32 to f32 + %fill = linalg.fill(%cast2, %out1) : f32, tensor -> tensor + %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor to tensor + linalg.yield %slice : tensor + } + return %r : tensor +} + +// ----- + // CHECK-LABEL: func @dim_of_tiled_loop_input( // CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor // CHECK: %[[c0:.*]] = constant 0 : index