diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -268,12 +269,69 @@ return success(); } }; + +struct AssumingBypassIndependentResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + Block *body = op.getBody(); + auto yieldOp = llvm::dyn_cast(body->getTerminator()); + if (!yieldOp) + return failure(); + + // See if there is at least one result that can bypass the assuming op. + auto isDefinedInBody = [&](Value val) { + Operation *def = val.getDefiningOp(); + return def && def->getBlock() == body; + }; + if (llvm::all_of(yieldOp.operands(), isDefinedInBody)) + return failure(); + + SmallVector replacementValues; + auto newAssumingOp = rewriter.create( + op.getLoc(), op.witness(), [&](OpBuilder &b, Location loc) { + // Copy body. + BlockAndValueMapping mapping; + for (auto &nested : op.getBody()->without_terminator()) + b.clone(nested, mapping); + + // Collect new yielded values. + SmallVector mappedResults; + for (auto result : yieldOp.getOperands()) { + if (isDefinedInBody(result)) { + // This value is a result of the assuming op. We can obtain the + // replacement value only after the new op is fully constructed. + mappedResults.push_back(mapping.lookup(result)); + replacementValues.push_back(nullptr); + } else { + // When defined outside of the assuming block, we can use it + // direclty. There is no need to yield the value from within the + // block. + replacementValues.push_back(result); + } + } + return mappedResults; + }); + + // Use the assuming op's results for the missing replacement values, which + // could not bypass the op. + auto src = newAssumingOp.getResults().begin(); + for (auto &dst : replacementValues) { + if (dst) + continue; + dst = *src++; + } + + rewriter.replaceOp(op, replacementValues); + return success(); + } +}; } // namespace void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { - // If taking a passing witness, inline region. - patterns.insert(context); + patterns.insert(context); } // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1131,3 +1131,28 @@ "use"(%0) : (tensor) -> () return } + +// ----- + +// CHECK-LABEL: @bypass_assmunig +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>) +func @bypass_assmunig(%arg : tensor<2x3xf32>) + -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) { + // CHECK: %[[OUTER:.*]] = "some.tensor" + // CHECK: %[[WITNESS:.*]] = "some.witness" + // CHECK: %[[YIELDED:.*]] = shape.assuming %[[WITNESS]] -> (tensor<2x3xf32>) { + // CHECK: %[[INNER:.*]] = "some.tensor" + // CHECK: shape.assuming_yield %[[INNER]] : tensor<2x3xf32> + // CHECK: } + // CHECK: return %[[YIELDED]], %[[OUTER]], %[[ARG]] + %outer = "some.tensor"() : () -> tensor<2x3xf32> + %witness = "some.witness"() : () -> !shape.witness + %results:3 = shape.assuming %witness + -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) { + %inner = "some.tensor"() : () -> tensor<2x3xf32> + shape.assuming_yield %inner, %outer, %arg + : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32> + } + return %results#0, %results#1, %results#2 + : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32> +}