diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -272,25 +272,16 @@ /// Detensorize linalg ops involved in control-flow within a function. /// - /// This model starts from CondBranchOps within a function. For each cond_br, - /// the model then walks the use-def chain for the branch's condition - /// backwards in order to understand where the condition's value comes from. - /// If the condition value is (indirectly) computed by a linalg op that can be - /// detensored, the model then continues walking the use-def chain in order to - /// understand where the linalg op's operands come from. This leads to - /// discovering a "detensoring component". A detensoring component is the set - /// of operations + block arguments that are involved in control-flow AND can - /// be detensored. - /// - /// For examples where this model succeeds to discover a detensoring - /// component, see: - /// - test/Dialect/Linalg/detensorize_while.mlir - /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir. - /// - /// For an example where this model marks control-flow as "non-detensorable", - /// see: - /// - test/Dialect/Linalg/detensorize_while_failure.mlir - class PureControlFlowDetectionModel : public CostModel { + /// This model starts from BranchOps and CondBranchOps within a function. For + /// each such branch, the model then walks the use-def chain for the branch's + /// condition backwards in order to understand where the condition's value + /// comes from. If the condition value is (indirectly) computed by a linalg op + /// that can be detensored, the model then continues walking the use-def chain + /// in order to understand where the linalg op's operands come from. This + /// leads to discovering a "detensoring component". A detensoring component is + /// the set of operations + block arguments that are involved in control-flow + /// AND can be detensored. + class ControlFlowDetectionModel : public CostModel { public: void compute(FuncOp func, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, @@ -376,19 +367,19 @@ for (PredecessorIterator pred = ownerBlock->pred_begin(); pred != ownerBlock->pred_end(); ++pred) { - BranchOpInterface terminator = + BranchOpInterface predTerminator = dyn_cast((*pred)->getTerminator()); // TODO: For now, we give up if any of the control-flow components // in a function is not detensorable. Fix that. - if (!terminator) { + if (!predTerminator) { opsToDetensor.clear(); blockArgsToDetensor.clear(); return; } auto ownerBlockOperands = - terminator.getSuccessorOperands(pred.getSuccessorIndex()); + predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); if (!ownerBlockOperands || ownerBlockOperands->empty()) continue; @@ -418,12 +409,10 @@ if (opsToDetensor.count(genericOp)) continue; - // TODO: For now, we give up if any of the control-flow components - // in a function is not detensorable. Fix that. + // The op should not be detensored, give up on it but continue with + // discovering the rest of the control-flow component. if (!shouldBeDetensored(genericOp, typeConverter)) { - opsToDetensor.clear(); - blockArgsToDetensor.clear(); - return; + continue; } opsToDetensor.insert(genericOp); @@ -452,6 +441,47 @@ for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) workList.push_back(scalarOpOperand); } + + // Since the cost model gives up on some ops (see the details of step 2.2 + // above), block arguments that correspond to the values produced by those + // ops should not be detensored as well. + + DenseSet blockArgsToRemove; + + for (auto &blockArg : blockArgsToDetensor) { + Block *block = blockArg.getParentBlock(); + + // For the potentially detensorable block argument, find the + // correpsonding operands in predecessor blocks. + for (PredecessorIterator pred = block->pred_begin(); + pred != block->pred_end(); ++pred) { + BranchOpInterface terminator = + dyn_cast((*pred)->getTerminator()); + auto blockOperands = + terminator.getSuccessorOperands(pred.getSuccessorIndex()); + + if (!blockOperands || blockOperands->empty()) + continue; + + Operation *definingOp = + terminator + ->getOperand(blockOperands->getBeginOperandIndex() + + blockArg.getArgNumber()) + .getDefiningOp(); + + // If the operand is defined by a GenericOp that will not be + // detensored, then do not detensor the corresponding block argument. + if (dyn_cast_or_null(definingOp) && + opsToDetensor.count(definingOp) == 0) { + blockArgsToRemove.insert(blockArg); + break; + } + } + } + + for (auto &blockArg : blockArgsToRemove) { + blockArgsToDetensor.erase(blockArg); + } } }; @@ -487,7 +517,7 @@ blockArgsToDetensor); } else { - PureControlFlowDetectionModel costModel; + ControlFlowDetectionModel costModel; costModel.compute(getFunction(), typeConverter, opsToDetensor, blockArgsToDetensor); } diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir rename from mlir/test/Dialect/Linalg/detensorize_while_failure.mlir rename to mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -93,15 +93,14 @@ // DET-ALL: return %{{.*}} : tensor // DET-ALL: } -// Try to detensor pure control-flow. However, that fails since the potential -// detensorable component contains some ops that cannot be detensored. -// // DET-CF-LABEL: func @main // DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor) // DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) // DET-CF: ^bb1(%{{.*}}: tensor<10xi32>) // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor) { -// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%{{.*}} : tensor) { +// DET-CF: tensor.extract %{{.*}}[] : tensor +// DET-CF: tensor.extract %{{.*}}[] : tensor +// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32 // DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor), ^bb3(%{{.*}} : tensor) // DET-CF: ^bb2(%{{.*}}: tensor) // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) {