diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -560,11 +560,137 @@ return failure(); } }; + +/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and +/// for which only the last loop iteration is actually visible outside of the +/// loop. The canonicalization looks for a pattern such as: +/// ``` +/// %t0 = ... : tensor_type +/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { +/// ... +/// // %m is either tensor_to_memref(%bb00) or defined above the loop +/// %m... : memref_type +/// ... // uses of %m with potential inplace updates +/// %new_tensor = tensor_load %m : memref_type +/// ... +/// scf.yield %new_tensor : tensor_type +/// } +/// ``` +/// +/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a +/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load` +/// op. +/// +/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded, +/// occurs between tensor_load and yield then the value %0 visible outside of +/// the loop is the last `tensor_load` produced in the loop. +/// +/// For now, we approximate the absence of aliasing by only supporting the case +/// when the tensor_load is the operation immediately preceding the yield. +/// +/// The canonicalization rewrites the pattern as: +/// ``` +/// // %m is either a tensor_to_memref or defined above +/// %m... : memref_type +/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { +/// ... // uses of %m with potential inplace updates +/// scf.yield %bb0: tensor_type +/// } +/// %0 = tensor_load %m : memref_type +/// ``` +/// +/// A later bbArg canonicalization will further rewrite as: +/// ``` +/// // %m is either a tensor_to_memref or defined above +/// %m... : memref_type +/// scf.for ... { // no iter_args +/// ... // uses of %m with potential inplace updates +/// } +/// %0 = tensor_load %m : memref_type +/// ``` +struct LastTensorLoadCanonicalization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + assert(std::next(forOp.region().begin()) == forOp.region().end() && + "unexpected multiple blocks"); + + Location loc = forOp.getLoc(); + DenseMap replacements; + for (BlockArgument bbArg : forOp.getRegionIterArgs()) { + unsigned idx = bbArg.getArgNumber() - /*numIv=*/1; + auto yieldOp = cast(forOp.region().front().getTerminator()); + Value yieldVal = yieldOp->getOperand(idx); + auto tensorLoadOp = yieldVal.getDefiningOp(); + bool isTensor = bbArg.getType().isa(); + + TensorToMemrefOp tensorToMemRefOp; + // Either bbArg has no use or it has a single tensor_to_memref use. + if (bbArg.hasOneUse()) + tensorToMemRefOp = + dyn_cast(*bbArg.getUsers().begin()); + if (!isTensor || !tensorLoadOp || + (!bbArg.use_empty() && !tensorToMemRefOp)) + continue; + // If tensorToMemRefOp is present, it must feed into the `tensorLoadOp`. + if (tensorToMemRefOp && tensorLoadOp.memref() != tensorToMemRefOp) + continue; + // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp` + // must be before `tensorLoadOp` in the block so that the lastWrite + // property is not subject to additional side-effects. + // For now, we only support the case when tensorLoadOp appears immediately + // before the terminator. + if (tensorLoadOp->getNextNode() != yieldOp) + continue; + + // Clone the optional tensorToMemRefOp before forOp. + if (tensorToMemRefOp) { + rewriter.setInsertionPoint(forOp); + rewriter.replaceOpWithNewOp( + tensorToMemRefOp, tensorToMemRefOp.memref().getType(), + tensorToMemRefOp.tensor()); + } + + // Clone the tensorLoad after forOp. + rewriter.setInsertionPointAfter(forOp); + Value newTensorLoad = + rewriter.create(loc, tensorLoadOp.memref()); + Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1); + replacements.insert(std::make_pair(forOpResult, newTensorLoad)); + + // Make the terminator just yield the bbArg, the old tensorLoadOp + the + // old bbArg (that is now directly yielded) will canonicalize away. + rewriter.startRootUpdate(yieldOp); + yieldOp.setOperand(idx, bbArg); + rewriter.finalizeRootUpdate(yieldOp); + } + if (replacements.empty()) + return failure(); + + // We want to replace a subset of the results of `forOp`. rewriter.replaceOp + // replaces the whole op and erase it unconditionally. This is wrong for + // `forOp` as it generally contains ops with side effects. + // Instead, use `rewriter.replaceOpWithIf`. + SmallVector newResults; + newResults.reserve(forOp.getNumResults()); + for (Value v : forOp.getResults()) { + auto it = replacements.find(v); + newResults.push_back((it != replacements.end()) ? it->second : v); + } + unsigned idx = 0; + rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) { + return op.get() != newResults[idx++]; + }); + return success(); + } +}; } // namespace void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + + +// ----- func @single_iteration(%A: memref) { %c0 = constant 0 : index @@ -143,6 +146,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32 +// ----- + // CHECK-LABEL: @replace_true_if func @replace_true_if() { %true = constant true @@ -155,6 +160,8 @@ return } +// ----- + // CHECK-LABEL: @remove_false_if func @remove_false_if() { %false = constant false @@ -167,6 +174,8 @@ return } +// ----- + // CHECK-LABEL: @replace_true_if_with_values func @replace_true_if_with_values() { %true = constant true @@ -184,6 +193,8 @@ return } +// ----- + // CHECK-LABEL: @replace_false_if_with_values func @replace_false_if_with_values() { %false = constant false @@ -201,6 +212,8 @@ return } +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = constant 42 : index @@ -217,6 +230,8 @@ return } +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop_vals func @remove_zero_iteration_loop_vals(%arg0: index) { %c2 = constant 2 : index @@ -233,6 +248,8 @@ return } +// ----- + // CHECK-LABEL: @replace_single_iteration_loop_1 func @replace_single_iteration_loop_1() { // CHECK: %[[LB:.*]] = constant 42 @@ -252,6 +269,8 @@ return } +// ----- + // CHECK-LABEL: @replace_single_iteration_loop_2 func @replace_single_iteration_loop_2() { // CHECK: %[[LB:.*]] = constant 5 @@ -271,6 +290,7 @@ return } +// ----- // CHECK-LABEL: @replace_single_iteration_loop_non_unit_step func @replace_single_iteration_loop_non_unit_step() { @@ -291,6 +311,8 @@ return } +// ----- + // CHECK-LABEL: @remove_empty_parallel_loop func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // CHECK: %[[INIT:.*]] = "test.init" @@ -311,3 +333,52 @@ "test.consume"(%0) : (f32) -> () return } + +// ----- +func private @process(%0 : memref<128x128xf32>) +func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32> + +// CHECK-LABEL: last_value +// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32> +// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32> +// CHECK-SAME: %[[T2:[0-9a-z]*]]: tensor<128x128xf32> +// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32> +func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, + %t2: tensor<128x128xf32>, %m0: memref<128x128xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>) +{ + // CHECK-NEXT: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<128x128xf32> + // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) { + %0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2) + -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>) + { + %m1 = tensor_to_memref %arg2 : memref<128x128xf32> + + // CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> () + call @process(%m0) : (memref<128x128xf32>) -> () + + // CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> () + call @process(%m1) : (memref<128x128xf32>) -> () + + // This does not hoist (fails the bbArg has at most a single check). + // CHECK-NEXT: %[[T:.*]] = call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32> + // CHECK-NEXT: %[[YIELD_T:.*]] = tensor_load %[[T:.*]] + %m2 = call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32> + %3 = tensor_load %m2 : memref<128x128xf32> + + // All this stuff goes away, incrementally + %1 = tensor_load %m0 : memref<128x128xf32> + %2 = tensor_load %m1 : memref<128x128xf32> + + // CHECK-NEXT: scf.yield %[[YIELD_T]] : tensor<128x128xf32> + scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + + // CHECK-NEXT: } + } + + // CHECK-NEXT: %[[R0:.*]] = tensor_load %[[M0]] : memref<128x128xf32> + // CHECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> +}