diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -560,11 +560,99 @@ return failure(); } }; + +/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and +/// for which only the last loop iteration is actually visible outside of the +/// loop. The canonicalization looks for a pattern such as: +/// ``` +/// %t0 = ... : tensor_type +/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { +/// ... +/// // %m is either tensor_to_memref(%bb00) or defined above the loop +/// %m... : memref_type +/// ... // uses of %m with potential inplace updates +/// %new_tensor = tensor_load %m : memref_type +/// ... +/// scf.yield %new_tensor : tensor_type +/// } +/// ``` +/// +/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a +/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load` +/// op. +/// +/// If no aliasing write of `%new_tensor` occurs between tensor_load and yield +/// then the value %0 visible outside of the loop is the last `tensor_load` +/// produced in the loop. +/// +/// For now, we approximate the absence of aliasing by only supporting the case +/// when the tensor_load is the operation immediately preceding the yield. +/// +/// The canonicalization rewrites the pattern as: +/// ``` +/// // %m is either a tensor_to_memref or defined above +/// %m... : memref_type +/// scf.for ... { // no iter_args +/// ... // uses of %m with potential inplace updates +/// } +/// %0 = tensor_load %m : memref_type +/// ``` +struct LastTensorLoadCanonicalization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + Location loc = forOp.getLoc(); + DenseMap replacements; + for (BlockArgument bbArg : forOp.getRegionIterArgs()) { + unsigned idx = bbArg.getArgNumber() - /*numIv=*/1; + auto yieldOp = cast(forOp.region().front().getTerminator()); + Value yieldVal = yieldOp->getOperand(idx); + auto tensorLoadOp = yieldVal.getDefiningOp(); + bool isTensor = bbArg.getType().isa(); + + TensorToMemrefOp tensorToMemRefOp; + if (bbArg.hasOneUse()) + tensorToMemRefOp = + dyn_cast(*bbArg.getUsers().begin()); + if (!isTensor || !tensorLoadOp || + (!bbArg.use_empty() && !tensorToMemRefOp)) + continue; + // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp` + // must be before `tensorLoadOp` in the block so that the lastWrite + // property is not subject to additional side-effects. + // For now, we only support the case when tensorLoadOp appears immediately + // before the terminator. + if (tensorLoadOp->getNextNode() != yieldOp) + continue; + // Clone the optional tensorToMemRefOp before forOp. + if (tensorToMemRefOp) { + rewriter.setInsertionPoint(forOp); + rewriter.replaceOpWithNewOp( + tensorToMemRefOp, tensorToMemRefOp.memref().getType(), + tensorToMemRefOp.tensor()); + } + // Clone the tensorLoad after forOp. + rewriter.setInsertionPointAfter(forOp); + Value newTensorLoad = + rewriter.create(loc, tensorLoadOp.memref()); + Value forOpResult = forOp.getResult(bbArg.getArgNumber()); + replacements.insert(std::make_pair(forOpResult, newTensorLoad)); + // Make the terminator just yield the bbArg, the old tensorLoadOp + the + // old bbArg (that is now directly yielded) will canonicalize away. + rewriter.startRootUpdate(yieldOp); + yieldOp.setOperand(idx, bbArg); + rewriter.finalizeRootUpdate(yieldOp); + } + return success(!replacements.empty()); + } +}; } // namespace void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// ----- func @single_iteration(%A: memref) { %c0 = constant 0 : index @@ -143,6 +145,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32 +// ----- + // CHECK-LABEL: @replace_true_if func @replace_true_if() { %true = constant true @@ -155,6 +159,8 @@ return } +// ----- + // CHECK-LABEL: @remove_false_if func @remove_false_if() { %false = constant false @@ -167,6 +173,8 @@ return } +// ----- + // CHECK-LABEL: @replace_true_if_with_values func @replace_true_if_with_values() { %true = constant true @@ -184,6 +192,8 @@ return } +// ----- + // CHECK-LABEL: @replace_false_if_with_values func @replace_false_if_with_values() { %false = constant false @@ -201,6 +211,8 @@ return } +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = constant 42 : index @@ -217,6 +229,8 @@ return } +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop_vals func @remove_zero_iteration_loop_vals(%arg0: index) { %c2 = constant 2 : index @@ -233,6 +247,8 @@ return } +// ----- + // CHECK-LABEL: @replace_single_iteration_loop_1 func @replace_single_iteration_loop_1() { // CHECK: %[[LB:.*]] = constant 42 @@ -252,6 +268,8 @@ return } +// ----- + // CHECK-LABEL: @replace_single_iteration_loop_2 func @replace_single_iteration_loop_2() { // CHECK: %[[LB:.*]] = constant 5 @@ -271,6 +289,7 @@ return } +// ----- // CHECK-LABEL: @replace_single_iteration_loop_non_unit_step func @replace_single_iteration_loop_non_unit_step() { @@ -291,6 +310,8 @@ return } +// ----- + // CHECK-LABEL: @remove_empty_parallel_loop func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // CHECK: %[[INIT:.*]] = "test.init" @@ -311,3 +332,43 @@ "test.consume"(%0) : (f32) -> () return } + +// ----- + +func private @process(%0 : memref<128x128xf32>) + +// CHECK-LABEL: last_value +// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32> +// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32> +// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32> +func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, + %m0: memref<128x128xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor<128x128xf32>, tensor<128x128xf32>) +{ + // CHECK-NEXT: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<128x128xf32> + // CHECK-NEXT: scf.for + // CHECK-NOT: iter_args + %0:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1) + -> (tensor<128x128xf32>, tensor<128x128xf32>) + { + %m1 = tensor_to_memref %arg2 : memref<128x128xf32> + + // CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> () + call @process(%m0) : (memref<128x128xf32>) -> () + + // CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> () + call @process(%m1) : (memref<128x128xf32>) -> () + + // All this stuff goes away + %1 = tensor_load %m0 : memref<128x128xf32> + %2 = tensor_load %m1 : memref<128x128xf32> + scf.yield %1, %2 : tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: } + } + + // CHECK-NEXT: %[[R0:.*]] = tensor_load %[[M0]] : memref<128x128xf32> + // CHECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[R1]] : tensor<128x128xf32>, tensor<128x128xf32> + return %0#0, %0#1 : tensor<128x128xf32>, tensor<128x128xf32> +}