diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -300,14 +300,56 @@ DenseSet visitedValues; DenseSet visitedOps; + // For a (to-be-detesored) value, check if it "escapes" the block by being + // passed to terminator. If it does, then workList is updated with the + // corresponding argument to the successor block. + auto updateWorkListWithSuccessorArguments = + [&](Value value, BranchOpInterface terminator) { + if (!terminator) + return; + + for (auto operandIdx : + llvm::seq(0, terminator->getOperands().size())) { + Value operand = terminator->getOperand(operandIdx); + + if (operand == value) { + auto succBlockArg = + terminator.getSuccessorBlockArgument(operandIdx); + + if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) + workList.push_back(*succBlockArg); + } + } + }; + while (!workList.empty()) { Value currentItem = workList.pop_back_val(); if (!visitedValues.insert(currentItem).second) continue; - // The current item is defined by a block argument. - if (auto bbarg = currentItem.dyn_cast()) { + // 1 - Look forward: + // 1.1 - If currentItem escapes to one or more successors, add + // the corresponding successor arguments to workList. + updateWorkListWithSuccessorArguments( + currentItem, dyn_cast( + currentItem.getParentBlock()->getTerminator())); + + // 1.2 - For each user of currentItem, add the defined values to + // workList. This way, the user ops can be inspected later if they are + // detensorable and if so, their operands will be added to workList to + // potentially discover other parts of the detensorable component. + for (auto *user : currentItem.getUsers()) + for (Value result : user->getResults()) + workList.push_back(result); + + // 2 - Look backward: + // 2.1 - The current item is defined by a block argument. If the owner + // block is a non-entry one, then: + // * Add the argument to blockArgsToDetensor. + // * Walk the use-def chain backwards to add each predecessor's + // terminator-operands corresponding to currentItem to workList. + if (currentItem.dyn_cast()) { BlockArgument currentItemBlockArgument = currentItem.cast(); Block *ownerBlock = currentItemBlockArgument.getOwner(); @@ -354,7 +396,11 @@ if (!visitedOps.insert(currentItemDefiningOp).second) continue; - // The current item is computed by a GenericOp. + // 2.2 - The current item is computed by a GenericOp. If the op should + // be detensored, then: + // * Add it to opsToDetensor. + // * Add its operands to workList to discover other parts of the + // potentially detensorable component. if (auto genericOp = dyn_cast(currentItemDefiningOp)) { // The op was encountered already, no need to inspect it again. if (opsToDetensor.count(genericOp)) @@ -376,7 +422,7 @@ continue; } - // The current item is the result of a FromElemntsOp, it will be + // 2.3 - The current item is the result of a FromElementsOp, it will be // trivially detensored later as part of canonicalization patterns // applied at the end of detensoring. // @@ -386,8 +432,8 @@ if (dyn_cast(currentItemDefiningOp)) continue; - // The current item is the result of a scalar op, add all its operands - // to the work list. + // 2.4 - The current item is the result of a scalar op, add all its + // operands to the work list. if (llvm::all_of( currentItemDefiningOp->getResultTypes(), [&](Type resultType) { return resultType.isIntOrFloat(); })) @@ -442,8 +488,8 @@ target.addDynamicallyLegalOp([&](FuncOp op) { // A function is legal if all of its non-entry blocks are legal. We - // don't legalize the entry block (i.e. the function's signature) since - // detensoring can't happen along external calling convention + // don't legalize the entry block (i.e. the function's signature) + // since detensoring can't happen along external calling convention // boundaries, which we conservatively approximate as all function // signatures. return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir rename from mlir/test/Dialect/Linalg/detensorized_0d.mlir rename to mlir/test/Dialect/Linalg/detensorize_0d.mlir diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -linalg-detensorize | FileCheck %s #map0 = affine_map<() -> ()> @@ -48,18 +48,149 @@ // CHECK-NEXT: constant 10 // CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32) // CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): -// CHECK-NEXT: tensor.from_elements %{{.*}} -// CHECK-NEXT: linalg.tensor_reshape %{{.*}} // CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}} -// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor), ^bb3(%{{.*}} : tensor) -// CHECK-NEXT: ^[[bb2]](%{{.*}}: tensor) -// CHECK-NEXT: linalg.init_tensor -// CHECK-NEXT: linalg.generic -// CHECK-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) -// CHECK-NEXT: addi %{{.*}}, %{{.*}} -// CHECK-NEXT: linalg.yield %{{.*}} -// CHECK-NEXT: } -> tensor -// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : tensor) -// CHECK-NEXT: ^[[bb3]](%{{.*}}: tensor) +// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32) +// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) +// CHECK-NEXT: addi %{{.*}}, %{{.*}} +// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> +// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: return %{{.*}} +// CHECK-NEXT: } + +// ----- + +// Similar to the above test with one change: one of the block after the +// if-condition passes/forwards its tensor argument to another block. + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main() -> (tensor) attributes {} { + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<1xi32> + %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + br ^bb1(%reshaped0 : tensor) + +^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%2, %reshaped1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %5 = tensor.extract %4[] : tensor + cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) + +^bb2(%6: tensor): // pred: ^bb1 + %7 = linalg.init_tensor [] : tensor + %8 = linalg.generic #attrs + ins(%6, %6 : tensor, tensor) + outs(%7 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %9 = addi %arg0, %arg1 : i32 + linalg.yield %9 : i32 + } -> tensor + br ^bb3(%8 : tensor) + +^bb3(%10: tensor): // pred: ^bb1 + br ^bb4(%10 : tensor) + +^bb4(%11: tensor): // pred: ^bb1 + return %11 : tensor +} + +// CHECK-LABEL: func @main() +// CHECK-NEXT: constant 0 +// CHECK-NEXT: constant 10 +// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32) +// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): +// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}} +// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32) +// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) +// CHECK-NEXT: addi %{{.*}}, %{{.*}} +// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) +// CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32) +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> +// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: return %{{.*}} +// CHECK-NEXT: } + +// ----- + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main() -> (tensor) attributes {} { + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<1xi32> + %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + br ^bb1(%reshaped0 : tensor) + +^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%2, %reshaped1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %5 = tensor.extract %4[] : tensor + // This cond_br intentionally has bb2 as it's target for both branches. This + // is to make sure that the "forward phase" of the cost-model correctly adds + // the users of a block argument (in this case bb2's argument) to the work + // list. + cond_br %5, ^bb2(%2 : tensor), ^bb2(%2 : tensor) + +^bb2(%6: tensor): // pred: ^bb1 + %12 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor + %7 = linalg.init_tensor [] : tensor + %8 = linalg.generic #attrs + ins(%6, %reshaped12 : tensor, tensor) + outs(%7 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %9 = addi %arg0, %arg1 : i32 + linalg.yield %9 : i32 + } -> tensor + br ^bb3(%8 : tensor) + +^bb3(%10: tensor): // pred: ^bb1 + return %10 : tensor +} + +// CHECK-LABEL: func @main() +// CHECK-NEXT: constant 0 +// CHECK-NEXT: constant 10 +// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32) +// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): +// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}} +// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32) +// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) +// CHECK-NEXT: addi %{{.*}}, %{{.*}} +// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> +// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -62,12 +62,12 @@ // DET-CF: tensor.extract {{.*}} // DET-CF: br ^[[bb1:.*]](%{{.*}} : i32) // DET-CF: ^[[bb1]](%{{.*}}: i32) -// DET-CF-DAG tensor.from_elements {{.*}} -// DET-CF-DAG: linalg.tensor_reshape {{.*}} -// DET-CF-DAG: cmpi slt, {{.*}} -// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor) +// DET-CF: cmpi slt, {{.*}} +// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-CF: ^[[bb2]](%{{.*}}: i32) // DET-CF: addi {{.*}} // DET-CF: br ^[[bb1]](%{{.*}} : i32) -// DET-CF: ^[[bb3]](%{{.*}}: tensor) +// DET-CF: ^[[bb3]](%{{.*}}: i32) +// DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32> +// DET-CF: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor // DET-CF: return %{{.*}} : tensor